场景
在读取文件、或则需要使用多线程批量入库的时候,往往是需要我们自己来写多线程的调度完成多线程批量入库的功能;
难点:多线程的调度、数据分批的逻辑
:::info 不仅仅用于数据库插入,只要在 固定数量 的 多线程处理 的场景都适用 :::
解决的问题
- 多线程调度、数据分批的逻辑
 - 提供多线程批量插入/处理
 - 提供多线程单条插入/处理
 
实现思路
- 使用 ArrayBlockingQueue 来调节生产方和消费方速度不一致的情况:使用阻塞的 put 来达到让生产方阻塞等待
 - 使用 CountDownLatch 来实现,等待多线程将所有生产的数据都入库/处理完成
 
工具类实现
依赖
// lombok,使用了里面的 @Slf4j 日志工具compileOnly 'org.projectlombok:lombok:1.18.18'testCompileOnly 'org.projectlombok:lombok:1.18.18'annotationProcessor 'org.projectlombok:lombok:1.18.18'// hutool 工具类// 比如 ExceptionUtil 来自于 https://www.hutool.cn/ 工具包中,完全可以换成手动 new 异常implementation 'cn.hutool:hutool-all:5.8.3'
package cn.mrcode;import cn.hutool.core.exceptions.ExceptionUtil;import cn.hutool.core.util.StrUtil;import lombok.extern.slf4j.Slf4j;import java.sql.Struct;import java.util.ArrayList;import java.util.List;import java.util.concurrent.ArrayBlockingQueue;import java.util.concurrent.CountDownLatch;import java.util.concurrent.TimeUnit;import java.util.stream.Collectors;import java.util.stream.IntStream;/*** 多线程批处理器* @author mrcode* @date 2021/6/2 17:52*/@Slf4jpublic class BatchProcessor<T> {/*** 线程名称前缀,可自定义*/private String threadNamePrefix = "reportInsert-";// 是否已经开始处理private boolean started;// 用于等待线程处理结束后的收尾处理private CountDownLatch cdl;// 是否还会产生数据: 用于配合 queue.size() 判断线程是否该结束private volatile boolean isProduceData = true;// 实体数据容器队列,队列满,则限制生产方的生产速度private ArrayBlockingQueue<T> queue;// 消费到一条实体数据,就调用该方法给使用方,使用方可以调用存储接口存储private StorageConsumer<T> consumer;// 批量插入时,每次最多插入多少条private int maxItemCount;private List<WorkThread> workThreads;public BatchProcessor() {this(1000);}/*** <pre>* capacity :利用队列的阻塞 put,来调节生产速度和消费速度的差别* 当生产速度明显大于插入速度时,该参数用来限制生产的速度,达到该上限时,生成方就会阻塞,知道有新的容量空闲出来* </pre>** @param capacity 队列能接收的最大容量*/public BatchProcessor(int capacity) {queue = new ArrayBlockingQueue<>(capacity);}/*** 配置线程名称前缀** @param threadNamePrefix*/public synchronized void setThreadNamePrefix(String threadNamePrefix) {if (started) {throw new RuntimeException("已经开始处理,不能再线程名称前缀");}this.threadNamePrefix = threadNamePrefix;}/*** 默认 4 个线程,每个线程每次处理一条数据** @param consumer 每次达到消费条数时,消费方的消费回调逻辑*/public void start(StorageConsumer<T> consumer) {this.start(consumer, 4);}/*** 默认每个线程每次处理 1 条数据** @param consumer 每次达到消费条数时,消费方的消费回调逻辑* @param workThreadCount 需要并行处理的线程数量,必须大于 0*/public void start(StorageConsumer<T> consumer, int workThreadCount) {this.start(consumer, workThreadCount, 0);}/*** @param consumer 每次达到消费条数时,消费方的消费回调逻辑* @param workThreadCount 需要并行处理的线程数量,必须大于 0* @param maxItemCount 每次每个线程希望的消费数据条数, 0:每个线程每次消费 1 条数据,大于 0 则按照期望的条数进行消费*/public synchronized void start(StorageConsumer<T> consumer,int workThreadCount,int maxItemCount) {this.start(consumer, workThreadCount, maxItemCount, null);}/*** @param consumer 每次达到消费条数时,消费方的消费回调逻辑* 由于是线程处理,所有在消费逻辑处理的时候,建议消费方一定要将逻辑都 try 一下,否则就会进入 uncaughtExceptionHandler 处理异常,并且该工作线程退出工作* @param workThreadCount 需要并行处理的线程数量,必须大于 0* @param maxItemCount 每次每个线程希望的消费数据条数, 0:每个线程每次消费 1 条数据,大于 0 则按照期望的条数进行消费* @param uncaughtExceptionHandler 当抛出异常的时候,该异常如何处理,可以为 null, 如果为 null, 将使用 @Slf4j 日志打印*/public synchronized void start(StorageConsumer<T> consumer,int workThreadCount,int maxItemCount,Thread.UncaughtExceptionHandler uncaughtExceptionHandler) {if (started) {throw new RuntimeException("处理中");}if (workThreadCount <= 0) {throw new IllegalArgumentException("workThreadCount 必须大于 0");}if (maxItemCount < 0) {throw new IllegalArgumentException("maxItemCount 必须大于等于 0");}started = true;this.consumer = consumer;this.maxItemCount = maxItemCount;this.cdl = new CountDownLatch(workThreadCount);if (uncaughtExceptionHandler == null) {uncaughtExceptionHandler = (t, e) -> {log.error(StrUtil.format("工作线程异常退出,threadName={}", t.getName()), e);};}Thread.UncaughtExceptionHandler finalUncaughtExceptionHandler = uncaughtExceptionHandler;workThreads = IntStream.range(0, workThreadCount).mapToObj(i -> {final WorkThread workThread = new WorkThread(threadNamePrefix + i, maxItemCount);workThread.start();// 如果不设置异常处理器,那么当 run 方法抛出异常的时候,会被 java.lang.ThreadGroup.uncaughtException 处理// 然后 ThreadGroup.uncaughtException 的默认处理是使用 System.error 打印错误,和调用 e.printStackTrace(System.err);// 这就会导致在生产环境中使用日志框架的时候,在日志框架里面看不到打印的错误信息,看起来就像异常被吞了workThread.setUncaughtExceptionHandler(finalUncaughtExceptionHandler);return workThread;}).collect(Collectors.toList());}/*** 将实体交给处理器,处理器的线程会消费该实体;* <pre>* 当容器队列已满时,则会阻塞,以此达到生产方暂停生产的目的;可以防止生产速度过快(消费速度过慢),导致占用过多内存* </pre>** @param entity*/public void put(T entity) {try {queue.put(entity);} catch (InterruptedException e) {ExceptionUtil.wrapAndThrow(e);}}/*** 等待,处理器处理完成;此方法会阻塞*/public void await() {if (!started) {throw new RuntimeException("还未运行");}try {isProduceData = false;cdl.await();for (WorkThread workThread : workThreads) {workThread.clearEntity();}} catch (InterruptedException e) {ExceptionUtil.wrapAndThrow(e);}}/*** 立即停止,只适合在生产方不生产数据时,调用*/public void stop() {if (!started) {throw new RuntimeException("还未运行");}isProduceData = false;queue.clear();}private class WorkThread extends Thread {// 批量插入时,用于缓存实体的容器private List<T> batchCacheContainer;private int maxItemCount;public WorkThread(String name, int maxItemCount) {super(name);this.maxItemCount = maxItemCount;if (maxItemCount > 0) {batchCacheContainer = new ArrayList<>(maxItemCount);}}@Overridepublic void run() {try {doRun();} catch (InterruptedException e) {log.debug("工作线程收到中断异常退出", e);} finally {cdl.countDown();}}private void doRun() throws InterruptedException {while (true) {// 如果不产生数据了,队列也会空,则退出线程if (!isProduceData && queue.size() == 0) {break;}final T entity;entity = queue.poll(500, TimeUnit.MILLISECONDS);if (entity == null) {continue;}if (maxItemCount > 0) {batchCacheContainer.add(entity);if (batchCacheContainer.size() >= maxItemCount) {consumer.accept(null, batchCacheContainer);batchCacheContainer.clear();}} else {consumer.accept(entity, null);}}}public void clearEntity() {if (maxItemCount > 0 && batchCacheContainer.size() > 0) {consumer.accept(null, batchCacheContainer);batchCacheContainer.clear();}}}public interface StorageConsumer<T> {/*** 需要使用方存储数据时,会调用该方法** @param t* @param ts*/void accept(T t, List<T> ts);}}
用法测试
package cn.mrcode;import org.junit.jupiter.api.Test;/*** @author mrcode* @date 2021/6/3 23:24*/class BatchProcessorTest {/*** 批量插入测试*/@Testpublic void batchInsert() {final BatchProcessor<DemoEntity> work = new BatchProcessor<>();work.start((t, ts) -> {System.out.println("插入数据库条数:" + ts.size());},4, 4);// 模拟生产数据try {for (int i = 0; i < 21; i++) {work.put(new DemoEntity(i, i + " name"));}// 等待入库完成work.await();} catch (Exception e) {// 如果生产过程中有异常,立即停止掉处理器,不再入库work.stop();}}/*** 单条插入测试*/@Testpublic void insert() {final BatchProcessor<DemoEntity> work = new BatchProcessor<>();work.start((t, ts) -> {System.out.println("插入数据库:" + t);},4, 0);// 模拟生产数据try {for (int i = 0; i < 5; i++) {work.put(new DemoEntity(i, i + " name"));}// 等待入库完成work.await();} catch (Exception e) {// 如果生产过程中有异常,立即停止掉处理器,不再入库work.stop();}}/*** 异常测试 - 不自定义异常处理器*/@Testpublic void exceptionTest() {// 看看在消费逻辑中发现业务异常,会发生什么事情final BatchProcessor<DemoEntity> work = new BatchProcessor<>();work.start((t, ts) -> {if (true) {// 会抛出 ArithmeticException: / by zero 异常int a = 1 / 0;}System.out.println("插入数据库:" + t);},2, 0,// 异常处理器,如果为 null, BatchProcessor 工具会捕获,并使用 Slf4j error 级别打印日志// 如果框架不做这个处理,jdk 会使用 System.err.out 打印到控制台,所以在线上生产环境,就不会记录到日志文件中// 当出现问题的时候,就很难发现出现了什么问题null);// 模拟生产数据try {for (int i = 0; i < 5; i++) {work.put(new DemoEntity(i, i + " name"));}// 等待入库完成work.await();System.out.println("处理完成");} catch (Exception e) {// 如果生产过程中有异常,立即停止掉处理器,不再入库System.err.println("异常处理完成");work.stop();}}/*** 异常测试 - 自定义异常处理器*/@Testpublic void exceptionHandlerTest() {// 看看在消费逻辑中发现业务异常,会发生什么事情final BatchProcessor<DemoEntity> work = new BatchProcessor<>();work.start((t, ts) -> {if (true) {// 会抛出 ArithmeticException: / by zero 异常int a = 1 / 0;}System.out.println("插入数据库:" + t);},2, 0,// 自定义异常处理器new Thread.UncaughtExceptionHandler() {@Overridepublic void uncaughtException(Thread t, Throwable e) {System.out.println("工作线程异常退出");e.printStackTrace();}});// 模拟生产数据try {for (int i = 0; i < 5; i++) {work.put(new DemoEntity(i, i + " name"));}// 等待入库完成work.await();System.out.println("处理完成");} catch (Exception e) {// 如果生产过程中有异常,立即停止掉处理器,不再入库System.err.println("异常处理完成");work.stop();}}/*** 测试 实体*/private class DemoEntity {private int id;private String name;public DemoEntity(int id, String name) {this.id = id;this.name = name;}public int getId() {return id;}public void setId(int id) {this.id = id;}public String getName() {return name;}public void setName(String name) {this.name = name;}@Overridepublic String toString() {return "DemoEntity{" +"id=" + id +", name='" + name + '\'' +'}';}}}
测试输出
// 批量插入输出插入数据库条数:4插入数据库条数:4插入数据库条数:4插入数据库条数:4插入数据库条数:3插入数据库条数:2// 单条插入输出插入数据库:DemoEntity{id=0, name='0 name'}插入数据库:DemoEntity{id=1, name='1 name'}插入数据库:DemoEntity{id=3, name='3 name'}插入数据库:DemoEntity{id=4, name='4 name'}插入数据库:DemoEntity{id=2, name='2 name'}
异常处理器相关测试
// 异常测试 - 不自定义异常处理器处理完成10:52:47.847 [reportInsert-0] ERROR cn.mrcode.BatchProcessor - 工作线程异常退出,threadName=reportInsert-0java.lang.ArithmeticException: / by zeroat cn.mrcode.BatchProcessorTest.lambda$exceptionTest$2(BatchProcessorTest.java:67)at cn.mrcode.BatchProcessor$WorkThread.doRun(BatchProcessor.java:270)at cn.mrcode.BatchProcessor$WorkThread.run(BatchProcessor.java:244)10:52:47.847 [reportInsert-1] ERROR cn.mrcode.BatchProcessor - 工作线程异常退出,threadName=reportInsert-1java.lang.ArithmeticException: / by zeroat cn.mrcode.BatchProcessorTest.lambda$exceptionTest$2(BatchProcessorTest.java:67)at cn.mrcode.BatchProcessor$WorkThread.doRun(BatchProcessor.java:270)at cn.mrcode.BatchProcessor$WorkThread.run(BatchProcessor.java:244)// 异常测试 - 自定义异常处理器工作线程异常退出工作线程异常退出处理完成java.lang.ArithmeticException: / by zeroat cn.mrcode.BatchProcessorTest.lambda$exceptionHandlerTest$3(BatchProcessorTest.java:102)at cn.mrcode.BatchProcessor$WorkThread.doRun(BatchProcessor.java:270)at cn.mrcode.BatchProcessor$WorkThread.run(BatchProcessor.java:244)java.lang.ArithmeticException: / by zeroat cn.mrcode.BatchProcessorTest.lambda$exceptionHandlerTest$3(BatchProcessorTest.java:102)at cn.mrcode.BatchProcessor$WorkThread.doRun(BatchProcessor.java:270)at cn.mrcode.BatchProcessor$WorkThread.run(BatchProcessor.java:244)
注意事项
作为批量入库的注意事项
/*** 多线程批量插入测试*/@Testpublic void batchInsert() {final BatchInsertProcessor<DemoEntity> work = new BatchInsertProcessor<>();work.start((t, ts) -> {System.out.println("插入数据库条数:" + ts.size());在这里调用数据库批量插入,记住是 批量插入,将 ts 整个一次性插入到数据库},4, 4);// 模拟生产数据try {for (int i = 0; i < 21; i++) {work.put(new DemoEntity(i, i + " name"));}// 等待入库完成work.await();} catch (Exception e) {// 如果生产过程中有异常,立即停止掉处理器,不再入库work.stop();}}
在 mybatis 中批量插入如下所示
<insert id="batchInsertProductCategory" parameterType="java.util.List">INSERT INTOtb_product_category(product_category_name,priority,shop_id)VALUES<foreach collection="list" item="productCategory" index="index" separator=",">(#{productCategory.productCategoryName},#{productCategory.priority},#{productCategory.shopId})</foreach></insert>
