一,源码

image.png

对于 CountDownLatch,我们仅仅需要关心两个方法,一个是 countDown() 方法,另一个是 await() 方法。countDown() 方法每次调用都会将 state 减 1,直到state 的值为 0;而 await 是一个阻塞方法,当 state 减为 0 的时候,await 方法才会返回。await 可以被多个线程调用,所有调用了await 方法的线程阻塞在 AQS 的阻塞队列中,等待条件满足(state == 0),将线程从队列中一个个唤醒过来。
cdl.jpg

1.内部类

  1. private static final class Sync extends AbstractQueuedSynchronizer {
  2. private static final long serialVersionUID = 4982264981922014374L;
  3. //CountDownLatch中的计数其实就是AQS的state
  4. Sync(int count) {
  5. setState(count);
  6. }
  7. int getCount() {
  8. return getState();
  9. }
  10. protected int tryAcquireShared(int acquires) {
  11. //如果state =0 返回1 否则返回 0
  12. return (getState() == 0) ? 1 : -1;
  13. }
  14. protected boolean tryReleaseShared(int releases) {
  15. for (;;) {
  16. //获取最新的state
  17. int c = getState();
  18. //如果state==0 返回false
  19. if (c == 0)
  20. return false;
  21. int nextc = c-1;
  22. //如果cas成功,且c=1,返回true
  23. if (compareAndSetState(c, nextc))
  24. return nextc == 0;
  25. }
  26. }
  27. }

2.构造函数

  1. private final Sync sync;
  2. //构造函数
  3. public CountDownLatch(int count) {
  4. //边界值判断
  5. if (count < 0) throw new IllegalArgumentException("count < 0");
  6. //初始化Sync
  7. this.sync = new Sync(count);
  8. }

3.await

使当前线程挂起,直到计数器减为0或者当前线程被中断。

  1. public void await() throws InterruptedException {
  2. //执行aqs.acquireSharedInterruptibly()
  3. sync.acquireSharedInterruptibly(1);
  4. }

4.AQS.acquireSharedInterruptibly

countdownlatch 也用到了 AQS,在 CountDownLatch 内部写了一个 Sync 并且继承了AQS 这个抽象类重写了 AQS中的共享锁方法。首先看到下面这个代码,这块代码主要是 判 断 当 前 线 程 是 否 获 取 到 了 共 享 锁 ; ( 在CountDownLatch 中 , 使 用 的是 共 享 锁 机 制 , 因 为CountDownLatch 并不需要实现互斥的特性)。

  1. public final void acquireSharedInterruptibly(long arg) throws InterruptedException {
  2. //如果当前线程被中断,抛出中断异常
  3. if (Thread.interrupted())
  4. throw new InterruptedException();
  5. //条件成立:说明此时state>0将线程入队,然后等待唤醒
  6. //条件不成立:说明此时state=0,说明此时阻塞已经放开,当前线程不会被阻塞
  7. if (tryAcquireShared(arg) < 0)
  8. //将当前线程加入到共享锁队列
  9. doAcquireSharedInterruptibly(arg);
  10. }

5.tryAcquireShared

判断state状态.

  1. protected int tryAcquireShared(int acquires) {
  2. //如果state =0 返回1 否则返回 0
  3. return (getState() == 0) ? 1 : -1;
  4. }

6.AQS.doAcquireSharedInterruptibly

  1. addWaiter 设置为 shared 模式。

  2. tryAcquire 和 tryAcquireShared 的返回值不同,因此会多出一个判断过程。

  3. 在 判 断 前 驱 节 点 是 头 节 点 后 , 调 用 了setHeadAndPropagate 方法,而不是简单的更新一下头节点。

  1. private void doAcquireSharedInterruptibly(long arg)
  2. throws InterruptedException {
  3. //将当前线程封装成节点入队,共享节点,使用的是state的高16位运算
  4. final Node node = addWaiter(Node.SHARED);
  5. boolean failed = true;
  6. try {
  7. //自旋
  8. for (;;) {
  9. //获取当前节点的前驱
  10. final Node p = node.predecessor();
  11. //如果前驱节点是头节点
  12. if (p == head) {
  13. //当前节点就可以尝试去抢锁
  14. long r = tryAcquireShared(arg);
  15. //此时说明抢到锁了
  16. if (r >= 0) {
  17. //修改头节点的值
  18. setHeadAndPropagate(node, r);
  19. //头节点出队
  20. p.next = null; // help GC
  21. //代表抢锁成功
  22. failed = false;
  23. return;
  24. }
  25. }
  26. //否则的话,线程在这里park,如果线程中断信号=true,就会抛出中断异常
  27. if (shouldParkAfterFailedAcquire(p, node) &&
  28. parkAndCheckInterrupt())
  29. throw new InterruptedException();
  30. }
  31. } finally {
  32. //如果抢锁失败了,就走取消竞争锁的逻辑
  33. if (failed)
  34. cancelAcquire(node);
  35. }
  36. }

假如这个时候有 3 个线程调用了 await 方法,由于这个时候 state 的值还不为 0,所以这三个线程都会加入到 AQS队列中。并且三个线程都处于阻塞状态。
1.jpg

7.countDown

递减锁计数,如果锁计数为0,释放所有阻塞线程。

  1. public void countDown() {
  2. sync.releaseShared(1);
  3. }

8.AQS.releaseShared

由于线程被 await 方法阻塞了,所以只有等到countdown 方法使得 state=0 的时候才会被唤醒。

  1. 只有当 state 减为 0 的时候,tryReleaseShared 才返回 true, 否则只是简单的 state = state - 1。

  2. 如果 state=0, 则调用 doReleaseShared唤醒处于 await 状态下的线程。

  1. public final boolean releaseShared(int arg) {
  2. //执行子类重写的方法,state=0的时候,执行doReleaseShared
  3. //条件成立:说明当前调用latch.countDown()方法的线程,正好是state-1 == 0 的这个线程,需要做触发唤醒await状态的线程。
  4. if (tryReleaseShared(arg)) {
  5. //调用countDown()方法的线程,只有一个线程会进入到这个if块里面,执行下面的方法
  6. doReleaseShared();
  7. return true;
  8. }
  9. return false;
  10. }

9.tryReleaseShared

自旋释放锁,释放完了返回true,否则返回false。

  1. protected boolean tryReleaseShared(int releases) {
  2. for (;;) {
  3. //获取最新的state
  4. int c = getState();
  5. //如果state==0 返回false
  6. if (c == 0)
  7. return false;
  8. int nextc = c-1;
  9. //如果cas成功,且c=1,返回true
  10. if (compareAndSetState(c, nextc))
  11. return nextc == 0;
  12. }
  13. }

10.AQS.doReleaseShared

共享锁的释放和独占锁的释放有一定的差别

前面唤醒锁的逻辑和独占锁是一样,先判断头结点是不是SIGNAL 状态,如果是,则修改为 0,并且唤醒头结点的下一个节点。

PROPAGATE : 标识为 PROPAGATE 状态的节点,是共享锁模式下的节点状态,处于这个状态下的节点,会对线程的唤醒进行传播

  1. private void doReleaseShared() {
  2. //自旋
  3. for (;;) {
  4. //获取头节点的引用
  5. Node h = head;
  6. //如果头节点不为空 && 头节点不等于尾结点
  7. //条件一成立:说明阻塞队列不为空
  8. //什么时候不成立?latch创建出来以后,没有任何线程调用过await方法之前,就有线程调用countDown操作,并且触发了唤醒阻塞节点的逻辑
  9. //条件二成立:说明当前队列除了头节点还有其他节点
  10. //什么时候不成立?
  11. //1.正常唤醒情况:依次获取共享锁,当前线程执行到这里的时候是tail节点
  12. //2.第一个调用await的线程与调用countDown的线程并发了
  13. if (h != null && h != tail) {
  14. int ws = h.waitStatus;
  15. //如果头结点的转态=-1
  16. if (ws == Node.SIGNAL) {
  17. //cas设置头节点的状态失败
  18. if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
  19. continue;
  20. //cas成功,唤醒头节点的下一个节点
  21. unparkSuccessor(h);
  22. }
  23. //cas失败走到这里,
  24. //执行到这里的时候,刚好有一个节点入队,入队会将这个 ws 设置为 -1
  25. else if (ws == 0 &&
  26. !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
  27. continue;
  28. }
  29. /*
  30. 条件成立:
  31. 1.说明刚刚唤醒的后继节点,还没将自己设置为头节点,没执行到呢....
  32. 这个时候,当前线程直接跳出去结束了
  33. 此时并不需要担心 唤醒逻辑 在这里断开 ,因为被唤醒的线程,早晚会执行到doReleaseShared方法
  34. 2.head==null
  35. latch创建出来以后,没有任何线程调用过await方法之前,就有线程调用countDown操作,并且触发了唤醒阻塞节点的逻辑
  36. 3.h==tail
  37. break
  38. 条件不成立:
  39. 条件成立1的相反情况,此时唤醒他的节点 执行 h == head 不成立,此时 原头节点不会跳出,会继续唤醒新的头节点的后继节点。
  40. */
  41. if (h == head)
  42. break;
  43. }
  44. }

11.AQS.doAcquireSharedInterruptibly

一旦 ThreadA 被唤醒,代码又会继续回到doAcquireSharedInterruptibly 中来执行。如果当前 state满足=0 的条件,则会执行 setHeadAndPropagate 方法。

  1. private void doAcquireSharedInterruptibly(int arg)
  2. throws InterruptedException {
  3. final Node node = addWaiter(Node.SHARED);
  4. //创建一个共享模式的节点添加到队列中
  5. boolean failed = true;
  6. try {
  7. for (;;) {//被唤醒的线程进入下一次循环继续判断
  8. final Node p = node.predecessor();
  9. if (p == head) {
  10. int r = tryAcquireShared(arg);//就判断尝试获取锁
  11. if (r >= 0) {//r>=0 表示获取到了执行权限,这个时候因为 state!=0,所以不会执行这段代码
  12. setHeadAndPropagate(node, r);
  13. p.next = null; // help GC 把当前节点移除 aqs 队列
  14. failed = false;
  15. return;
  16. }
  17. }
  18. //阻塞线程
  19. if (shouldParkAfterFailedAcquire(p, node) &&
  20. parkAndCheckInterrupt())
  21. throw new InterruptedException();
  22. }
  23. } finally {
  24. if (failed)
  25. cancelAcquire(node);
  26. }
  27. }

12.setHeadAndPropagate

这个方法的主要作用是把被唤醒的节点,设置成 head 节点。 然后继续唤醒队列中的其他线程。由于现在队列中有 3 个线程处于阻塞状态,一旦 ThreadA被唤醒,并且设置为 head 之后,会继续唤醒后续的ThreadB。

  1. private void setHeadAndPropagate(Node node, int propagate) {
  2. Node h = head; // Record old head for check below
  3. //将当前节点设置为头节点
  4. setHead(node);
  5. //1>0
  6. if (propagate > 0 || h == null || h.waitStatus < 0 ||
  7. (h = head) == null || h.waitStatus < 0) {
  8. Node s = node.next;
  9. //条件一:s==null 什么时候成立呢? 当前node节点已经是tail节点了,
  10. //条件二的前置条件:s!=null 要求s的模式是共享模式
  11. if (s == null || s.isShared())
  12. //继续向后唤醒
  13. doReleaseShared();
  14. }
  15. }

1.jpg

13.流程图

cdl流程.jpg

二,使用

countdownlatch 是一个同步工具类,它允许一个或多个线程一直等待,直到其他线程的操作执行完毕再执行。从命名可以解读到 countdown 是倒数的意思,类似于倒计时的概念。

countdownlatch 提供了两个方法,一个是 countDown,一个是 await, countdownlatch 初始化的时候需要传入一个整数,在这个整数倒数到 0 之前,调用了 await 方法的程序都必须要等待,然后通过 countDown 来倒数。

  1. /**
  2. * @author 二十
  3. * @since 2021/9/6 2:00 下午
  4. */
  5. public class DemoA {
  6. private static CountDownLatch c = new CountDownLatch(6);
  7. private static ThreadPoolExecutor executor = new ThreadPoolExecutor(
  8. 6,
  9. 6,
  10. 1,
  11. TimeUnit.SECONDS,
  12. new ArrayBlockingQueue<>(1),
  13. new MyDefaultFactory(),
  14. new ThreadPoolExecutor.AbortPolicy()
  15. );
  16. public static void main(String[] args)throws Exception {
  17. for (int i = 0; i < 6; i++)
  18. executor.submit(()->{
  19. System.out.println(Thread.currentThread().getName() + "国被灭!");
  20. c.countDown();
  21. });
  22. c.await();
  23. if (Thread.currentThread().getName().equals("main")) System.out.println("main线程执行结束:" + Thread.currentThread().getName() );
  24. }
  25. private static class MyDefaultFactory implements ThreadFactory{
  26. private static Queue<String> queue = new LinkedList();
  27. static {
  28. for (int i = 1; i <= 6; i++) queue.add(Objects.requireNonNull(Message.foreach_CountryEnum(i)).message);
  29. }
  30. @Override
  31. public Thread newThread(Runnable r) {
  32. return new Thread(r,"thread-"+queue.poll() +"-er_shi");
  33. }
  34. }
  35. enum Message {
  36. ONE(1, "齐"), TWO(2, "楚"), THREE(3, "燕"), FOUR(4, "赵"), FIVE(5, "魏"), SIX(6, "韩");
  37. private int code;
  38. private String message;
  39. Message(int code, String message) {
  40. this.code = code;
  41. this.message = message;
  42. }
  43. public int getCode() {
  44. return code;
  45. }
  46. public void setCode(int code) {
  47. this.code = code;
  48. }
  49. public String getMessage() {
  50. return message;
  51. }
  52. public void setMessage(String message) {
  53. this.message = message;
  54. }
  55. public static Message foreach_CountryEnum(int index) {
  56. Message[] countryEnums = Message.values();
  57. for (Message countryEnum : countryEnums) {
  58. if (countryEnum.getCode() == index) {
  59. return countryEnum;
  60. }
  61. }
  62. return null;
  63. }
  64. }
  65. }