并发学习(六) Fork/Join框架

Fork/Join框架

Fork/Join框架是Java7提供了的一个用于并行执行任务的框架, 是一个把大任务分割成若干个小任务,最终汇总每个小任务结果后得到大任务结果的框架。

工作窃取算法

工作窃取算法(work-stealing)是指当一个队列所对应的线程先执行完队列中的所有任务后,从其他线程的队列里窃取一个任务来执行。为了减少竞争,通常会使用双端队列,被窃取任务线程永远从双端队列的头部拿任务执行,而窃取任务的线程永远从双端队列的尾部拿任务执行。

Fork/Join框架

Fork/Join框架的工作分为两步:

第一步分割任务。首先我们需要有一个fork类来把大任务分割成子任务,有可能子任务还是很大,所以还需要不停的分割,直到分割出的子任务足够小。

第二步执行任务并合并结果。分割的子任务分别放在双端队列里,然后几个启动线程分别从双端队列里获取任务执行。子任务执行完的结果都统一放在一个队列里,启动一个线程从队列里拿数据,然后合并这些数据。

Fork/Join使用两个类来完成以上两件事情:

ForkJoinTask:我们要使用ForkJoin框架,必须首先创建一个ForkJoin任务。它提供在任务中执行fork()和join()操作的机制,通常情况下我们不需要直接继承ForkJoinTask类,而只需要继承它的子类,Fork/Join框架提供了以下两个子类:

  • RecursiveAction:用于没有返回结果的任务。
  • RecursiveTask :用于有返回结果的任务。

ForkJoinPool :ForkJoinTask需要通过ForkJoinPool来执行,任务分割出的子任务会添加到当前工作线程所维护的双端队列中,进入队列的头部。当一个工作线程的队列里暂时没有任务时,它会随机从其他工作线程的队列的尾部获取一个任务。

使用Fork/Join框架

通过Fork/Join框架计算1 + 2 + … + 100000 的值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
package forkjoin;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveTask;

public class ForkJoinTest {

static class SumTask extends RecursiveTask<Long>{
private int threshold; //分配阈值
private int start, end;
private int[] array;

public SumTask(int threshold, int start, int end, int[] array) {
this.threshold = threshold;
this.start = start;
this.end = end;
this.array = array;
}

@Override
protected Long compute() {
boolean canComplate = (end-start) <= threshold;
long sum = 0L;
if (canComplate){
for (int i = start; i <= end; i++) {
sum += array[i];
}
}
else {
//如果任务大于阈值,就分裂成两个子任务计算
int middle = start + (end - start) / 2;
SumTask leftTask = new SumTask(threshold, start, middle, array);
SumTask rightTask = new SumTask(threshold, middle+1, end, array);
//执行子任务
leftTask.fork();
rightTask.fork();
//等待子任务执行完,并得到结果
long leftResult = leftTask.join();
long rightResult = rightTask.join();
sum = leftResult + rightResult;
}

return sum;
}
}

public static void main(String[] args) {
final int COUNT = 100000;

int[] array = new int[COUNT];
for (int i = 0; i < array.length; i++) {
array[i] = i + 1;
}
ForkJoinPool forkJoinPool = new ForkJoinPool();
long begintime = System.currentTimeMillis();
SumTask sumTask = new SumTask(1000, 0, array.length-1, array);
Future<Long> future = forkJoinPool.submit(sumTask);
try {
System.out.println("sum = " + future.get());
long endtime=System.currentTimeMillis();
System.out.println("takes " + (endtime - begintime) + "ms");
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}
}

输出结果:

1
2
sum = 5000050000
takes 6ms

注意,阈值的选择需要有一个平衡点,如果阈值过小,会导致线程之间存在竞争增大,如果阈值过大,会退化成单线程。

参考资料

  • Java并发编程的艺术[M].机械工业出版社