一、 CyclicBarrier 的介绍

1.1 介绍

CyclicBarrier的字面意思是可循环使用(Cyclic)的屏障(Barrier)。它要做的事情是,让一组线程到达一个屏障(common barrier point)时被阻塞,直到最后一个线程到达屏障时,屏障才会开门,所有被屏障拦截的线程才会继续运行。

1.2 CountDownLatch 与 CyclicBarrier 的区别?

  • CountDownLatch的作用是允许1或N个线程等待其他线程完成执行;而CyclicBarrier则是允许N个线程相互等待。
  • CountDownLatch的计数器无法被重置;CyclicBarrier的计数器可以被重置后使用,因此它被称为是循环的barrier。
  • CountDownLatch 内部自行采用 AQS实现的共享锁 ;而 CyclicBarrier内部采用 可重入锁 ReentrantLock 和Condition

1.3 CyclicBarrier 的API

  1. //构造方法
  2. //创建一个新的CyclicBarrier, 当给定数量parties的线程全部到达Barrier时
  3. //Barrier(栏栅)将会被绊倒(放行全部线程执行), 而预定义执行线程为空
  4. public CyclicBarrier(int parties);
  5. //创建一个新的CyclicBarrier, 当给定数量parties的线程全部到达Barrier时
  6. //Barrier(栏栅)将会被绊倒(放行全部线程执行),
  7. //而当/Barrier(栏栅)将会被绊倒后, 预定义执行线程barrierAction将会执行
  8. public CyclicBarrier(int parties, Runnable barrierAction);
  9. //返回要求启动此 barrier 的参与者数目。
  10. public int getParties();
  11. //在所有参与者都已经在此 barrier 上调用 await 方法之前,将一直等待。
  12. public int await();
  13. //在所有参与者都已经在此屏障上调用 await 方法之前将一直等待,或者超出了指定的等待时间。
  14. public int await(long timeout, TimeUnit unit);
  15. //查询此屏障是否处于损坏状态。
  16. public boolean isBroken()
  17. public void reset();
  18. //返回当前在屏障处等待的参与者数目。
  19. public int getNumberWaiting();

二、 CyclicBarrier 的内部结构

image.png

CyclicBarrier是包含了”ReentrantLock对象lock(不公平锁)”和”Condition对象trip”,它是通过独占锁实现的。

三、CyclicBarrier 源码解析

3.1 成员变量

  1. //屏障 可重入锁
  2. private final ReentrantLock lock = new ReentrantLock();
  3. //条件等待,直到屏障被绊倒
  4. private final Condition trip = lock.newCondition();
  5. //启动屏障的数量
  6. private final int parties;
  7. //屏障启动时执行的线程
  8. private final Runnable barrierCommand;
  9. //当前一代
  10. //屏障每次被绊倒,都会改变generation
  11. private Generation generation = new Generation();
  12. //返回当前在屏障处等待的参与者数目。
  13. private int count;

3.2 构造方法

  1. public CyclicBarrier(int parties) {
  2. this(parties, null);
  3. }
  4. //parties:在屏障处等待的参与者数目
  5. //barrierAction: 在屏障被绊倒时,执行的预定义操作
  6. public CyclicBarrier(int parties, Runnable barrierAction) {
  7. if (parties <= 0) throw new IllegalArgumentException();
  8. this.parties = parties;
  9. //返回当前在屏障处等待的参与者数目。
  10. this.count = parties;
  11. this.barrierCommand = barrierAction;
  12. }

3.3 await 方法

  • await 方法图示

  • 源码解析

    1. //await()是通过dowait()实现的。
    2. private int dowait(boolean timed, long nanos)
    3. throws InterruptedException, BrokenBarrierException,
    4. TimeoutException {
    5. final ReentrantLock lock = this.lock;
    6. //获取独占锁
    7. lock.lock();
    8. try {
    9. //保存当前generation
    10. final Generation g = generation;
    11. //判断当前栏栅是否 ”被损坏“
    12. if (g.broken)
    13. throw new BrokenBarrierException();
    14. //判断当前线程是否被打断
    15. if (Thread.interrupted()) {
    16. // 则通过breakBarrier()终止CyclicBarrier,唤醒CyclicBarrier中所有等待线程。
    17. breakBarrier();
    18. throw new InterruptedException();
    19. }
    20. //当前在屏障处等待的参与者数目 -1
    21. int index = --count;
    22. //如果 当前在屏障处等待的参与者数目等于 0
    23. if (index == 0) { // tripped
    24. boolean ranAction = false;
    25. try {
    26. final Runnable command = barrierCommand;
    27. //预定义操作不为空,则执行
    28. if (command != null)
    29. command.run();
    30. ranAction = true;
    31. // 唤醒所有等待线程,并更新generation。
    32. nextGeneration();
    33. return 0;
    34. } finally {
    35. if (!ranAction)
    36. breakBarrier();
    37. }
    38. }
    39. // loop until tripped, broken, interrupted, or timed out
    40. // 当前线程一直阻塞,直到以下情况
    41. //1. “有parties个线程到达barrier”
    42. //2. “当前线程 或等待线程 被中断”
    43. //3. “超时等待”
    44. //4. CyclicBarrier 被重置
    45. // 当前线程才继续执行。
    46. for (;;) {
    47. try {
    48. // 如果不是“超时等待”,则调用awati()进行等待;否则,调用awaitNanos()进行等待。
    49. if (!timed)
    50. trip.await();
    51. else if (nanos > 0L)
    52. nanos = trip.awaitNanos(nanos);
    53. } catch (InterruptedException ie) {
    54. // 如果等待过程中,线程被中断,则执行下面的函数。
    55. if (g == generation && ! g.broken) {
    56. //breakBarrier 主要设置当前CyclicBarrier为 broken
    57. //唤醒所有等待线程
    58. breakBarrier();
    59. throw ie;
    60. } else {
    61. // We're about to finish waiting even if we had not
    62. // been interrupted, so this interrupt is deemed to
    63. // "belong" to subsequent execution.
    64. Thread.currentThread().interrupt();
    65. }
    66. }
    67. // 如果“当前generation已经损坏”,则抛出异常。
    68. if (g.broken)
    69. throw new BrokenBarrierException();
    70. // 如果“generation已经换代”,则返回index。(CyclicBarrier被重置)
    71. if (g != generation)
    72. return index;
    73. // 如果是“超时等待”,并且时间已到,
    74. //则通过breakBarrier()终止CyclicBarrier,唤醒CyclicBarrier中所有等待线程。
    75. if (timed && nanos <= 0L) {
    76. breakBarrier();
    77. throw new TimeoutException();
    78. }
    79. }
    80. } finally {
    81. //释放锁
    82. lock.unlock();
    83. }
    84. }

四、CyclicBarrier 的使用示例

4.1 用于多线程计算数据,最后合并计算结果的场景

  1. //结果
  2. ConcurrentHashMap<String, Integer> result = new ConcurrentHashMap<>();
  3. //屏障
  4. CyclicBarrier cyclicBarrier = new CyclicBarrier(4, new Runnable() {
  5. @Override
  6. public void run() {
  7. int total = 0;
  8. //撤除屏障,汇总结果
  9. for (Map.Entry<String, Integer> entry : result.entrySet()) {
  10. total += entry.getValue();
  11. }
  12. System.out.println("总结果==>" + total);
  13. }
  14. });
  15. //创建4个线程计算
  16. for (int i = 0; i < 4; i++) {
  17. new Thread(new Runnable() {
  18. @Override
  19. public void run() {
  20. try {
  21. TimeUnit.SECONDS.sleep(new Random().nextInt(5));
  22. //计算
  23. result.put(Thread.currentThread().getName(), 1);
  24. System.out.println(Thread.currentThread().getName() + "计算完成, 等待");
  25. //计算完成,插入屏障
  26. cyclicBarrier.await();
  27. } catch (InterruptedException e) {
  28. e.printStackTrace();
  29. } catch (BrokenBarrierException e) {
  30. e.printStackTrace();
  31. }
  32. }
  33. }).start();
  34. }
  • CountDownLatch 版 ```java ConcurrentHashMap result = new ConcurrentHashMap<>();

CountDownLatch countDownLatch = new CountDownLatch(4); //创建4个线程计算 for (int i = 0; i < 4; i++) { new Thread(new Runnable() { @Override public void run() { try { TimeUnit.SECONDS.sleep(new Random().nextInt(5)); //计算 result.put(Thread.currentThread().getName(), 1); System.out.println(Thread.currentThread().getName() + “计算完成, 等待”);

  1. countDownLatch.countDown();
  2. } catch (InterruptedException e) {
  3. e.printStackTrace();
  4. }
  5. }
  6. }).start();

} new Thread(new Runnable() { @Override public void run() { try { countDownLatch.await(); } catch (InterruptedException e) { e.printStackTrace(); }

  1. int total = 0;
  2. //撤除屏障,汇总结果
  3. for (Map.Entry<String, Integer> entry : result.entrySet()) {
  4. total += entry.getValue();
  5. }
  6. System.out.println("总结果==>" + total);
  7. }

}).start(); ```

参考