@[toc]


1. 引入

假设,现在需要求解集合给定区间[start , end]内连续元素的和,你会怎么做呢?首先,一种最简单的方法就是在main方法中直接计算,或者新开一个线程来计算,如下所示:

  1. /**
  2. * @Author dyliang
  3. * @Date 2020/9/9 14:20
  4. * @Version 1.0
  5. */
  6. public class Demo {
  7. public static void main(String[] args) {
  8. int result = compute(1, 10000);
  9. System.out.println(result); // 5050
  10. }
  11. public static int compute(int start, int end){
  12. int result = 0;
  13. for(int i = start; i <= end; i ++){
  14. result += i;
  15. }
  16. return result;
  17. }
  18. }

当需要计算的数据比较少时,程序运行耗时较短,但是当数据量很大时,运行可能会很慢。那有没有办法提升程序的效率呢?分析需求可以知道,区间可以继续分段,直到数据可以两两相加,最后将片段计算的结果逐层相加,采用分治的思想来进行解决。在多线程环境下,我们可以使用Java中提供的Fork/Join框架来完成上述的任务。


2. Fork/Join

2.1 概念

Fork/Join中JDK 1.7 之后引入的一个用于并行执行任务的框架,它借助分治的思想,将一个大任务分割为若干个小任务,直到不能拆分可以直接求解,最后汇总每个小任务的结果后得到大任务的结果。例如,使用Fork/Join框架计算1 + 2 + ... + 6的示意图如下所示:
你会使用Fork_Join框架更好的解决并发问题嘛? - 图1

Fork/Join框架会将每个任务的分解和合并交给不同的线程来完成,进一步的提升运算的效率,Fork/Join框架默认会创建与CPU核心数相同数量的线程。

2.2 执行过程

使用Fork/Join求解大任务的流程可以分为两步:

  • 分割任务:即大任务拆分为小任务的过程,拆分进行知道分割得到的子任务足够的小。Fork/Join对应的类为ForkJoinTask,它提供了任务中fork和join的执行机制,常用的两个子类有:

    • RecursiveAction:没有返回结果
    • RecusiveTask:有返回结果

compute()

  • 执行任务并合并结果:上一步分割得到的众多子任务会分别放在双端队列中,然后启动几个线程分别从队列中获取任务并执行。而且,子任务执行的结果也会放在一个队列中,启动一个线程从队列中拿数据,最后合并这些数据得到最终的结果。Fork/Join对应的类为ForkJoinPool,当一个工作线程的队列中暂时没有任务时,它会随机的从其他工作线程的队列的尾部获取一个任务帮助执行,这也称为工作窃取

    工作窃取指线程从其他队列中窃取任务执行的过程。对大任务进行拆分时,通常将拆分后的小任务分别放入不同的队列(通常使用双端队列,窃取线程只能从尾部取任务,本身的线程只能从头部取任务)中,然后每个队列都有各自的线程来执行任务。但是,有时某个线程提前执行完了自己的任务,它就会到其它还包含任务的队列中继续取任务来执行,直到所有队列中的任务都执行完毕。

通过工作窃取可以进一步的有效利用多线程的优势,队列的存在也可以避免线程间的竞争。但并不能完全避免竞争,例如当队列中只有一个任务时,窃取的线程和队列本身的线程都想要获取执行权,不免就会发生竞争。另外,队列和线程的创建与撤销也会消耗一定的系统资源。

2.3 使用

下面我们使用Fork/Join框架来解决第一部分的求和问题,代码如下:

  1. /**
  2. * @Author dyliang
  3. * @Date 2020/9/9 15:01
  4. * @Version 1.0
  5. */
  6. class ForkJoinDemo{
  7. public static void main(String[] args) {
  8. Task task = new Task(1, 100);
  9. // 创建ForkJoinPool,任务交给它执行
  10. ForkJoinPool pool = new ForkJoinPool();
  11. Integer r = pool.invoke(task);
  12. System.out.println(r);
  13. }
  14. }
  15. // 继承需要返回结果的RecursiveTask,并重写compute方法
  16. class Task extends RecursiveTask<Integer> {
  17. private int start;
  18. private int end;
  19. public Task(int start, int end) {
  20. this.start = start;
  21. this.end = end;
  22. }
  23. @Override
  24. protected Integer compute() {
  25. int result = 0;
  26. // 拆分的终止条件
  27. if(end - start <= 2){
  28. for (int i = start; i <= end ; i++) {
  29. result += i;
  30. }
  31. } else {
  32. // 任务拆分
  33. int mid = (end + start) / 2;
  34. Task left = new Task(start, mid);
  35. Task right = new Task(mid + 1, end);
  36. invokeAll(left, right);
  37. // 错误写法
  38. // t.fork();
  39. // ht.fork();
  40. // 结果合并
  41. Integer lr = left.join();
  42. Integer rr = right.join();
  43. result = lr + rr;
  44. }
  45. return result;
  46. }
  47. }

3. 源码分析

你会使用Fork_Join框架更好的解决并发问题嘛? - 图2

从上面示例中代码的执行过程来看,Rork/Join和线程池的执行有些类似,下面看一下fork()join()的实现,从而验证一下我们的直觉。

ForkJoinTask是一个抽象类,它的定义如下:

  1. public abstract class ForkJoinTask<V> implements Future<V>, Serializable {
  2. volatile int status; // accessed directly by pool and workers
  3. static final int DONE_MASK = 0xf0000000; // mask out non-completion bits
  4. static final int NORMAL = 0xf0000000; // must be negative
  5. static final int CANCELLED = 0xc0000000; // must be < NORMAL
  6. static final int EXCEPTIONAL = 0x80000000; // must be < CANCELLED
  7. static final int SIGNAL = 0x00010000; // must be >= 1 << 16
  8. static final int SMASK = 0x0000ffff; // short bits for tags
  9. // 其他代码
  10. }

其中fork()的源码实现如下,它用于将当前的任务推到当前工作线程的工作队列中

  1. public final ForkJoinTask<V> fork() {
  2. // 创建一个线程对象
  3. Thread t;
  4. // 判断,如果当前执行fork操作的线程是ForkJoinWorkerThread中创建的线程
  5. // 则将fork任务压入到ForkJoinWorkerThread的工作队列中
  6. if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
  7. ((ForkJoinWorkerThread)t).workQueue.push(this);
  8. else
  9. ForkJoinPool.common.externalPush(this);
  10. return this;
  11. }

其中ForkJoinWorkerThreadForkJoinPool中重要的元素之一,它负责执行ForkJoinPoolForkJoinTask数组中存放的任务,它本身是Thread类的一个子类实现。workQueueForkJoinPool定义的一个静态内部类,用于任务接收和释放的队列使用,即前面所说的双端队列。其中push()的源码如下所示:

  1. final void push(ForkJoinTask<?> task) {
  2. ForkJoinTask<?>[] a;
  3. ForkJoinPool p;
  4. int b = base, s = top, n;
  5. // 将任务放入到ForkJoinTask数组中
  6. if ((a = array) != null) { // ignore if queue removed
  7. int m = a.length - 1; // fenced write for task visibility
  8. U.putOrderedObject(a, ((m & s) << ASHIFT) + ABASE, task);
  9. U.putOrderedInt(this, QTOP, s + 1);
  10. if ((n = s - b) <= 1) {
  11. if ((p = pool) != null)
  12. // 调用ForkJoinPool的signalWork方法唤醒或创建一个工作线程执行任务
  13. p.signalWork(p.workQueues, this);
  14. }
  15. else if (n >= m)
  16. growArray();
  17. }
  18. }

join()的源码实现如下,它用于阻塞当前线程并等待获取结果,类似于Thread类中的join()

  1. public final V join() {
  2. int s;
  3. if ((s = doJoin() & DONE_MASK) != NORMAL)
  4. reportException(s);
  5. return getRawResult();
  6. }

首先调用doJoin()获取当前任务的执行状态,状态定义在类的属性字段,包含4中状态:

  • NORMAL:任务已完成,返回结果
  • CANCELLED:任务被撤销,抛出CancellationException异常
  • SIGNAL:信号
  • EXCEPTIONAL:执行任务时出现异常,直接抛异常

方法的源码为:

  1. private int doJoin() {
  2. int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
  3. // 检查任务的状态,如果执行完毕直接返回状态
  4. return (s = status) < 0 ? s :
  5. // 如果没有执行完毕
  6. ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
  7. // 从队列中取出任务并执行
  8. (w = (wt = (ForkJoinWorkerThread)t).workQueue).
  9. tryUnpush(this) && (s = doExec()) < 0 ? s :
  10. wt.pool.awaitJoin(w, this, 0L) :
  11. externalAwaitDone();
  12. }

任务执行如果顺利完成,则设置任务状态为NORMAL;如果出现异常,则记录异常,并将任务状态设置为EXCEPTIONAL。

基本的执行流程如下所示:

  • 检查调用方法的线程是否是ForkJoinThread线程,如果不是则阻塞当前线程,等待任务完成;如果是,则不阻塞进行下一步
  • 查看当前任务的任务状态,如果是NORMAL,则直接返回执行结果
  • 如果任务还没有完成,但处于自己的WorkQueue中,则执行任务
  • 如果任务已经被其他的工作线程偷走,则窃取这个小偷的工作队列内的任务(以 FIFO 方式)执行,以期帮助它早日完成欲 join 的任务
  • 如果偷走任务的小偷也已经把自己的任务全部做完,正在等待需要 join 的任务时,则找到小偷的小偷,帮助它完成它的任务
  • 递归地执行后两步

4. 参考

Java 并发编程笔记:如何使用 ForkJoinPool 以及原理

Java的Fork/Join任务,你写对了吗?