Redisson2.0源码分析14-分布式AtomicLong以及CountDownLatch
Redisson2.0源码分析14-分布式AtomicLong以及CountDownLatch
1. 分布式 CountDownLatch RCountDownLatch
跟 Redisson V1 版本差不多,不过有些逻辑也采用 lua 脚本了。 基本的套路也是每一个 CountDownLatch 有一个唯一标识,通过订阅这个表示的 channel ,接收开锁闭锁消息。
1.1 源码分析
package org.redisson;
import java.util.Collections;
import java.util.UUID;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import org.redisson.client.BaseRedisPubSubListener;
import org.redisson.client.RedisPubSubListener;
import org.redisson.client.codec.LongCodec;
import org.redisson.client.protocol.RedisCommands;
import org.redisson.client.protocol.pubsub.PubSubType;
import org.redisson.core.RCountDownLatch;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.PlatformDependent;
public class RedissonCountDownLatch extends RedissonObject implements RCountDownLatch {
private static final Integer zeroCountMessage = 0;
private static final Integer newCountMessage = 1;
private static final ConcurrentMap<String, RedissonCountDownLatchEntry> ENTRIES = PlatformDependent.newConcurrentHashMap();
private final UUID id;
protected RedissonCountDownLatch(CommandExecutor commandExecutor, String name, UUID id) {
super(commandExecutor, name);
this.id = id;
}
// 订阅开锁/闭锁 channel 的消息
private Future<Boolean> subscribe() {
Promise<Boolean> promise = aquire();
if (promise != null) {
return promise;
}
Promise<Boolean> newPromise = newPromise();
final RedissonCountDownLatchEntry value = new RedissonCountDownLatchEntry(newPromise);
value.aquire();
RedissonCountDownLatchEntry oldValue = ENTRIES.putIfAbsent(getEntryName(), value);
if (oldValue != null) {
Promise<Boolean> oldPromise = aquire();
if (oldPromise == null) {
return subscribe();
}
return oldPromise;
}
// 监听计数器变化做相应处理
RedisPubSubListener<Integer> listener = new BaseRedisPubSubListener<Integer>() {
@Override
public void onMessage(String channel, Integer message) {
if (!getChannelName().equals(channel)) {
return;
}
// 当计数器变为0时,打开Latch
if (message.equals(zeroCountMessage)) {
value.getLatch().open();
}
// 当计数器被重置时,关闭Latch
if (message.equals(newCountMessage)) {
value.getLatch().close();
}
}
@Override
public boolean onStatus(PubSubType type, String channel) {
if (channel.equals(getChannelName()) && !value.getPromise().isSuccess()) {
value.getPromise().setSuccess(true);
return true;
}
return false;
}
};
synchronized (ENTRIES) {
commandExecutor.getConnectionManager().subscribe(listener, getChannelName());
}
return newPromise;
}
// 取消订阅 channel 消息
private void unsubscribe() {
while (true) {
RedissonCountDownLatchEntry entry = ENTRIES.get(getEntryName());
if (entry == null) {
return;
}
RedissonCountDownLatchEntry newEntry = new RedissonCountDownLatchEntry(entry);
newEntry.release();
if (ENTRIES.replace(getEntryName(), entry, newEntry)) {
if (newEntry.isFree()
&& ENTRIES.remove(getEntryName(), newEntry)) {
synchronized (ENTRIES) {
// maybe added during subscription
if (!ENTRIES.containsKey(getEntryName())) {
commandExecutor.getConnectionManager().unsubscribe(getChannelName());
}
}
}
return;
}
}
}
private Promise<Boolean> aquire() {
while (true) {
RedissonCountDownLatchEntry entry = ENTRIES.get(getEntryName());
if (entry != null) {
RedissonCountDownLatchEntry newEntry = new RedissonCountDownLatchEntry(entry);
newEntry.aquire();
if (ENTRIES.replace(getEntryName(), entry, newEntry)) {
return newEntry.getPromise();
}
} else {
return null;
}
}
}
// 阻塞当前线程直到计数器变为0
public void await() throws InterruptedException {
Future<Boolean> promise = subscribe();
try {
promise.await();
// 在计数器大于0时,阻塞线程,等待 latch 打开。
while (getCountInner() > 0) {
// waiting for open state
RedissonCountDownLatchEntry entry = ENTRIES.get(getEntryName());
if (entry != null) {
entry.getLatch().await();
}
}
} finally {
// 取消订阅
unsubscribe();
}
}
@Override
public boolean await(long time, TimeUnit unit) throws InterruptedException {
Future<Boolean> promise = subscribe();
try {
// 等待 channel 订阅成功
if (!promise.await(time, unit)) {
return false;
}
time = unit.toMillis(time);
// 持续检查计数器是否大于零
while (getCountInner() > 0) {
// 时间不够了返回 false
if (time <= 0) {
return false;
}
// 记录下当前时间
long current = System.currentTimeMillis();
// 等待 latch 打开
RedissonCountDownLatchEntry entry = ENTRIES.get(getEntryName());
if (entry != null) {
entry.getLatch().await(time, TimeUnit.MILLISECONDS);
}
// 扣减本次等待的时间
long elapsed = System.currentTimeMillis() - current;
time = time - elapsed;
}
// 等到了就返回 true
return true;
} finally {
// 释放 channel 订阅
unsubscribe();
}
}
// 减少计数器的值
@Override
public void countDown() {
if (getCount() <= 0) {
return;
}
// 调用 lua 脚本扣减值
// 如果计数器变为0,删除键并发布消息通知
commandExecutor.evalWrite(getName(), RedisCommands.EVAL_BOOLEAN,
"local v = redis.call('decr', KEYS[1]);" +
"if v <= 0 then redis.call('del', KEYS[1]) end;" +
"if v == 0 then redis.call('publish', ARGV[2], ARGV[1]) end;" +
"return true",
Collections.<Object>singletonList(getName()), zeroCountMessage, getChannelName());
}
private String getEntryName() {
return id + getName();
}
private String getChannelName() {
return "redisson_countdownlatch_{" + getName() + "}";
}
@Override
public long getCount() {
return getCountInner();
}
private long getCountInner() {
Long val = commandExecutor.read(getName(), LongCodec.INSTANCE, RedisCommands.GET, getName());
if (val == null) {
return 0;
}
return val;
}
// 设置计数器的值
@Override
public boolean trySetCount(long count) {
// 设置成功发布消息计数器值变化
return commandExecutor.evalWrite(getName(), RedisCommands.EVAL_BOOLEAN,
"if redis.call('exists', KEYS[1]) == 0 then redis.call('set', KEYS[1], ARGV[2]); redis.call('publish', ARGV[3], ARGV[1]); return true else return false end",
Collections.<Object>singletonList(getName()), newCountMessage, count, getChannelName());
}
// 删除计数器的值
@Override
public Future<Boolean> deleteAsync() {
// 删除成功发布消息计数器值变化
return commandExecutor.evalWriteAsync(getName(), RedisCommands.EVAL_BOOLEAN,
"if redis.call('del', KEYS[1]) == 1 then redis.call('publish', ARGV[2], ARGV[1]); return true else return false end",
Collections.<Object>singletonList(getName()), newCountMessage, getChannelName());
}
}
1.2 核心 lua 源码分析
三段核心的 lua 脚本
1.2.1 减计数器值
假设计数器名myCountDownLatch,计数器归0发送到channel的消息为0,channel名为 redissoncountdownlatch{myCountDownLatch} , 参数示例:
- KEYS[1]:myCountDownLatch
- ARGV[1]:0
- ARGV[2]:redissoncountdownlatch{myCountDownLatch}
local v = redis.call('decr', KEYS[1]); --- 数值减1保存在变量中
if v <= 0 then redis.call('del', KEYS[1]) end; --- 如果减1后小于等于0,就删除这个 key
if v == 0 then redis.call('publish', ARGV[2], ARGV[1]) end; --- 等于0就发布计数器归0消息到目标 channel
return true
1.2.2 设置计数器值
- KEYS[1]:myCountDownLatch 计数器的Redis key
- ARGV[1]:1 表示计数器已被设置的消息
- ARGV[2]:5 新的计数器值
- ARGV[3]:redissoncountdownlatch{myCountDownLatch} channel 名称
if redis.call('exists', KEYS[1]) == 0 then --- 存在这个 key 则更新计数器值,发布计数器值更新消息
redis.call('set', KEYS[1], ARGV[2]);
redis.call('publish', ARGV[3], ARGV[1]);
return true
else
return false --- 不存在这个计数器直接返回
end
1.2.3 删除计数器值
- KEYS[1]:myCountDownLatch 计数器的Redis key
- ARGV[1]:1 表示计数器已被删除的消息
- ARGV[2]:redissoncountdownlatch{myCountDownLatch} channel 名称
if redis.call('del', KEYS[1]) == 1 then --- 删除成功则发布删除成功消息到 channel
redis.call('publish', ARGV[2], ARGV[1]);
return true
else
return false --- 不成功返回 false
end
2. 分布式计数器 RAtomicLong
这个类主要是用于原子性自增值的一个工具类。
2.1 源码分析
package org.redisson;
import java.util.Collections;
import org.redisson.client.codec.StringCodec;
import org.redisson.client.protocol.RedisCommands;
import org.redisson.core.RAtomicLong;
import io.netty.util.concurrent.Future;
public class RedissonAtomicLong extends RedissonExpirable implements RAtomicLong {
protected RedissonAtomicLong(CommandExecutor commandExecutor, String name) {
super(commandExecutor, name);
}
// 使用 Redis 命令 INCRBY 增加值
@Override
public long addAndGet(long delta) {
return get(addAndGetAsync(delta));
}
@Override
public Future<Long> addAndGetAsync(long delta) {
return commandExecutor.writeAsync(getName(), StringCodec.INSTANCE, RedisCommands.INCRBY, getName(), delta);
}
@Override
public boolean compareAndSet(long expect, long update) {
return get(compareAndSetAsync(expect, update));
}
// 先比较,符合预期再设置值
@Override
public Future<Boolean> compareAndSetAsync(long expect, long update) {
// 执行 lua 脚本
// 先检查当前值是否等于 expect value 。
// 如果相等,则更新值为 update value 。 不相等不处理
return commandExecutor.evalWriteAsync(getName(), StringCodec.INSTANCE, RedisCommands.EVAL_BOOLEAN,
"if redis.call('get', KEYS[1]) == ARGV[1] then "
+ "redis.call('set', KEYS[1], ARGV[2]); "
+ "return true "
+ "else "
+ "return false end",
Collections.<Object>singletonList(getName()), expect, update);
}
// 当前值减1并返回更新后的值
@Override
public long decrementAndGet() {
// 使用Redis命令 DECR 减少值
return get(decrementAndGetAsync());
}
@Override
public Future<Long> decrementAndGetAsync() {
return commandExecutor.writeAsync(getName(), StringCodec.INSTANCE, RedisCommands.DECR, getName());
}
@Override
public long get() {
// 通过addAndGet(0)获取当前值,确保与Redis同步
return addAndGet(0);
}
@Override
public Future<Long> getAsync() {
return addAndGetAsync(0);
}
// 先获取当前值,然后增加指定的值
@Override
public long getAndAdd(long delta) {
return get(getAndAddAsync(delta));
}
@Override
public Future<Long> getAndAddAsync(long delta) {
// 使用 Lua 脚本获取当前值并增加指定数值,再返回旧值。
return commandExecutor.evalWriteAsync(getName(),
StringCodec.INSTANCE, RedisCommands.EVAL_INTEGER,
"local v = redis.call('get', KEYS[1]) or 0; "
+ "redis.call('set', KEYS[1], v + ARGV[1]); "
+ "return tonumber(v)",
Collections.<Object>singletonList(getName()), delta);
}
@Override
public long getAndSet(long newValue) {
return get(getAndSetAsync(newValue));
}
@Override
public Future<Long> getAndSetAsync(long newValue) {
// 使用 Lua 脚本获取当前值并设置为 newValue,再返回旧值
return commandExecutor.evalWriteAsync(getName(),
StringCodec.INSTANCE, RedisCommands.EVAL_INTEGER,
"local v = redis.call('get', KEYS[1]) or 0; redis.call('set', KEYS[1], ARGV[1]); return tonumber(v)",
Collections.<Object>singletonList(getName()), newValue);
}
// 将当前值加1并返回更新后的值
@Override
public long incrementAndGet() {
return get(incrementAndGetAsync());
}
@Override
public Future<Long> incrementAndGetAsync() {
return commandExecutor.writeAsync(getName(), StringCodec.INSTANCE, RedisCommands.INCR, getName());
}
// 先获取当前值再将当前值加1
@Override
public long getAndIncrement() {
return getAndAdd(1);
}
@Override
public Future<Long> getAndIncrementAsync() {
return getAndAddAsync(1);
}
// 先获取当前值再将当前值减1
@Override
public long getAndDecrement() {
return getAndAdd(-1);
}
@Override
public Future<Long> getAndDecrementAsync() {
return getAndAddAsync(-1);
}
// 设置为目标数值
@Override
public void set(long newValue) {
get(setAsync(newValue));
}
@Override
public Future<Void> setAsync(long newValue) {
// 使用 Redis 命令 SET 设置值
return commandExecutor.writeAsync(getName(), StringCodec.INSTANCE, RedisCommands.SET, getName(), newValue);
}
public String toString() {
return Long.toString(get());
}
}
2.2 核心 lua 源码分析
2.2.1 比较再设置值 compareAndSet
- KEYS[1]:myAtomicLong 【Redis key名称】
- ARGV[1]:100 【期望值】
- ARGV[2]:200 【更新值】
if redis.call('get', KEYS[1]) == ARGV[1] then
redis.call('set', KEYS[1], ARGV[2]);
return true
else
return false
end
2.2.2 获取旧的值再累加值 getAndAdd
- EYS[1]:myAtomicLong 【Redis key名称】
- RGV[1]:5 【要增加的数值】
local v = redis.call('get', KEYS[1]) or 0;
redis.call('set', KEYS[1], v + ARGV[1]);
return tonumber(v)
2.2.3 获取旧的值再设置新的值 getAndSet
- KEYS[1]:myAtomicLong 【Redis key名称】
- ARGV[1]:300 【新的值】
local v = redis.call('get', KEYS[1]) or 0;
redis.call('set', KEYS[1], ARGV[1]);
return tonumber(v)