场景

在读取文件、或则需要使用多线程批量入库的时候,往往是需要我们自己来写多线程的调度完成多线程批量入库的功能;

难点:多线程的调度、数据分批的逻辑

:::info 不仅仅用于数据库插入,只要在 固定数量 多线程处理 的场景都适用 :::

解决的问题

  1. 多线程调度、数据分批的逻辑
  2. 提供多线程批量插入/处理
  3. 提供多线程单条插入/处理

实现思路

  1. 使用 ArrayBlockingQueue 来调节生产方和消费方速度不一致的情况:使用阻塞的 put 来达到让生产方阻塞等待
  2. 使用 CountDownLatch 来实现,等待多线程将所有生产的数据都入库/处理完成

工具类实现

依赖

  1. // lombok,使用了里面的 @Slf4j 日志工具
  2. compileOnly 'org.projectlombok:lombok:1.18.18'
  3. testCompileOnly 'org.projectlombok:lombok:1.18.18'
  4. annotationProcessor 'org.projectlombok:lombok:1.18.18'
  5. // hutool 工具类
  6. // 比如 ExceptionUtil 来自于 https://www.hutool.cn/ 工具包中,完全可以换成手动 new 异常
  7. implementation 'cn.hutool:hutool-all:5.8.3'
  1. package cn.mrcode;
  2. import cn.hutool.core.exceptions.ExceptionUtil;
  3. import cn.hutool.core.util.StrUtil;
  4. import lombok.extern.slf4j.Slf4j;
  5. import java.sql.Struct;
  6. import java.util.ArrayList;
  7. import java.util.List;
  8. import java.util.concurrent.ArrayBlockingQueue;
  9. import java.util.concurrent.CountDownLatch;
  10. import java.util.concurrent.TimeUnit;
  11. import java.util.stream.Collectors;
  12. import java.util.stream.IntStream;
  13. /**
  14. * 多线程批处理器
  15. * @author mrcode
  16. * @date 2021/6/2 17:52
  17. */
  18. @Slf4j
  19. public class BatchProcessor<T> {
  20. /**
  21. * 线程名称前缀,可自定义
  22. */
  23. private String threadNamePrefix = "reportInsert-";
  24. // 是否已经开始处理
  25. private boolean started;
  26. // 用于等待线程处理结束后的收尾处理
  27. private CountDownLatch cdl;
  28. // 是否还会产生数据: 用于配合 queue.size() 判断线程是否该结束
  29. private volatile boolean isProduceData = true;
  30. // 实体数据容器队列,队列满,则限制生产方的生产速度
  31. private ArrayBlockingQueue<T> queue;
  32. // 消费到一条实体数据,就调用该方法给使用方,使用方可以调用存储接口存储
  33. private StorageConsumer<T> consumer;
  34. // 批量插入时,每次最多插入多少条
  35. private int maxItemCount;
  36. private List<WorkThread> workThreads;
  37. public BatchProcessor() {
  38. this(1000);
  39. }
  40. /**
  41. * <pre>
  42. * capacity :利用队列的阻塞 put,来调节生产速度和消费速度的差别
  43. * 当生产速度明显大于插入速度时,该参数用来限制生产的速度,达到该上限时,生成方就会阻塞,知道有新的容量空闲出来
  44. * </pre>
  45. *
  46. * @param capacity 队列能接收的最大容量
  47. */
  48. public BatchProcessor(int capacity) {
  49. queue = new ArrayBlockingQueue<>(capacity);
  50. }
  51. /**
  52. * 配置线程名称前缀
  53. *
  54. * @param threadNamePrefix
  55. */
  56. public synchronized void setThreadNamePrefix(String threadNamePrefix) {
  57. if (started) {
  58. throw new RuntimeException("已经开始处理,不能再线程名称前缀");
  59. }
  60. this.threadNamePrefix = threadNamePrefix;
  61. }
  62. /**
  63. * 默认 4 个线程,每个线程每次处理一条数据
  64. *
  65. * @param consumer 每次达到消费条数时,消费方的消费回调逻辑
  66. */
  67. public void start(StorageConsumer<T> consumer) {
  68. this.start(consumer, 4);
  69. }
  70. /**
  71. * 默认每个线程每次处理 1 条数据
  72. *
  73. * @param consumer 每次达到消费条数时,消费方的消费回调逻辑
  74. * @param workThreadCount 需要并行处理的线程数量,必须大于 0
  75. */
  76. public void start(StorageConsumer<T> consumer, int workThreadCount) {
  77. this.start(consumer, workThreadCount, 0);
  78. }
  79. /**
  80. * @param consumer 每次达到消费条数时,消费方的消费回调逻辑
  81. * @param workThreadCount 需要并行处理的线程数量,必须大于 0
  82. * @param maxItemCount 每次每个线程希望的消费数据条数, 0:每个线程每次消费 1 条数据,大于 0 则按照期望的条数进行消费
  83. */
  84. public synchronized void start(StorageConsumer<T> consumer,
  85. int workThreadCount,
  86. int maxItemCount) {
  87. this.start(consumer, workThreadCount, maxItemCount, null);
  88. }
  89. /**
  90. * @param consumer 每次达到消费条数时,消费方的消费回调逻辑
  91. * 由于是线程处理,所有在消费逻辑处理的时候,建议消费方一定要将逻辑都 try 一下,否则就会进入 uncaughtExceptionHandler 处理异常,并且该工作线程退出工作
  92. * @param workThreadCount 需要并行处理的线程数量,必须大于 0
  93. * @param maxItemCount 每次每个线程希望的消费数据条数, 0:每个线程每次消费 1 条数据,大于 0 则按照期望的条数进行消费
  94. * @param uncaughtExceptionHandler 当抛出异常的时候,该异常如何处理,可以为 null, 如果为 null, 将使用 @Slf4j 日志打印
  95. */
  96. public synchronized void start(StorageConsumer<T> consumer,
  97. int workThreadCount,
  98. int maxItemCount,
  99. Thread.UncaughtExceptionHandler uncaughtExceptionHandler) {
  100. if (started) {
  101. throw new RuntimeException("处理中");
  102. }
  103. if (workThreadCount <= 0) {
  104. throw new IllegalArgumentException("workThreadCount 必须大于 0");
  105. }
  106. if (maxItemCount < 0) {
  107. throw new IllegalArgumentException("maxItemCount 必须大于等于 0");
  108. }
  109. started = true;
  110. this.consumer = consumer;
  111. this.maxItemCount = maxItemCount;
  112. this.cdl = new CountDownLatch(workThreadCount);
  113. if (uncaughtExceptionHandler == null) {
  114. uncaughtExceptionHandler = (t, e) -> {
  115. log.error(StrUtil.format("工作线程异常退出,threadName={}", t.getName()), e);
  116. };
  117. }
  118. Thread.UncaughtExceptionHandler finalUncaughtExceptionHandler = uncaughtExceptionHandler;
  119. workThreads = IntStream.range(0, workThreadCount)
  120. .mapToObj(i -> {
  121. final WorkThread workThread = new WorkThread(threadNamePrefix + i, maxItemCount);
  122. workThread.start();
  123. // 如果不设置异常处理器,那么当 run 方法抛出异常的时候,会被 java.lang.ThreadGroup.uncaughtException 处理
  124. // 然后 ThreadGroup.uncaughtException 的默认处理是使用 System.error 打印错误,和调用 e.printStackTrace(System.err);
  125. // 这就会导致在生产环境中使用日志框架的时候,在日志框架里面看不到打印的错误信息,看起来就像异常被吞了
  126. workThread.setUncaughtExceptionHandler(finalUncaughtExceptionHandler);
  127. return workThread;
  128. })
  129. .collect(Collectors.toList());
  130. }
  131. /**
  132. * 将实体交给处理器,处理器的线程会消费该实体;
  133. * <pre>
  134. * 当容器队列已满时,则会阻塞,以此达到生产方暂停生产的目的;可以防止生产速度过快(消费速度过慢),导致占用过多内存
  135. * </pre>
  136. *
  137. * @param entity
  138. */
  139. public void put(T entity) {
  140. try {
  141. queue.put(entity);
  142. } catch (InterruptedException e) {
  143. ExceptionUtil.wrapAndThrow(e);
  144. }
  145. }
  146. /**
  147. * 等待,处理器处理完成;此方法会阻塞
  148. */
  149. public void await() {
  150. if (!started) {
  151. throw new RuntimeException("还未运行");
  152. }
  153. try {
  154. isProduceData = false;
  155. cdl.await();
  156. for (WorkThread workThread : workThreads) {
  157. workThread.clearEntity();
  158. }
  159. } catch (InterruptedException e) {
  160. ExceptionUtil.wrapAndThrow(e);
  161. }
  162. }
  163. /**
  164. * 立即停止,只适合在生产方不生产数据时,调用
  165. */
  166. public void stop() {
  167. if (!started) {
  168. throw new RuntimeException("还未运行");
  169. }
  170. isProduceData = false;
  171. queue.clear();
  172. }
  173. private class WorkThread extends Thread {
  174. // 批量插入时,用于缓存实体的容器
  175. private List<T> batchCacheContainer;
  176. private int maxItemCount;
  177. public WorkThread(String name, int maxItemCount) {
  178. super(name);
  179. this.maxItemCount = maxItemCount;
  180. if (maxItemCount > 0) {
  181. batchCacheContainer = new ArrayList<>(maxItemCount);
  182. }
  183. }
  184. @Override
  185. public void run() {
  186. try {
  187. doRun();
  188. } catch (InterruptedException e) {
  189. log.debug("工作线程收到中断异常退出", e);
  190. } finally {
  191. cdl.countDown();
  192. }
  193. }
  194. private void doRun() throws InterruptedException {
  195. while (true) {
  196. // 如果不产生数据了,队列也会空,则退出线程
  197. if (!isProduceData && queue.size() == 0) {
  198. break;
  199. }
  200. final T entity;
  201. entity = queue.poll(500, TimeUnit.MILLISECONDS);
  202. if (entity == null) {
  203. continue;
  204. }
  205. if (maxItemCount > 0) {
  206. batchCacheContainer.add(entity);
  207. if (batchCacheContainer.size() >= maxItemCount) {
  208. consumer.accept(null, batchCacheContainer);
  209. batchCacheContainer.clear();
  210. }
  211. } else {
  212. consumer.accept(entity, null);
  213. }
  214. }
  215. }
  216. public void clearEntity() {
  217. if (maxItemCount > 0 && batchCacheContainer.size() > 0) {
  218. consumer.accept(null, batchCacheContainer);
  219. batchCacheContainer.clear();
  220. }
  221. }
  222. }
  223. public interface StorageConsumer<T> {
  224. /**
  225. * 需要使用方存储数据时,会调用该方法
  226. *
  227. * @param t
  228. * @param ts
  229. */
  230. void accept(T t, List<T> ts);
  231. }
  232. }

用法测试

  1. package cn.mrcode;
  2. import org.junit.jupiter.api.Test;
  3. /**
  4. * @author mrcode
  5. * @date 2021/6/3 23:24
  6. */
  7. class BatchProcessorTest {
  8. /**
  9. * 批量插入测试
  10. */
  11. @Test
  12. public void batchInsert() {
  13. final BatchProcessor<DemoEntity> work = new BatchProcessor<>();
  14. work.start((t, ts) -> {
  15. System.out.println("插入数据库条数:" + ts.size());
  16. },
  17. 4, 4);
  18. // 模拟生产数据
  19. try {
  20. for (int i = 0; i < 21; i++) {
  21. work.put(new DemoEntity(i, i + " name"));
  22. }
  23. // 等待入库完成
  24. work.await();
  25. } catch (Exception e) {
  26. // 如果生产过程中有异常,立即停止掉处理器,不再入库
  27. work.stop();
  28. }
  29. }
  30. /**
  31. * 单条插入测试
  32. */
  33. @Test
  34. public void insert() {
  35. final BatchProcessor<DemoEntity> work = new BatchProcessor<>();
  36. work.start((t, ts) -> {
  37. System.out.println("插入数据库:" + t);
  38. },
  39. 4, 0);
  40. // 模拟生产数据
  41. try {
  42. for (int i = 0; i < 5; i++) {
  43. work.put(new DemoEntity(i, i + " name"));
  44. }
  45. // 等待入库完成
  46. work.await();
  47. } catch (Exception e) {
  48. // 如果生产过程中有异常,立即停止掉处理器,不再入库
  49. work.stop();
  50. }
  51. }
  52. /**
  53. * 异常测试 - 不自定义异常处理器
  54. */
  55. @Test
  56. public void exceptionTest() {
  57. // 看看在消费逻辑中发现业务异常,会发生什么事情
  58. final BatchProcessor<DemoEntity> work = new BatchProcessor<>();
  59. work.start((t, ts) -> {
  60. if (true) {
  61. // 会抛出 ArithmeticException: / by zero 异常
  62. int a = 1 / 0;
  63. }
  64. System.out.println("插入数据库:" + t);
  65. },
  66. 2, 0,
  67. // 异常处理器,如果为 null, BatchProcessor 工具会捕获,并使用 Slf4j error 级别打印日志
  68. // 如果框架不做这个处理,jdk 会使用 System.err.out 打印到控制台,所以在线上生产环境,就不会记录到日志文件中
  69. // 当出现问题的时候,就很难发现出现了什么问题
  70. null);
  71. // 模拟生产数据
  72. try {
  73. for (int i = 0; i < 5; i++) {
  74. work.put(new DemoEntity(i, i + " name"));
  75. }
  76. // 等待入库完成
  77. work.await();
  78. System.out.println("处理完成");
  79. } catch (Exception e) {
  80. // 如果生产过程中有异常,立即停止掉处理器,不再入库
  81. System.err.println("异常处理完成");
  82. work.stop();
  83. }
  84. }
  85. /**
  86. * 异常测试 - 自定义异常处理器
  87. */
  88. @Test
  89. public void exceptionHandlerTest() {
  90. // 看看在消费逻辑中发现业务异常,会发生什么事情
  91. final BatchProcessor<DemoEntity> work = new BatchProcessor<>();
  92. work.start((t, ts) -> {
  93. if (true) {
  94. // 会抛出 ArithmeticException: / by zero 异常
  95. int a = 1 / 0;
  96. }
  97. System.out.println("插入数据库:" + t);
  98. },
  99. 2, 0,
  100. // 自定义异常处理器
  101. new Thread.UncaughtExceptionHandler() {
  102. @Override
  103. public void uncaughtException(Thread t, Throwable e) {
  104. System.out.println("工作线程异常退出");
  105. e.printStackTrace();
  106. }
  107. });
  108. // 模拟生产数据
  109. try {
  110. for (int i = 0; i < 5; i++) {
  111. work.put(new DemoEntity(i, i + " name"));
  112. }
  113. // 等待入库完成
  114. work.await();
  115. System.out.println("处理完成");
  116. } catch (Exception e) {
  117. // 如果生产过程中有异常,立即停止掉处理器,不再入库
  118. System.err.println("异常处理完成");
  119. work.stop();
  120. }
  121. }
  122. /**
  123. * 测试 实体
  124. */
  125. private class DemoEntity {
  126. private int id;
  127. private String name;
  128. public DemoEntity(int id, String name) {
  129. this.id = id;
  130. this.name = name;
  131. }
  132. public int getId() {
  133. return id;
  134. }
  135. public void setId(int id) {
  136. this.id = id;
  137. }
  138. public String getName() {
  139. return name;
  140. }
  141. public void setName(String name) {
  142. this.name = name;
  143. }
  144. @Override
  145. public String toString() {
  146. return "DemoEntity{" +
  147. "id=" + id +
  148. ", name='" + name + '\'' +
  149. '}';
  150. }
  151. }
  152. }

测试输出

  1. // 批量插入输出
  2. 插入数据库条数:4
  3. 插入数据库条数:4
  4. 插入数据库条数:4
  5. 插入数据库条数:4
  6. 插入数据库条数:3
  7. 插入数据库条数:2
  8. // 单条插入输出
  9. 插入数据库:DemoEntity{id=0, name='0 name'}
  10. 插入数据库:DemoEntity{id=1, name='1 name'}
  11. 插入数据库:DemoEntity{id=3, name='3 name'}
  12. 插入数据库:DemoEntity{id=4, name='4 name'}
  13. 插入数据库:DemoEntity{id=2, name='2 name'}

异常处理器相关测试

  1. // 异常测试 - 不自定义异常处理器
  2. 处理完成
  3. 10:52:47.847 [reportInsert-0] ERROR cn.mrcode.BatchProcessor - 工作线程异常退出,threadName=reportInsert-0
  4. java.lang.ArithmeticException: / by zero
  5. at cn.mrcode.BatchProcessorTest.lambda$exceptionTest$2(BatchProcessorTest.java:67)
  6. at cn.mrcode.BatchProcessor$WorkThread.doRun(BatchProcessor.java:270)
  7. at cn.mrcode.BatchProcessor$WorkThread.run(BatchProcessor.java:244)
  8. 10:52:47.847 [reportInsert-1] ERROR cn.mrcode.BatchProcessor - 工作线程异常退出,threadName=reportInsert-1
  9. java.lang.ArithmeticException: / by zero
  10. at cn.mrcode.BatchProcessorTest.lambda$exceptionTest$2(BatchProcessorTest.java:67)
  11. at cn.mrcode.BatchProcessor$WorkThread.doRun(BatchProcessor.java:270)
  12. at cn.mrcode.BatchProcessor$WorkThread.run(BatchProcessor.java:244)
  13. // 异常测试 - 自定义异常处理器
  14. 工作线程异常退出
  15. 工作线程异常退出
  16. 处理完成
  17. java.lang.ArithmeticException: / by zero
  18. at cn.mrcode.BatchProcessorTest.lambda$exceptionHandlerTest$3(BatchProcessorTest.java:102)
  19. at cn.mrcode.BatchProcessor$WorkThread.doRun(BatchProcessor.java:270)
  20. at cn.mrcode.BatchProcessor$WorkThread.run(BatchProcessor.java:244)
  21. java.lang.ArithmeticException: / by zero
  22. at cn.mrcode.BatchProcessorTest.lambda$exceptionHandlerTest$3(BatchProcessorTest.java:102)
  23. at cn.mrcode.BatchProcessor$WorkThread.doRun(BatchProcessor.java:270)
  24. at cn.mrcode.BatchProcessor$WorkThread.run(BatchProcessor.java:244)

注意事项

作为批量入库的注意事项

  1. /**
  2. * 多线程批量插入测试
  3. */
  4. @Test
  5. public void batchInsert() {
  6. final BatchInsertProcessor<DemoEntity> work = new BatchInsertProcessor<>();
  7. work.start((t, ts) -> {
  8. System.out.println("插入数据库条数:" + ts.size());
  9. 在这里调用数据库批量插入,记住是 批量插入,将 ts 整个一次性插入到数据库
  10. },
  11. 4, 4);
  12. // 模拟生产数据
  13. try {
  14. for (int i = 0; i < 21; i++) {
  15. work.put(new DemoEntity(i, i + " name"));
  16. }
  17. // 等待入库完成
  18. work.await();
  19. } catch (Exception e) {
  20. // 如果生产过程中有异常,立即停止掉处理器,不再入库
  21. work.stop();
  22. }
  23. }

在 mybatis 中批量插入如下所示

  1. <insert id="batchInsertProductCategory" parameterType="java.util.List">
  2. INSERT INTO
  3. tb_product_category(product_category_name,priority,shop_id)
  4. VALUES
  5. <foreach collection="list" item="productCategory" index="index" separator=",">
  6. (
  7. #{productCategory.productCategoryName},
  8. #{productCategory.priority},
  9. #{productCategory.shopId}
  10. )
  11. </foreach>
  12. </insert>