并发学习(十二) ThreadLocal源码分析

ThreadLocal

ThreadLocal,即线程本地变量,每个线程往这个ThreadLocal中读写是线程隔离,互相之间不会影响的。它提供了一种将可变数据通过每个线程有自己的独立副本从而实现线程封闭的机制

实现思路

Thread类有一个类型为ThreadLocal.ThreadLocalMap的实例变量threadLocals,也就是说每个线程有一个自己的ThreadLocalMap。ThreadLocalMap有自己的独立实现,可以简单地将它的key视作ThreadLocal,value为代码中放入的值(实际上key并不是ThreadLocal本身,而是它的一个弱引用)。每个线程在往某个ThreadLocal里塞值的时候,都会往自己的ThreadLocalMap里存,读也是以某个ThreadLocal作为引用,在自己的map里找对应的key,从而实现了线程隔离。

源码分析

ThreadLocal中的核心数据结构就是ThreadLocalMap,所有的操作都是围绕ThreadLocalMap来实现的,这里先分析ThreadLocalMap的源码,后面分析ThreadLocal会简单很多

ThreadLocalMap

1
2
3
4
5
6
7
8
9
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;

Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}

这个map不要与HashMap中的map混淆,这个map以ThreadLocal为key,value为实际放入的值,但是注意,这个地方传入的是ThreadLocal的弱引用。

之所以使用弱引用,是因为如果使用正常的key-value形式来定义存储结构,会造成结点生命周期与线程强绑定,只要线程没有销毁,那么结点与GC ROOT会一直处于可达状态,则不会被回收,而程序本身也无法判断是否可以清理节点。在Java的垃圾收集中,引用分为强引用、软引用、弱引用、虚引用。而被弱引用关联的对象一定会被回收,也就是说它只能存活到下一次垃圾回收发生之前。当某个ThreadLocal已经没有强引用可达,则随着它被垃圾回收,在ThreadLocalMap里对应的Entry的键值会失效,这为ThreadLocalMap本身的垃圾清理提供了便利。

类变量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

//初始容量16
private static final int INITIAL_CAPACITY = 16;

//散列表,大小为2的n次幂
private Entry[] table;

//大小默认为0
private int size = 0;

//重分配阈值默认为0
private int threshold; // Default to 0

/**
* 设置重分配阈值维持最坏的2/3的装载因子
*/
private void setThreshold(int len) {
threshold = len * 2 / 3;
}

/**
* 相当于环形取下一结点并与长度取模运算(取模效率低)
*/
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}

/**
* 相当于环形取上一结点并与长度取模运算(取模效率低)
*/
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}

由于ThreadLocalMap使用线性探测法来解决散列冲突,所以实际上Entry[]数组在程序逻辑上是作为一个环形存在的。所以,我们基本了解了ThreadLocalMap的内部存储结构。

threadlocal.png

构造函数

1
2
3
4
5
6
7
8
9
10
11
12
13
/**
* 构造一个新的初始化的map保持(firstKey, firstValue)
* ThreadLocalMaps是惰性构造的,所以我们当我们至少往里put一个entry的时候才会建立
*/
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
//hash & (length -1) 这个式子肯定不陌生,就是在hashmap也是用了这种策略,这里不赘述
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
//将扩容阈值设置为初始容量16
setThreshold(INITIAL_CAPACITY);
}

可以看出,它是在上一个被构造出的ThreadLocal的ID/threadLocalHashCode的基础上加上一个魔数0x61c88647的。这个魔数的选取与斐波那契散列有关,0x61c88647对应的十进制为1640531527。斐波那契散列的乘数可以用(long) ((1L << 31) * (Math.sqrt(5) - 1))可以得到2654435769,如果把这个值给转为带符号的int,则会得到-1640531527。换句话说
(1L << 32) - (long) ((1L << 31) * (Math.sqrt(5) - 1))得到的结果就是1640531527也就是0x61c88647。通过理论与实践,当我们用0x61c88647作为魔数累加为每个ThreadLocal分配各自的ID也就是threadLocalHashCode再与2的幂取模,得到的结果分布很均匀。

至于为什么要用2的n次幂,在HashMap已经讨论过了,不再赘述。

getEntry()方法

getEntry()方法会被ThreadLocal的get()方法直接调用,上面也说过,get()方法内部就是先拿到当前线程的ThreadLocalMap,然后将自己this作为参数调用其getEntry()方法。这里要提前说明一点的是,每个索引(slot)上的状态有三种:有效(ThreadLocal未回收),失效(ThreadLocal已回收),空(null)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
private Entry getEntry(ThreadLocal<?> key) {
//获取key在map中的下标
int i = key.threadLocalHashCode & (table.length - 1);
//取得下标对应的entry
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
//如果未能成功取得entry,执行未命中方法 (由于是采用线性探测,此时要么entry已失效或者是出现了哈希冲突)
return getEntryAfterMiss(key, i, e);
}

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;

while (e != null) {
//获取entry的key
ThreadLocal<?> k = e.get();
//找到目标
if (k == key)
return e;
if (k == null)
// 该entry对应的ThreadLocal已经被回收,调用expungeStaleEntry来清理无效的entry
expungeStaleEntry(i);
else
// 获取环形的下一个entry
i = nextIndex(i, len);
e = tab[i];
}
return null;
}

// ThreadLocal的核心清理方法
private int expungeStaleEntry(int staleSlot) {
//获取entry表以及长度
Entry[] tab = table;
int len = tab.length;

//这个方法的上一步,由于该位的ThreadLocal已经被回收,此时将value设置为空,并且与表断开方便回收
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;

// Rehash until we encounter null
Entry e;
int i;
//获取下一个槽,清理后面的ThreadLocal的结点
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
//获得对应的ThreadLocal,如果ThreadLocal的key为空,说明已经被回收,此时将value设为空
//并且与map断开来,帮助回收,最后长度自减
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
//如果不为空,说明该结点的ThreadLocal没有被回收,此时进行rehash
int h = k.threadLocalHashCode & (len - 1);
//如果rehash取得的位置与原来的位置i不同
if (h != i) {
//将该位置位空
tab[i] = null;
// Unlike Knuth 6.4 Algorithm R, we must scan until
// null because multiple entries could have been stale.
//往后找到一个空位,将该结点放入
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}

set()方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
private void set(ThreadLocal<?> key, Object value) {

// We don't use a fast path as with get() because it is at
// least as common to use set() to create new entries as
// it is to replace existing ones, in which case, a fast
// path would fail more often than not.

Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);

//在下标i的结点处往后遍历(之所以遍历是因为使用线性探测法来解决哈希冲突所以要往后遍历)
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();

// 找到相应的entry
if (k == key) {
e.value = value;
return;
}
// 替换失效的entry
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
// ThreadLocal对应的key实例不存在也没有陈旧元素,new 一个
tab[i] = new Entry(key, value);
int sz = ++size;
// cleanSomeSlots 清除陈旧的Entry(key == null)
// 如果没有清理陈旧的 Entry 并且数组中的元素大于了阈值,则进行 rehash
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

这里可以知道,ThreadLocalMap与HashMap的处理哈希冲突方式并不一样,前者使用的是线性探测法,而后者使用的是拉链法。并且,set()操作除了存储元素外,还有一个很重要的作用,就是replaceStaleEntry()和cleanSomeSlots(),这两个方法可以清除掉key == null 的实例,防止内存泄漏。

下面看看rehash

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
private void rehash() {
expungeStaleEntries();// 清理一次陈旧数据

// 如果清理完后大于3/4阈值,进行扩容
if (size >= threshold - threshold / 4)
resize();
}

/**
* 两倍容量的扩容
*/
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
//两倍原大小的新表
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;
//数据迁移
for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
//如果此时key为空,将value也置位空帮助GC
if (k == null) {
e.value = null; // Help the GC
} else {
//如果不为空则将数据迁移到新表
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
//如果存在哈希冲突,进行线性探测
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}
//设置阈值
setThreshold(newLen);
size = count;
table = newTab;
}

/**
* Expunge all stale entries in the table.
*/
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}

关于ThreadLocalMap的源码就分析到这里,下面分析ThreadLocal的源码。

ThreadLocal的get方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
public T get() {
//获取当前线程
Thread t = Thread.currentThread();
//获取ThreadLocalMap
ThreadLocalMap map = getMap(t);
//如果map不为空
if (map != null) {
//找到当前ThreadLocal变量实例对应的Entry
ThreadLocalMap.Entry e = map.getEntry(this);
//如果entry存在返回value
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
//如果map为null,说明没有初始化
return setInitialValue();
}

private T setInitialValue() {
T value = initialValue();//默认返回null,可以自定义
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)//如果map不为null,把初始化value设置进去
map.set(this, value);
else//如果为空,则执行createMap
createMap(t, value);
return value;
}

void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}

set方法

1
2
3
4
5
6
7
8
9
10
public void set(T value) {
Thread t = Thread.currentThread();
//找到当前线程的threadLocals
ThreadLocalMap map = getMap(t);
//如果map不为空则返回value,否则调用createMap
if (map != null)
map.set(this, value);
else
createMap(t, value);
}

remove方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
//调用ThreadLocalMap的remove方法
m.remove(this);
}

private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}

public void clear() {
this.referent = null;
}

ThreadLocal与内存泄漏

关于ThreadLocal是否会引起内存泄漏也是一个比较有争议性的问题,其实就是要看对内存泄漏的准确定义是什么。
认为ThreadLocal会引起内存泄漏的说法是因为如果一个ThreadLocal对象被回收了,我们往里面放的value对于【当前线程->当前线程的threadLocals(ThreadLocal.ThreadLocalMap对象)->Entry数组->某个entry.value】这样一条强引用链是可达的,因此value不会被回收。
认为ThreadLocal不会引起内存泄漏的说法是因为ThreadLocal.ThreadLocalMap源码实现中自带一套自我清理的机制。

之所以有关于内存泄露的讨论是因为在有线程复用如线程池的场景中,一个线程的寿命很长,大对象长期不被回收影响系统运行效率与安全。如果线程不会复用,用完即销毁了也不会有ThreadLocal引发内存泄露的问题。《Effective Java》一书中的第6条对这种内存泄露称为unintentional object retention(无意识的对象保留)。

当我们仔细读过ThreadLocalMap的源码,我们可以推断,如果在使用的ThreadLocal的过程中,显式地进行remove是个很好的编码习惯,这样是不会引起内存泄漏。
那么如果没有显式地进行remove呢?只能说如果对应线程之后调用ThreadLocal的get和set方法都有很高的概率会顺便清理掉无效对象,断开value强引用,从而大对象被收集器回收。

但无论如何,我们应该考虑到何时调用ThreadLocal的remove方法。一个比较熟悉的场景就是对于一个请求一个线程的server如tomcat,在代码中对web api作一个切面,存放一些如用户名等用户信息,在连接点方法结束后,再显式调用remove。

参考资料