把一个大任务拆成多个子任务进行并行计算再把拆分的子任务的计算结果进行合并

基本方法

1.fork()创建异步执行的子任务
2.等待任务完成后返回计算结果
3.开始执行任务 必要时等待其执行结果

RecursiveAction

无返回结果的任务

RecursiveTask

有返回结果的类

invoke(ForkJoinTask)
提交任务并一直阻塞 直到任务执行完成返回合并结果
executr()
异步执行没有返回值
submit(ForkJoinTask)
异步执行任务 返回Task本身 通过task.get()方法获取合并后的结果

应用

  1. public static void main(String[] args) {
  2. ForkJoinPool forkJoinPool = new ForkJoinPool();
  3. ForkJoinTask<Integer> task = forkJoinPool.submit(new CalculationTask(1, 2002));
  4. try {
  5. Integer integer = task.get();
  6. System.out.println("执行结果"+integer);
  7. }catch (Exception e) {
  8. e.printStackTrace();
  9. }
  10. }
  11. private static final Integer MAX = 400;
  12. static class CalculationTask extends RecursiveTask<Integer> {
  13. private Integer startValue; //子任务开始计算的值
  14. private Integer endValue; //子任务结束计算的值
  15. public CalculationTask(Integer startValue, Integer endValue) {
  16. this.startValue = startValue;
  17. this.endValue = endValue;
  18. }
  19. @Override
  20. protected Integer compute() {
  21. if (endValue - startValue < MAX) {
  22. System.out.println("开始计算"+"startValue="+startValue+"endValue="+endValue);
  23. Integer totalValue = 0 ;
  24. for (int index = this.startValue; index <=this.endValue ; index++) {
  25. totalValue +=index;
  26. }
  27. return totalValue;
  28. }
  29. return createSubtasks();
  30. }
  31. private Integer createSubtasks() {
  32. CalculationTask task = new CalculationTask(startValue,(startValue+endValue)/2);
  33. task.fork();
  34. CalculationTask task1 = new CalculationTask((startValue+endValue)/2+1,endValue);
  35. task1.fork();
  36. return task.join()+task1.join();
  37. }
  38. }

实现原理

1.使用forkJoinPool.submit提交任务 如果是第一次提交 需要初始化 forkJoinPool中的workQueues数组

workQueue(工作队列包含属性)
ForkJoinTasks[]
用来存放通过 submit /execute方法提交的 ForkJoinTask
ForkJoinTaskWorkerThreadowner是 ForkJoinPoll中的工作线程 该线程用于执行具体的ForkJoinTask

ForkJoinPoll 指向当前ForkJoinPoll实例的引用 该引用是为了当ForkJoinTask数组中的任务处理完成之后 再次获取任务并交给ForkJoinTaskWorkerThread处理
2.通过r&m&SQMASK进行取模计算 计算WorkQueues数组的下标 把当前ForkJoinTask添加到指定位置
m表示 workQueues数组长度
r是通过 ThreadLocalRandom.getProbe得到的随机数
SQMASK = 0*007e表示任何整数和 SQMASK进行与运算后得到的一定是偶数 也就是第一次提交的任务会放到 workQueues的偶数位

3.任务提交后 需要安排线程来执行 如果工作线程数不够且没有正在等待的线程则创建一个新的ForkJoinWorkerThread

4.初始化时候 ForkJoinWorkerThread线程会调用registerWorker方法绑定一个工作线程 也就是把ForkJoinPoll中的WorkQueues数组的奇位数分配给当前线程

5.启动创建好的线程 从当前线程绑定的工作队列中获取任务来执行 由于第一次进来时存储数据在数组的偶数位 而当前线程绑定的是奇数位 所以当前线程工作队列中没有任务 所以会从其他线程窃取
6.当前线程执行完成后发现没有任务需要执行则等待

工作窃取

由于每个工作线程都从自己的工作队列中来获得任务执行 如果某个工作线程执行完自己工作队列中的任务 就会进入阻塞状态 有可能其他的工作线程还有任务没有执行完
工作窃取就是当自己的工作线程任务执行完毕 从其他工作队列中窃取任务来执行为了避免任务获取存在的竞争线程进行工作窃取时是从队列的尾部来获取任务的

  1. public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
  2. if (task == null)
  3. throw new NullPointerException();
  4. externalPush(task);
  5. return task;
  6. }
  1. final void externalPush(ForkJoinTask<?> task) {
  2. WorkQueue[] ws; WorkQueue q; int m;
  3. int r = ThreadLocalRandom.getProbe();
  4. int rs = runState;
  5. if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
  6. (q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
  7. U.compareAndSwapInt(q, QLOCK, 0, 1)) {
  8. ForkJoinTask<?>[] a; int am, n, s;
  9. if ((a = q.array) != null &&
  10. (am = a.length - 1) > (n = (s = q.top) - q.base)) {
  11. int j = ((am & s) << ASHIFT) + ABASE;
  12. U.putOrderedObject(a, j, task);//将任务添加到当前WorkQueue的ForkJoinTask数组中
  13. U.putOrderedInt(q, QTOP, s + 1);//释放QTOP索引
  14. U.putIntVolatile(q, QLOCK, 0);//释放锁
  15. if (n <= 1) //当前队列的任务处理完毕 工作线程属于阻塞状态
  16. signalWork(ws, q); //唤醒或者创建线程
  17. return;
  18. }
  19. U.compareAndSwapInt(q, QLOCK, 1, 0);
  20. }
  21. //如果存在线程竞争或者WorkQueues数组没有初始化
  22. externalSubmit(task);
  23. }
  1. private void externalSubmit(ForkJoinTask<?> task) {
  2. int r; // initialize caller's probe
  3. //得到一个探针hash值
  4. if ((r = ThreadLocalRandom.getProbe()) == 0) {
  5. ThreadLocalRandom.localInit();
  6. r = ThreadLocalRandom.getProbe();
  7. }
  8. for (;;) {
  9. WorkQueue[] ws; WorkQueue q; int rs, m, k;
  10. boolean move = false;
  11. //当前线程池的状态为TERMINATE拒绝添加任务
  12. if ((rs = runState) < 0) {
  13. tryTerminate(false, false); // help terminate
  14. throw new RejectedExecutionException();
  15. }
  16. //队列为空 进行初始化
  17. else if ((rs & STARTED) == 0 || // initialize
  18. ((ws = workQueues) == null || (m = ws.length - 1) < 0)) {
  19. int ns = 0;
  20. //获得锁
  21. rs = lockRunState();
  22. try {
  23. if ((rs & STARTED) == 0) {
  24. U.compareAndSwapObject(this, STEALCOUNTER, null,
  25. new AtomicLong());
  26. // create workQueues array with size a power of two
  27. int p = config & SMASK; // ensure at least 2 slots
  28. //保证数组长度为2的N次幂
  29. int n = (p > 1) ? p - 1 : 1;
  30. n |= n >>> 1; n |= n >>> 2; n |= n >>> 4;
  31. n |= n >>> 8; n |= n >>> 16; n = (n + 1) << 1;
  32. workQueues = new WorkQueue[n];
  33. ns = STARTED;
  34. }
  35. } finally {
  36. unlockRunState(rs, (rs & ~RSLOCK) | ns);
  37. }
  38. }
  39. //随机从workQueues数组中找到一个偶数位下标对应的 workQueue 把任务添加到队列中
  40. else if ((q = ws[k = r & m & SQMASK]) != null) {
  41. if (q.qlock == 0 && U.compareAndSwapInt(q, QLOCK, 0, 1)) {
  42. ForkJoinTask<?>[] a = q.array;
  43. int s = q.top;
  44. boolean submitted = false; // initial submission or resizing
  45. try { // locked version of push
  46. if ((a != null && a.length > s + 1 - q.base) ||
  47. (a = q.growArray()) != null) {
  48. //计算存储偏移量
  49. int j = (((a.length - 1) & s) << ASHIFT) + ABASE;
  50. //把任务存储到数组的指定位置
  51. U.putOrderedObject(a, j, task);
  52. //修改索引
  53. U.putOrderedInt(q, QTOP, s + 1);
  54. submitted = true;
  55. }
  56. } finally {
  57. U.compareAndSwapInt(q, QLOCK, 1, 0);
  58. }
  59. //任务提交成功 唤醒或者创建工作线程来执行
  60. if (submitted) {
  61. signalWork(ws, q);
  62. return;
  63. }
  64. }
  65. move = true; // move on failure
  66. }
  67. //如果指定偶数位下标的还未初始化 则构建一个新的workQueue保存到数组中该下标位置
  68. else if (((rs = runState) & RSLOCK) == 0) { // create new queue
  69. q = new WorkQueue(this, null);
  70. q.hint = r;
  71. q.config = k | SHARED_QUEUE;
  72. q.scanState = INACTIVE;
  73. rs = lockRunState(); // publish index
  74. if (rs > 0 && (ws = workQueues) != null &&
  75. k < ws.length && ws[k] == null)
  76. ws[k] = q; // else terminated
  77. unlockRunState(rs, rs & ~RSLOCK);
  78. }
  79. //不满足上面的条件 重新更新hash探针 继续寻找数组的下一个元素
  80. else
  81. move = true; // move if busy
  82. if (move)
  83. r = ThreadLocalRandom.advanceProbe(r);
  84. }
  85. }

唤醒或创建工作线程

  1. final void signalWork(WorkQueue[] ws, WorkQueue q) {
  2. long c; int sp, i; WorkQueue v; Thread p;
  3. while ((c = ctl) < 0L) {
  4. // too few active
  5. //没有空闲的工作线程
  6. if ((sp = (int)c) == 0) {
  7. // no idle workers
  8. //工作线程还没有到达阈值
  9. if ((c & ADD_WORKER) != 0L) // too few workers
  10. tryAddWorker(c); //创建工作线程
  11. break;
  12. }
  13. //队列为空 可能线程已经终止或未初始化
  14. if (ws == null) // unstarted/terminated
  15. break;
  16. if (ws.length <= (i = sp & SMASK)) // terminated
  17. break;
  18. if ((v = ws[i]) == null) // terminating
  19. break;
  20. int vs = (sp + SS_SEQ) & ~INACTIVE; // next scanState
  21. int d = sp - v.scanState; // screen CAS
  22. //设置活跃工作线程数 总工作线程数
  23. long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred);
  24. if (d == 0 && U.compareAndSwapLong(this, CTL, c, nc)) {
  25. v.scanState = vs; // activate v
  26. if ((p = v.parker) != null)
  27. U.unpark(p); //唤醒工作线程
  28. break;
  29. }
  30. if (q != null && q.base == q.top) // no more work
  31. break;
  32. }
  33. }