本篇内容

  1. 介绍 CountDownLatch 及使用场景
  2. 提供几个使用示例介绍 CountDownLatch 的使用
  3. 手写一个并行处理任务的工具类

    如何实现这个需求?

    假如有这样一个需求,当我们需要解析一个 Excel 里多个 sheet 的数据时,可以考虑使用多线程,每个线程解析一个 sheet 里的数据,等到所有的 sheet 都解析完之后,程序需要统计解析总耗时。
    分析一下:解析每个 sheet 耗时可能不一样,总耗时就是最长耗时的那个操作。
    我们能够想到的最简单的做法是使用 join,代码如下:
    package com.itsoku.chat13;

import java.util.concurrent.TimeUnit;

/
微信公众号:程序员路人
/
public class Demo1** {

  1. **public** **static** **class** **T** **extends** **Thread** {<br /> //休眠时间(秒)<br /> **int** sleepSeconds;
  2. **public** **T**(String name, **int** sleepSeconds) {<br /> **super**(name);<br /> **this**.sleepSeconds = sleepSeconds;<br /> }
  3. @Override<br /> **public** **void** **run**() {<br /> Thread ct = Thread.currentThread();<br /> **long** startTime = System.currentTimeMillis();<br /> System.out.println(startTime + "," + ct.getName() + ",开始处理!");<br /> **try** {<br /> //模拟耗时操作,休眠sleepSeconds秒<br /> TimeUnit.SECONDS.sleep(**this**.sleepSeconds);<br /> } **catch** (InterruptedException e) {<br /> e.printStackTrace();<br /> }<br /> **long** endTime = System.currentTimeMillis();<br /> System.out.println(endTime + "," + ct.getName() + ",处理完毕,耗时:" + (endTime - startTime));<br /> }<br /> }
  4. **public** **static** **void** **main**(String[] args) **throws** InterruptedException {<br /> **long** starTime = System.currentTimeMillis();<br /> T t1 = **new** T("解析sheet1线程", 2);<br /> t1.start();
  5. T t2 = **new** T("解析sheet2线程", 5);<br /> t2.start();
  6. t1.join();<br /> t2.join();<br /> **long** endTime = System.currentTimeMillis();<br /> System.out.println("总耗时:" + (endTime - starTime));
  7. }<br />}

输出:
1563767560271,解析sheet1线程,开始处理!
1563767560272,解析sheet2线程,开始处理!
1563767562273,解析sheet1线程,处理完毕,耗时:2002
1563767565274,解析sheet2线程,处理完毕,耗时:5002
总耗时:5005

代码中启动了 2 个解析 sheet 的线程,第一个耗时 2 秒,第二个耗时 5 秒,最终结果中总耗时:5 秒。上面的关键技术点是线程的join()方法,此方法会让当前线程等待被调用的线程完成之后才能继续。可以看一下 join 的源码,内部其实是在 synchronized 方法中调用了线程的 wait 方法,最后被调用的线程执行完毕之后,由 jvm 自动调用其 notifyAll()方法,唤醒所有等待中的线程。这个 notifyAll()方法是由 jvm 内部自动调用的,jdk 源码中是看不到的,需要看 jvm 源码,有兴趣的同学可以去查一下。所以 JDK 不推荐在线程上调用 wait、notify、notifyAll 方法。
而在 JDK1.5 之后的并发包中提供的 CountDownLatch 也可以实现 join 的这个功能。

CountDownLatch 介绍

CountDownLatch 称之为闭锁,它可以使一个或一批线程在闭锁上等待,等到其他线程执行完相应操作后,闭锁打开,这些等待的线程才可以继续执行。确切的说,闭锁在内部维护了一个倒计数器。通过该计数器的值来决定闭锁的状态,从而决定是否允许等待的线程继续执行。
常用方法:
public CountDownLatch(int count):构造方法,count 表示计数器的值,不能小于 0,否则会报异常。
public void await() throws InterruptedException:调用 await()会让当前线程等待,直到计数器为 0 的时候,方法才会返回,此方法会响应线程中断操作。
public boolean await(long timeout, TimeUnit unit) throws InterruptedException:限时等待,在超时之前,计数器变为了 0,方法返回 true,否则直到超时,返回 false,此方法会响应线程中断操作。
public void countDown():让计数器减 1
CountDownLatch 使用步骤:

  1. 创建 CountDownLatch 对象
  2. 调用其实例方法await(),让当前线程等待
  3. 调用countDown()方法,让计数器减 1
  4. 当计数器变为 0 的时候,await()方法会返回

    示例 1:一个简单的示例

    我们使用 CountDownLatch 来完成上面示例中使用 join 实现的功能,代码如下:
    package com.itsoku.chat13;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

/
微信公众号:程序员路人
/
public class Demo2** {

  1. **public** **static** **class** **T** **extends** **Thread** {<br /> //休眠时间(秒)<br /> **int** sleepSeconds;<br /> CountDownLatch countDownLatch;
  2. **public** **T**(String name, **int** sleepSeconds, CountDownLatch countDownLatch) {<br /> **super**(name);<br /> **this**.sleepSeconds = sleepSeconds;<br /> **this**.countDownLatch = countDownLatch;<br /> }
  3. @Override<br /> **public** **void** **run**() {<br /> Thread ct = Thread.currentThread();<br /> **long** startTime = System.currentTimeMillis();<br /> System.out.println(startTime + "," + ct.getName() + ",开始处理!");<br /> **try** {<br /> //模拟耗时操作,休眠sleepSeconds秒<br /> TimeUnit.SECONDS.sleep(**this**.sleepSeconds);<br /> } **catch** (InterruptedException e) {<br /> e.printStackTrace();<br /> } **finally** {<br /> countDownLatch.countDown();<br /> }<br /> **long** endTime = System.currentTimeMillis();<br /> System.out.println(endTime + "," + ct.getName() + ",处理完毕,耗时:" + (endTime - startTime));<br /> }<br /> }
  4. **public** **static** **void** **main**(String[] args) **throws** InterruptedException {<br /> System.out.println(System.currentTimeMillis() + "," + Thread.currentThread().getName() + "线程 start!");<br /> CountDownLatch countDownLatch = **new** CountDownLatch(2);
  5. **long** starTime = System.currentTimeMillis();<br /> T t1 = **new** T("解析sheet1线程", 2, countDownLatch);<br /> t1.start();
  6. T t2 = **new** T("解析sheet2线程", 5, countDownLatch);<br /> t2.start();
  7. countDownLatch.await();<br /> System.out.println(System.currentTimeMillis() + "," + Thread.currentThread().getName() + "线程 end!");<br /> **long** endTime = System.currentTimeMillis();<br /> System.out.println("总耗时:" + (endTime - starTime));
  8. }<br />}

输出:
1563767580511,main线程 start!
1563767580513,解析sheet1线程,开始处理!
1563767580513,解析sheet2线程,开始处理!
1563767582515,解析sheet1线程,处理完毕,耗时:2002
1563767585515,解析sheet2线程,处理完毕,耗时:5002
1563767585515,main线程 end!
总耗时:5003

从结果中看出,效果和 join 实现的效果一样,代码中创建了计数器为 2 的CountDownLatch,主线程中调用countDownLatch.await();会让主线程等待,t1、t2 线程中模拟执行耗时操作,最终在 finally 中调用了countDownLatch.countDown();,此方法每调用一次,CountDownLatch 内部计数器会减 1,当计数器变为 0 的时候,主线程中的 await()会返回,然后继续执行。注意:上面的countDown()这个是必须要执行的方法,所以放在 finally 中执行。

示例 2:等待指定的时间

还是上面的示例,2 个线程解析 2 个 sheet,主线程等待 2 个 sheet 解析完成。主线程说,我等待 2 秒,你们还是无法处理完成,就不等待了,直接返回。如下代码:
package com.itsoku.chat13;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

/
微信公众号:程序员路人
/
public class Demo3** {

  1. **public** **static** **class** **T** **extends** **Thread** {<br /> //休眠时间(秒)<br /> **int** sleepSeconds;<br /> CountDownLatch countDownLatch;
  2. **public** **T**(String name, **int** sleepSeconds, CountDownLatch countDownLatch) {<br /> **super**(name);<br /> **this**.sleepSeconds = sleepSeconds;<br /> **this**.countDownLatch = countDownLatch;<br /> }
  3. @Override<br /> **public** **void** **run**() {<br /> Thread ct = Thread.currentThread();<br /> **long** startTime = System.currentTimeMillis();<br /> System.out.println(startTime + "," + ct.getName() + ",开始处理!");<br /> **try** {<br /> //模拟耗时操作,休眠sleepSeconds秒<br /> TimeUnit.SECONDS.sleep(**this**.sleepSeconds);<br /> } **catch** (InterruptedException e) {<br /> e.printStackTrace();<br /> } **finally** {<br /> countDownLatch.countDown();<br /> }<br /> **long** endTime = System.currentTimeMillis();<br /> System.out.println(endTime + "," + ct.getName() + ",处理完毕,耗时:" + (endTime - startTime));<br /> }<br /> }
  4. **public** **static** **void** **main**(String[] args) **throws** InterruptedException {<br /> System.out.println(System.currentTimeMillis() + "," + Thread.currentThread().getName() + "线程 start!");<br /> CountDownLatch countDownLatch = **new** CountDownLatch(2);
  5. **long** starTime = System.currentTimeMillis();<br /> T t1 = **new** T("解析sheet1线程", 2, countDownLatch);<br /> t1.start();
  6. T t2 = **new** T("解析sheet2线程", 5, countDownLatch);<br /> t2.start();
  7. **boolean** result = countDownLatch.await(2, TimeUnit.SECONDS);
  8. System.out.println(System.currentTimeMillis() + "," + Thread.currentThread().getName() + "线程 end!");<br /> **long** endTime = System.currentTimeMillis();<br /> System.out.println("主线程耗时:" + (endTime - starTime) + ",result:" + result);
  9. }<br />}

输出:
1563767637316,main线程 start!
1563767637320,解析sheet1线程,开始处理!
1563767637320,解析sheet2线程,开始处理!
1563767639321,解析sheet1线程,处理完毕,耗时:2001
1563767639322,main线程 end!
主线程耗时:2004,result:false
1563767642322,解析sheet2线程,处理完毕,耗时:5002

从输出结果中可以看出,线程 2 耗时了 5 秒,主线程耗时了 2 秒,主线程中调用countDownLatch.await(2, TimeUnit.SECONDS);,表示最多等 2 秒,不管计数器是否为 0,await 方法都会返回,若等待时间内,计数器变为 0 了,立即返回 true,否则超时后返回 false。

示例 3:2 个 CountDown 结合使用的示例

有 3 个人参见跑步比赛,需要先等指令员发指令枪后才能开跑,所有人都跑完之后,指令员喊一声,大家跑完了。
示例代码:
package com.itsoku.chat13;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

/
微信公众号:程序员路人
/
public class Demo4** {

  1. **public** **static** **class** **T** **extends** **Thread** {<br /> //跑步耗时(秒)<br /> **int** runCostSeconds;<br /> CountDownLatch commanderCd;<br /> CountDownLatch countDown;
  2. **public** **T**(String name, **int** runCostSeconds, CountDownLatch commanderCd, CountDownLatch countDown) {<br /> **super**(name);<br /> **this**.runCostSeconds = runCostSeconds;<br /> **this**.commanderCd = commanderCd;<br /> **this**.countDown = countDown;<br /> }
  3. @Override<br /> **public** **void** **run**() {<br /> //等待指令员枪响<br /> **try** {<br /> commanderCd.await();<br /> } **catch** (InterruptedException e) {<br /> e.printStackTrace();<br /> }<br /> Thread ct = Thread.currentThread();<br /> **long** startTime = System.currentTimeMillis();<br /> System.out.println(startTime + "," + ct.getName() + ",开始跑!");<br /> **try** {<br /> //模拟耗时操作,休眠runCostSeconds秒<br /> TimeUnit.SECONDS.sleep(**this**.runCostSeconds);<br /> } **catch** (InterruptedException e) {<br /> e.printStackTrace();<br /> } **finally** {<br /> countDown.countDown();<br /> }<br /> **long** endTime = System.currentTimeMillis();<br /> System.out.println(endTime + "," + ct.getName() + ",跑步结束,耗时:" + (endTime - startTime));<br /> }<br /> }
  4. **public** **static** **void** **main**(String[] args) **throws** InterruptedException {<br /> System.out.println(System.currentTimeMillis() + "," + Thread.currentThread().getName() + "线程 start!");<br /> CountDownLatch commanderCd = **new** CountDownLatch(1);<br /> CountDownLatch countDownLatch = **new** CountDownLatch(3);
  5. **long** starTime = System.currentTimeMillis();<br /> T t1 = **new** T("小张", 2, commanderCd, countDownLatch);<br /> t1.start();
  6. T t2 = **new** T("小李", 5, commanderCd, countDownLatch);<br /> t2.start();
  7. T t3 = **new** T("路人甲", 10, commanderCd, countDownLatch);<br /> t3.start();
  8. //主线程休眠5秒,模拟指令员准备发枪耗时操作<br /> TimeUnit.SECONDS.sleep(5);<br /> System.out.println(System.currentTimeMillis() + ",枪响了,大家开始跑");<br /> commanderCd.countDown();
  9. countDownLatch.await();<br /> **long** endTime = System.currentTimeMillis();<br /> System.out.println(System.currentTimeMillis() + "," + Thread.currentThread().getName() + "所有人跑完了,主线程耗时:" + (endTime - starTime));
  10. }<br />}

输出:
1563767691087,main线程 start!
1563767696092,枪响了,大家开始跑
1563767696092,小张,开始跑!
1563767696092,小李,开始跑!
1563767696092,路人甲,开始跑!
1563767698093,小张,跑步结束,耗时:2001
1563767701093,小李,跑步结束,耗时:5001
1563767706093,路人甲,跑步结束,耗时:10001
1563767706093,main所有人跑完了,主线程耗时:15004

代码中,t1、t2、t3 启动之后,都阻塞在commanderCd.await();,主线程模拟发枪准备操作耗时 5 秒,然后调用commanderCd.countDown();模拟发枪操作,此方法被调用以后,阻塞在commanderCd.await();的 3 个线程会向下执行。主线程调用countDownLatch.await();之后进行等待,每个人跑完之后,调用countDown.countDown();通知一下countDownLatch让计数器减 1,最后 3 个人都跑完了,主线程从countDownLatch.await();返回继续向下执行。

手写一个并行处理任务的工具类

package com.itsoku.chat13;

import org.springframework.util.CollectionUtils;

import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/
微信公众号:程序员路人
/
public class TaskDisposeUtils {
//并行线程数
public static final int** POOL_SIZE;

  1. **static** {<br /> POOL_SIZE = Integer.max(Runtime.getRuntime().availableProcessors(), 5);<br /> }
  2. /**<br /> * 并行处理,并等待结束<br /> *<br /> * **@param** taskList 任务列表<br /> * **@param** consumer 消费者<br /> * **@param** <T><br /> * **@throws** InterruptedException<br /> */<br /> **public** **static** <T> **void** **dispose**(List<T> taskList, Consumer<T> consumer) **throws** InterruptedException {<br /> dispose(**true**, POOL_SIZE, taskList, consumer);<br /> }
  3. /**<br /> * 并行处理,并等待结束<br /> *<br /> * **@param** moreThread 是否多线程执行<br /> * **@param** poolSize 线程池大小<br /> * **@param** taskList 任务列表<br /> * **@param** consumer 消费者<br /> * **@param** <T><br /> * **@throws** InterruptedException<br /> */<br /> **public** **static** <T> **void** **dispose**(**boolean** moreThread, **int** poolSize, List<T> taskList, Consumer<T> consumer) **throws** InterruptedException {<br /> **if** (CollectionUtils.isEmpty(taskList)) {<br /> **return**;<br /> }<br /> **if** (moreThread && poolSize > 1) {<br /> poolSize = Math.min(poolSize, taskList.size());<br /> ExecutorService executorService = **null**;<br /> **try** {<br /> executorService = Executors.newFixedThreadPool(poolSize);<br /> CountDownLatch countDownLatch = **new** CountDownLatch(taskList.size());<br /> **for** (T item : taskList) {<br /> executorService.execute(() -> {<br /> **try** {<br /> consumer.accept(item);<br /> } **finally** {<br /> countDownLatch.countDown();<br /> }<br /> });<br /> }<br /> countDownLatch.await();<br /> } **finally** {<br /> **if** (executorService != **null**) {<br /> executorService.shutdown();<br /> }<br /> }<br /> } **else** {<br /> **for** (T item : taskList) {<br /> consumer.accept(item);<br /> }<br /> }<br /> }
  4. **public** **static** **void** **main**(String[] args) **throws** InterruptedException {<br /> //生成1-10的10个数字,放在list中,相当于10个任务<br /> List<Integer> list = Stream.iterate(1, a -> a + 1).limit(10).collect(Collectors.toList());<br /> //启动多线程处理list中的数据,每个任务休眠时间为list中的数值<br /> TaskDisposeUtils.dispose(list, item -> {<br /> **try** {<br /> **long** startTime = System.currentTimeMillis();<br /> TimeUnit.SECONDS.sleep(item);<br /> **long** endTime = System.currentTimeMillis();
  5. System.out.println(System.currentTimeMillis() + ",任务" + item + "执行完毕,耗时:" + (endTime - startTime));<br /> } **catch** (InterruptedException e) {<br /> e.printStackTrace();<br /> }<br /> });<br /> //上面所有任务处理完毕完毕之后,程序才能继续<br /> System.out.println(list + "中的任务都处理完毕!");<br /> }<br />}

运行代码输出:
1563769828130,任务1执行完毕,耗时:1000
1563769829130,任务2执行完毕,耗时:2000
1563769830131,任务3执行完毕,耗时:3001
1563769831131,任务4执行完毕,耗时:4001
1563769832131,任务5执行完毕,耗时:5001
1563769833130,任务6执行完毕,耗时:6000
1563769834131,任务7执行完毕,耗时:7001
1563769835131,任务8执行完毕,耗时:8001
1563769837131,任务9执行完毕,耗时:9001
1563769839131,任务10执行完毕,耗时:10001
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]中的任务都处理完毕!

TaskDisposeUtils 是一个并行处理的工具类,可以传入 n 个任务内部使用线程池进行处理,等待所有任务都处理完成之后,方法才会返回。比如我们发送短信,系统中有 1 万条短信,我们使用上面的工具,每次取 100 条并行发送,待 100 个都处理完毕之后,再取一批按照同样的逻辑发送。