并发学习(十一) 并发同步组件 CountDownLatch 与 CyclicBarrier 源码分析

并发同步组件 CountDownLatch 与 CyclicBarrier

等待多线程完成的CountDownLatch

CountDownLatch允许一个或多个线程等待其他线程完成操作。 它继承了AbstractQueuedSynchronizer(AQS)抽象类,线程可通过调用await方法进入等待状态,在其他线程调用countDown方法将计数器减为0后,处于等待状态的线程即可恢复运行。

循环同步屏障CyclicBarrier

CyclicBarrier则是让一个线程到达屏障时被阻塞,直到最后一个线程到达屏障时,屏障才会开门,所有被屏障拦截的线程才会继续运行。CyclicBarrier的默认构造方法是CyclicBarrier(int parties),其参数表示屏障拦截的线程数量,每个线程带哦用await方法告诉CyclicBarrier已经到达了屏障,然后当前线程被阻塞。

两者的区别

CountDownLatch的计数器只能使用一次,而CyclicBarrier的计数器可以使用reset()方法重置,所以CyclicBarrier能够处理更为复杂的业务场景。

原理

CountDownLatch 的实现原理

CountDownLatch是基于AQS来实现的,CountDownLatch 使用 AQS 中的 state 成员变量作为计数器,在 state 不为0的情况下,凡是调用 await 方法的线程将会被阻塞,并被放入 AQS 所维护的同步队列中进行等待。

阻塞的线程会封装成一个AQS的结点Node形成同步队列(前驱为prev,后继为next),初始情况下,队列的头结点是一个虚拟节点(dummy node)。当线程调用countDown方法时,计数器通过循环CAS进行state自减,当最后一个线程执行countDown方法时(count为0的时候),同步队列中结点按照先进先出唤醒(从头结点到后唤醒线程)。

CyclicBarrier 的实现原理

CyclicBarrier 并没有直接继承AQS,而是基于重入锁 ReentrantLock 的基础上实现的。在 CyclicBarrier 中,线程访问 await 方法需先进行lock方法获取锁才能访问,判断是否是最后一个线程,如果是,则判断是否设置了回调,同时,最后一个进入 await 的线程还会重置 CyclicBarrier 的状态,使其可以重复使用, 如果回调不为空,运行回调,如果不是最后一个线程,则进行等待。

源码分析

CountDownLatch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
public class CountDownLatch {
//内部基于AQS的内部类来实现
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;

Sync(int count) {
//设置AQS的state
setState(count);
}

int getCount() {
//获得当前计数
return getState();
}

//重写了AQS的tryAcquireShared方法
protected int tryAcquireShared(int acquires) {
//当计数器为0,说明所有线程执行了countDown获得同步状态(返回1)
return (getState() == 0) ? 1 : -1;
}

//重写了AQS的tryReleaseShared方法
protected boolean tryReleaseShared(int releases) {
//循环 + CAS设置计数器,是考虑到多线程竞争的情况
//当nextc为0并且CAS成功后,调用 await 等待的线程会被唤醒。
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}

private final Sync sync;

//构造方法,构造一个计数器为count的sync内部类
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}

//线程进入等待状态,直至计数器为0,如果计数器为0则不会等待直接返回。
public void await() throws InterruptedException {
//调用AQS的(中断版acquireShared)
sync.acquireSharedInterruptibly(1);
}

//计数器减方法
public void countDown() {
//调用releaseShared方法
sync.releaseShared(1);
}

public final boolean releaseShared(int arg) {
//调用sync的tryReleaseShared方法,即CAS设置计数器自减,如果返回true(自减至0)
if (tryReleaseShared(arg)) {
//唤醒同步队列的线程并返回true
doReleaseShared();
return true;
}
//否则返回false
return false;
}


public long getCount() {
return sync.getCount();
}

public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
//若线程中断抛出中断异常
if (Thread.interrupted())
throw new InterruptedException();
//尝试获得同步状态,小于0(为-1)说明计数器未到0,等于0则说明计数器到0
//如果返回小于0则插入同步队列进行等待
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}


}

CyclicBarrier

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151

//定义一个Generation内部类
private static class Generation {
//broken参数表示屏障是否被破坏
boolean broken = false;
}

private final ReentrantLock lock = new ReentrantLock();

private final Condition trip = lock.newCondition();

private final int parties;

private final Runnable barrierCommand;

private Generation generation = new Generation();

private int count;

public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L);
} catch (TimeoutException toe) {
throw new Error(toe);
}
}
private int dowait(boolean timed, long nanos)
throws InterruptedException, BrokenBarrierException,
TimeoutException {
final ReentrantLock lock = this.lock;
//加锁
lock.lock();
try {
final Generation g = generation;

//如果屏障被破坏了抛出异常
if (g.broken)
throw new BrokenBarrierException();

// 如果线程中断,则调用 breakBarrier 破坏屏障
if (Thread.interrupted()) {
breakBarrier();
throw new InterruptedException();
}

//获取cont自减后的值
int index = --count;
//如果index为0时
if (index == 0) { // tripped
boolean ranAction = false;
try {
final Runnable command = barrierCommand;
// 如果回调对象不为 null,则执行回调
if (command != null)
command.run();
ranAction = true;
//重置generation,保证重复使用,并且唤醒所有线程
nextGeneration();
//返回0,正常情况结束,别忘了最底下finally的unlock方法
return 0;
} finally {
if (!ranAction)
breakBarrier();
}
}

//代码运行到这里,说明并不是这个线程并不是最后一个线程,进行等待
// loop until tripped, broken, interrupted, or timed out
for (;;) {
try {
//超时与非超时两种情况
if (!timed)
trip.await();
else if (nanos > 0L)
nanos = trip.awaitNanos(nanos);
} catch (InterruptedException ie) {
/* 捕获到中断异常
* 若下面的条件成立,则表明本轮运行还未结束。此时调用 breakBarrier
* 破坏屏障,唤醒其他线程,并抛出异常
*/
if (g == generation && ! g.broken) {
breakBarrier();
throw ie;
} else {
/*
* 若上面的条件不成立,则有两种可能:
* 1. g != generation
* 此种情况下,表明循环屏障的第 g 轮次的运行已经结束,屏障已经
* 进入了新的一轮运行轮次中。当前线程在稍后返回 到达屏障 的顺序即可
*
* 2. g = generation 但 g.broken = true
* 此种情况下,表明已经有线程执行过 breakBarrier 方法了,当前
* 线程则会在稍后抛出 BrokenBarrierException
*/
Thread.currentThread().interrupt();
}
}

//抛出 BrokenBarrierException
if (g.broken)
throw new BrokenBarrierException();

// 屏障进入新的运行轮次,此时返回线程在上一轮次到达屏障的顺序
if (g != generation)
return index;

// 超时判断
if (timed && nanos <= 0L) {
breakBarrier();
throw new TimeoutException();
}
}
} finally {
//解锁
lock.unlock();
}
}


private void nextGeneration() {
//唤醒所有等待的线程
trip.signalAll();
//重置 count
count = parties;
//进入下一轮
generation = new Generation();
}

private void breakBarrier() {
//打破屏障
generation.broken = true;
//重置count
count = parties;
//唤醒所有等待线程
trip.signalAll();
}

//重置屏障
public void reset() {
final ReentrantLock lock = this.lock;
//加锁
lock.lock();
try {
//破坏屏障并且进入下一轮
breakBarrier(); // break the current generation
nextGeneration(); // start a new generation
} finally {
//解锁
lock.unlock();
}
}

参考资料