七的博客

Redisson2.0源码分析14-分布式AtomicLong以及CountDownLatch

源码分析

Redisson2.0源码分析14-分布式AtomicLong以及CountDownLatch

1. 分布式 CountDownLatch RCountDownLatch

跟 Redisson V1 版本差不多,不过有些逻辑也采用 lua 脚本了。 基本的套路也是每一个 CountDownLatch 有一个唯一标识,通过订阅这个表示的 channel ,接收开锁闭锁消息。

1.1 源码分析

package org.redisson;

import java.util.Collections;
import java.util.UUID;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;

import org.redisson.client.BaseRedisPubSubListener;
import org.redisson.client.RedisPubSubListener;
import org.redisson.client.codec.LongCodec;
import org.redisson.client.protocol.RedisCommands;
import org.redisson.client.protocol.pubsub.PubSubType;
import org.redisson.core.RCountDownLatch;

import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import io.netty.util.internal.PlatformDependent;


public class RedissonCountDownLatch extends RedissonObject implements RCountDownLatch {

    private static final Integer zeroCountMessage = 0;
    private static final Integer newCountMessage = 1;

    private static final ConcurrentMap<String, RedissonCountDownLatchEntry> ENTRIES = PlatformDependent.newConcurrentHashMap();

    private final UUID id;

    protected RedissonCountDownLatch(CommandExecutor commandExecutor, String name, UUID id) {
        super(commandExecutor, name);
        this.id = id;
    }

    // 订阅开锁/闭锁 channel 的消息
    private Future<Boolean> subscribe() {
        Promise<Boolean> promise = aquire();
        if (promise != null) {
            return promise;
        }

        Promise<Boolean> newPromise = newPromise();
        final RedissonCountDownLatchEntry value = new RedissonCountDownLatchEntry(newPromise);
        value.aquire();
        RedissonCountDownLatchEntry oldValue = ENTRIES.putIfAbsent(getEntryName(), value);
        if (oldValue != null) {
            Promise<Boolean> oldPromise = aquire();
            if (oldPromise == null) {
                return subscribe();
            }
            return oldPromise;
        }

        // 监听计数器变化做相应处理
        RedisPubSubListener<Integer> listener = new BaseRedisPubSubListener<Integer>() {

            @Override
            public void onMessage(String channel, Integer message) {
                if (!getChannelName().equals(channel)) {
                    return;
                }
                // 当计数器变为0时,打开Latch
                if (message.equals(zeroCountMessage)) {
                    value.getLatch().open();
                }
                // 当计数器被重置时,关闭Latch
                if (message.equals(newCountMessage)) {
                    value.getLatch().close();
                }
            }

            @Override
            public boolean onStatus(PubSubType type, String channel) {
                if (channel.equals(getChannelName()) && !value.getPromise().isSuccess()) {
                    value.getPromise().setSuccess(true);
                    return true;
                }
                return false;
            }

        };

        synchronized (ENTRIES) {
            commandExecutor.getConnectionManager().subscribe(listener, getChannelName());
        }
        return newPromise;
    }

    // 取消订阅 channel 消息
    private void unsubscribe() {
        while (true) {
            RedissonCountDownLatchEntry entry = ENTRIES.get(getEntryName());
            if (entry == null) {
                return;
            }
            RedissonCountDownLatchEntry newEntry = new RedissonCountDownLatchEntry(entry);
            newEntry.release();
            if (ENTRIES.replace(getEntryName(), entry, newEntry)) {
                if (newEntry.isFree()
                        && ENTRIES.remove(getEntryName(), newEntry)) {
                    synchronized (ENTRIES) {
                        // maybe added during subscription
                        if (!ENTRIES.containsKey(getEntryName())) {
                            commandExecutor.getConnectionManager().unsubscribe(getChannelName());
                        }
                    }
                }
                return;
            }
        }
    }


    private Promise<Boolean> aquire() {
        while (true) {
            RedissonCountDownLatchEntry entry = ENTRIES.get(getEntryName());
            if (entry != null) {
                RedissonCountDownLatchEntry newEntry = new RedissonCountDownLatchEntry(entry);
                newEntry.aquire();
                if (ENTRIES.replace(getEntryName(), entry, newEntry)) {
                    return newEntry.getPromise();
                }
            } else {
                return null;
            }
        }
    }

    // 阻塞当前线程直到计数器变为0
    public void await() throws InterruptedException {
        Future<Boolean> promise = subscribe();
        try {

            promise.await();

            // 在计数器大于0时,阻塞线程,等待 latch 打开。
            while (getCountInner() > 0) {
                // waiting for open state
                RedissonCountDownLatchEntry entry = ENTRIES.get(getEntryName());
                if (entry != null) {
                    entry.getLatch().await();
                }
            }
        } finally {
            // 取消订阅
            unsubscribe();
        }
    }


    @Override
    public boolean await(long time, TimeUnit unit) throws InterruptedException {
        Future<Boolean> promise = subscribe();
        try {
            // 等待 channel 订阅成功
            if (!promise.await(time, unit)) {
                return false;
            }



            time = unit.toMillis(time);

            // 持续检查计数器是否大于零
            while (getCountInner() > 0) {
                // 时间不够了返回 false
                if (time <= 0) {
                    return false;
                }

                // 记录下当前时间
                long current = System.currentTimeMillis();
                
                // 等待 latch 打开
                RedissonCountDownLatchEntry entry = ENTRIES.get(getEntryName());
                if (entry != null) {
                    entry.getLatch().await(time, TimeUnit.MILLISECONDS);
                }

                // 扣减本次等待的时间
                long elapsed = System.currentTimeMillis() - current;
                time = time - elapsed;
            }

            // 等到了就返回 true
            return true;
        } finally {
            // 释放 channel 订阅
            unsubscribe();
        }
    }

    // 减少计数器的值
    @Override
    public void countDown() {
        if (getCount() <= 0) {
            return;
        }

        // 调用 lua 脚本扣减值
        // 如果计数器变为0,删除键并发布消息通知
        commandExecutor.evalWrite(getName(), RedisCommands.EVAL_BOOLEAN,
                "local v = redis.call('decr', KEYS[1]);" +
                        "if v <= 0 then redis.call('del', KEYS[1]) end;" +
                        "if v == 0 then redis.call('publish', ARGV[2], ARGV[1]) end;" +
                        "return true",
                 Collections.<Object>singletonList(getName()), zeroCountMessage, getChannelName());
    }

    private String getEntryName() {
        return id + getName();
    }

    private String getChannelName() {
        return "redisson_countdownlatch_{" + getName() + "}";
    }

    @Override
    public long getCount() {
        return getCountInner();
    }

    private long getCountInner() {
        Long val = commandExecutor.read(getName(), LongCodec.INSTANCE, RedisCommands.GET, getName());
        if (val == null) {
            return 0;
        }
        return val;
    }

    // 设置计数器的值
    @Override
    public boolean trySetCount(long count) {
        // 设置成功发布消息计数器值变化
        return commandExecutor.evalWrite(getName(), RedisCommands.EVAL_BOOLEAN,
                "if redis.call('exists', KEYS[1]) == 0 then redis.call('set', KEYS[1], ARGV[2]); redis.call('publish', ARGV[3], ARGV[1]); return true else return false end",
                 Collections.<Object>singletonList(getName()), newCountMessage, count, getChannelName());
    }

    // 删除计数器的值
    @Override
    public Future<Boolean> deleteAsync() {
        // 删除成功发布消息计数器值变化
        return commandExecutor.evalWriteAsync(getName(), RedisCommands.EVAL_BOOLEAN,
                "if redis.call('del', KEYS[1]) == 1 then redis.call('publish', ARGV[2], ARGV[1]); return true else return false end",
                 Collections.<Object>singletonList(getName()), newCountMessage, getChannelName());
    }

}

1.2 核心 lua 源码分析

三段核心的 lua 脚本

1.2.1 减计数器值

假设计数器名myCountDownLatch,计数器归0发送到channel的消息为0,channel名为 redissoncountdownlatch{myCountDownLatch} , 参数示例:

  • KEYS[1]:myCountDownLatch
  • ARGV[1]:0
  • ARGV[2]:redissoncountdownlatch{myCountDownLatch}
local v = redis.call('decr', KEYS[1]);  --- 数值减1保存在变量中
if v <= 0 then redis.call('del', KEYS[1]) end; --- 如果减1后小于等于0,就删除这个 key
if v == 0 then redis.call('publish', ARGV[2], ARGV[1]) end;  --- 等于0就发布计数器归0消息到目标 channel 
return true

1.2.2 设置计数器值

  • KEYS[1]:myCountDownLatch 计数器的Redis key
  • ARGV[1]:1 表示计数器已被设置的消息
  • ARGV[2]:5 新的计数器值
  • ARGV[3]:redissoncountdownlatch{myCountDownLatch} channel 名称
if redis.call('exists', KEYS[1]) == 0 then   --- 存在这个 key 则更新计数器值,发布计数器值更新消息
    redis.call('set', KEYS[1], ARGV[2]); 
    redis.call('publish', ARGV[3], ARGV[1]); 
    return true 
else 
    return false   --- 不存在这个计数器直接返回
end

1.2.3 删除计数器值

  • KEYS[1]:myCountDownLatch 计数器的Redis key
  • ARGV[1]:1 表示计数器已被删除的消息
  • ARGV[2]:redissoncountdownlatch{myCountDownLatch} channel 名称
if redis.call('del', KEYS[1]) == 1 then  --- 删除成功则发布删除成功消息到 channel 
    redis.call('publish', ARGV[2], ARGV[1]); 
    return true 
else 
    return false   --- 不成功返回 false
end

2. 分布式计数器 RAtomicLong

这个类主要是用于原子性自增值的一个工具类。

2.1 源码分析

package org.redisson;

import java.util.Collections;

import org.redisson.client.codec.StringCodec;
import org.redisson.client.protocol.RedisCommands;
import org.redisson.core.RAtomicLong;

import io.netty.util.concurrent.Future;


public class RedissonAtomicLong extends RedissonExpirable implements RAtomicLong {

    protected RedissonAtomicLong(CommandExecutor commandExecutor, String name) {
        super(commandExecutor, name);
    }

    // 使用 Redis 命令 INCRBY 增加值
    @Override
    public long addAndGet(long delta) {
        return get(addAndGetAsync(delta));
    }

    @Override
    public Future<Long> addAndGetAsync(long delta) {
        return commandExecutor.writeAsync(getName(), StringCodec.INSTANCE, RedisCommands.INCRBY, getName(), delta);
    }

    @Override
    public boolean compareAndSet(long expect, long update) {
        return get(compareAndSetAsync(expect, update));
    }

    // 先比较,符合预期再设置值
    @Override
    public Future<Boolean> compareAndSetAsync(long expect, long update) {
        // 执行 lua 脚本
        // 先检查当前值是否等于 expect value 。
        // 如果相等,则更新值为 update value 。 不相等不处理
        return commandExecutor.evalWriteAsync(getName(), StringCodec.INSTANCE, RedisCommands.EVAL_BOOLEAN,
                "if redis.call('get', KEYS[1]) == ARGV[1] then "
                     + "redis.call('set', KEYS[1], ARGV[2]); "
                     + "return true "
                   + "else "
                     + "return false end",
                Collections.<Object>singletonList(getName()), expect, update);
    }

    // 当前值减1并返回更新后的值
    @Override
    public long decrementAndGet() {
        // 使用Redis命令 DECR 减少值
        return get(decrementAndGetAsync());
    }

    @Override
    public Future<Long> decrementAndGetAsync() {
        return commandExecutor.writeAsync(getName(), StringCodec.INSTANCE, RedisCommands.DECR, getName());
    }

    @Override
    public long get() {
        // 通过addAndGet(0)获取当前值,确保与Redis同步
        return addAndGet(0);
    }

    @Override
    public Future<Long> getAsync() {
        return addAndGetAsync(0);
    }

    // 先获取当前值,然后增加指定的值
    @Override
    public long getAndAdd(long delta) {
        return get(getAndAddAsync(delta));
    }

    @Override
    public Future<Long> getAndAddAsync(long delta) {
        // 使用 Lua 脚本获取当前值并增加指定数值,再返回旧值。

        return commandExecutor.evalWriteAsync(getName(),
                StringCodec.INSTANCE, RedisCommands.EVAL_INTEGER,
                "local v = redis.call('get', KEYS[1]) or 0; "
                + "redis.call('set', KEYS[1], v + ARGV[1]); "
                + "return tonumber(v)",
                Collections.<Object>singletonList(getName()), delta);
    }


    @Override
    public long getAndSet(long newValue) {
        return get(getAndSetAsync(newValue));
    }

    @Override
    public Future<Long> getAndSetAsync(long newValue) {
        // 使用 Lua 脚本获取当前值并设置为 newValue,再返回旧值
        return commandExecutor.evalWriteAsync(getName(),
                StringCodec.INSTANCE, RedisCommands.EVAL_INTEGER,
                "local v = redis.call('get', KEYS[1]) or 0; redis.call('set', KEYS[1], ARGV[1]); return tonumber(v)",
                Collections.<Object>singletonList(getName()), newValue);
    }

    // 将当前值加1并返回更新后的值
    @Override
    public long incrementAndGet() {
        return get(incrementAndGetAsync());
    }

    @Override
    public Future<Long> incrementAndGetAsync() {
        return commandExecutor.writeAsync(getName(), StringCodec.INSTANCE, RedisCommands.INCR, getName());
    }

    // 先获取当前值再将当前值加1
    @Override
    public long getAndIncrement() {
        return getAndAdd(1);
    }

    @Override
    public Future<Long> getAndIncrementAsync() {
        return getAndAddAsync(1);
    }

    // 先获取当前值再将当前值减1
    @Override
    public long getAndDecrement() {
        return getAndAdd(-1);
    }

    @Override
    public Future<Long> getAndDecrementAsync() {
        return getAndAddAsync(-1);
    }

    // 设置为目标数值
    @Override
    public void set(long newValue) {
        get(setAsync(newValue));
    }

    @Override
    public Future<Void> setAsync(long newValue) {
        // 使用 Redis 命令 SET 设置值
        return commandExecutor.writeAsync(getName(), StringCodec.INSTANCE, RedisCommands.SET, getName(), newValue);
    }

    public String toString() {
        return Long.toString(get());
    }

}

2.2 核心 lua 源码分析

2.2.1 比较再设置值 compareAndSet

  • KEYS[1]:myAtomicLong 【Redis key名称】
  • ARGV[1]:100 【期望值】
  • ARGV[2]:200 【更新值】
if redis.call('get', KEYS[1]) == ARGV[1] then 
    redis.call('set', KEYS[1], ARGV[2]); 
    return true 
else 
    return false 
end

2.2.2 获取旧的值再累加值 getAndAdd

  • EYS[1]:myAtomicLong 【Redis key名称】
  • RGV[1]:5 【要增加的数值】
local v = redis.call('get', KEYS[1]) or 0; 
redis.call('set', KEYS[1], v + ARGV[1]); 
return tonumber(v)

2.2.3 获取旧的值再设置新的值 getAndSet

  • KEYS[1]:myAtomicLong 【Redis key名称】
  • ARGV[1]:300 【新的值】
local v = redis.call('get', KEYS[1]) or 0; 
redis.call('set', KEYS[1], ARGV[1]); 
return tonumber(v)