CountDownLatch 这个类能够使一个线程等待其他线程完成各自的工作后再执行。例如,应用程序的主线程希望在负责启动框架服务的线程已经启动所有的框架服务之后再执行。

简单示例


陪女朋友看病

如果你女朋友去看病,医院里边排队的人很多,如果你女朋友一个人去的话,要先看大夫,看完大夫再去排队交钱取药。当你陪你女朋友一起去,你女朋友可以去看病,而你可以去排队交钱取药。当两件事同时完成后,你才能和你女朋友一起回家。

女朋友去看病

  1. public class SeeDoctorTask implements Runnable {
  2. private CountDownLatch countDownLatch;
  3. public SeeDoctorTask(CountDownLatch countDownLatch){
  4. this.countDownLatch = countDownLatch;
  5. }
  6. public void run() {
  7. try {
  8. System.out.println("开始看医生");
  9. Thread.sleep(2000);
  10. System.out.println("看医生结束,准备离开病房");
  11. } catch (InterruptedException e) {
  12. e.printStackTrace();
  13. }finally {
  14. if (countDownLatch != null)
  15. //看病完成,计数器减1
  16. countDownLatch.countDown();
  17. }
  18. }
  19. }

你排队去拿药

  1. public class QueueTask implements Runnable {
  2. private CountDownLatch countDownLatch;
  3. public QueueTask(CountDownLatch countDownLatch){
  4. this.countDownLatch = countDownLatch;
  5. }
  6. public void run() {
  7. try {
  8. System.out.println("开始在医院药房排队买药....");
  9. Thread.sleep(5000);
  10. System.out.println("排队成功,可以开始缴费买药");
  11. } catch (InterruptedException e) {
  12. e.printStackTrace();
  13. }finally {
  14. if (countDownLatch != null)
  15. //拿药完成后,计数器减1
  16. countDownLatch.countDown();
  17. }
  18. }
  19. }

当你女朋友看完病并且你拿完药后,你们可以一起回家

  1. public class CountDownLaunchRunner {
  2. public static void main(String[] args) throws InterruptedException {
  3. CountDownLatch countDownLatch = new CountDownLatch(2);
  4. new Thread(new SeeDoctorTask(countDownLatch)).start();
  5. new Thread(new QueueTask(countDownLatch)).start();
  6. //只有女朋友看完病和你拿完药后,你们才可以一起回家
  7. countDownLatch.await();
  8. System.out.println("over,回家 cost:"+(System.currentTimeMillis()-now));
  9. }
  10. }

模拟高并发

CountDownLantch 反着来用还可以模拟高并发场景

  1. public class CountDownLaunchRunner {
  2. static int sub = 0;
  3. static Object object = new Object();
  4. public static void main(String[] args) throws InterruptedException {
  5. long now = System.currentTimeMillis();
  6. CountDownLatch countDownLatch = new CountDownLatch(1);
  7. //开启10个线程
  8. for(int i=0;i<10;i++){
  9. new Thread(new Runnable() {
  10. @Override
  11. public void run() {
  12. try {
  13. //每个线程都阻塞,等待主线程进行countDown()操作,10个线程可以并发执行
  14. countDownLatch.await();
  15. } catch (InterruptedException e) {
  16. e.printStackTrace();
  17. }
  18. synchronized (object){
  19. for(int j=0;j<1000;j++){
  20. sub++;
  21. }
  22. }
  23. }
  24. });
  25. }
  26. Thread.sleep(3000);
  27. //执行countDown(),上面阻塞的所有线程都可以开始执行
  28. countDownLatch.countDown();
  29. System.out.println("over,回家 cost:"+(System.currentTimeMillis()-now));
  30. }
  31. }

实现原理

CountDownLatch是通过一个计数器来实现的,计数器的初始值为线程的数量。每当一个线程完成了自己的任务后,计数器的值就会减1。当计数器值到达0时,它表示所有的线程已经完成了任务,然后在闭锁上等待的线程就可以恢复执行任务。

  1. public CountDownLatch(int count) {
  2. if (count < 0) throw new IllegalArgumentException("count < 0");
  3. this.sync = new Sync(count);
  4. }
  5. ----------------------------
  6. Sync(int count) {
  7. //设置AQS中state的值为我们指定的参数
  8. setState(count);
  9. }

await()——-阻塞等待

  1. public void await() throws InterruptedException {
  2. sync.acquireSharedInterruptibly(1);
  3. }
  4. -------------------------
  5. public final void acquireSharedInterruptibly(int arg) throws InterruptedException {
  6. if (Thread.interrupted())
  7. throw new InterruptedException();
  8. if (tryAcquireShared(arg) < 0)
  9. doAcquireSharedInterruptibly(arg);
  10. }

调用tryAcquireShared(),如果当前计数器(state)不为0,说明还有线程未执行完

  1. protected int tryAcquireShared(int acquires) {
  2. //如果当前state=0,说明其他线程的任务已完成,返回1
  3. //如果当前state>0,说明还有线程任务未完成,返回-1
  4. return (getState() == 0) ? 1 : -1;
  5. }

与其他AQS工具的做法一样,创建节点加入CLH队列,并且将线程阻塞

  1. private void doAcquireSharedInterruptibly(int arg) throws InterruptedException {
  2. final Node node = addWaiter(Node.SHARED);
  3. boolean failed = true;
  4. try {
  5. for (;;) {
  6. final Node p = node.predecessor();
  7. if (p == head) {
  8. //如果state=0,返回1
  9. int r = tryAcquireShared(arg);
  10. //r=1,说明其他所有线程任务执行完
  11. if (r >= 0) {
  12. //以广播的方式唤醒所有调用await()方法阻塞的线程
  13. setHeadAndPropagate(node, r);
  14. p.next = null; // help GC
  15. failed = false;
  16. return;
  17. }
  18. }
  19. if (shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt())
  20. throw new InterruptedException();
  21. }
  22. } finally {
  23. if (failed)
  24. cancelAcquire(node);
  25. }
  26. }

当state=0,说明其他线程的任务执行完,则会以广播的方式通知所有调用 await() 方法阻塞等待的线程

countDown()——-计数器减1

当任务线程任务执行完毕,调用 countDown() 方法的时候,会对计数器(state)进行减1操作

  1. public void countDown() {
  2. sync.releaseShared(1);
  3. }
  4. ---------------------
  5. public final boolean releaseShared(int arg) {
  6. //将state减1
  7. if (tryReleaseShared(arg)) {
  8. //将之前调用 await() 方法阻塞的线程唤醒,继续检查state是否为0
  9. doReleaseShared();
  10. return true;
  11. }
  12. return false;
  13. }

tryReleaseShared() 尝试将 state 减1

  1. protected boolean tryReleaseShared(int releases) {
  2. for (;;) {
  3. //获取当前state的值
  4. int c = getState();
  5. //state=0 直接返回false
  6. if (c == 0)
  7. return false;
  8. //state减1
  9. int nextc = c-1;
  10. //cas修改state的值
  11. if (compareAndSetState(c, nextc))
  12. return nextc == 0;
  13. }
  14. }

成功修改 state 的值后,将之前调用 await() 方法阻塞的线程唤醒,继续检查state是否为0

通过广播的形式唤醒节点线程具体实现可以阅读

从上面源码中可以看到,state的值在调用 countDown() 方法减1后,后面调用await() 方法并没有加回去,所以 CountDownLantch 所起作用是一次性的。