概述

从 ThreadLocal 源码学习可知,ThreadLocalMap 是 ThreadLocal 的静态内部类。

而每个 Thread 都有自己的 ThreadLocalMap,Map 中的 key 为 ThreadLocal 变量,value 为 ThreadLocal 的值。不同线程的 ThreadLocalMap 互相不可见,也就保证了线程安全。

  1. public class ThreadLocal<T> {
  2. static class ThreadLocalMap {
  3. }
  4. }

ThreadLocalMap 的成员变量

ThreadLocalMap 用 Entry 数组保存键值对,键为当前的 ThreadLocal 对象,值为 ThreadLocal 对象对应的值

Entry 数组默认的长度为 16,其中阈值 threshold 默认为 INITIAL_CAPACITY *2/3

  1. static class ThreadLocalMap {
  2. static class Entry extends WeakReference<ThreadLocal<?>> {
  3. /** The value associated with this ThreadLocal. */
  4. Object value;
  5. Entry(ThreadLocal<?> k, Object v) {
  6. super(k);
  7. value = v;
  8. }
  9. }
  10. private static final int INITIAL_CAPACITY = 16;
  11. private Entry[] table;
  12. private int size = 0;
  13. private int threshold; // Default to 0
  14. private void setThreshold(int len) {
  15. threshold = len * 2 / 3;
  16. }
  17. }

ThreadLocalMap 构造方法

ThreadLocalMap 默认的构造方法,传入的 key 为当前 ThreadLocal 对象,firstValue 为 ThreadLocal 初始设置的值。

hashcode & (len - 1) 其实就是取模运算,定位元素在数组中的位置

  1. void createMap(Thread t, T firstValue) {
  2. t.threadLocals = new ThreadLocalMap(this, firstValue);
  3. }
  4. ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
  5. table = new Entry[INITIAL_CAPACITY];
  6. int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1); // 对长度取模,得到索引位置
  7. table[i] = new Entry(firstKey, firstValue);
  8. size = 1;
  9. setThreshold(INITIAL_CAPACITY); // 初始化阈值,超过这个长度需要扩容
  10. }
  11. private void setThreshold(int len) {
  12. threshold = len * 2 / 3; // 默认为数组长度的 2/3
  13. }
  14. // ThreadLocal 中的方法
  15. private final int threadLocalHashCode = nextHashCode();
  16. private static int nextHashCode() {
  17. return nextHashCode.getAndAdd(HASH_INCREMENT);
  18. }

其次,ThreadLocalMap 的构造方法还支持传入其它的 ThreadLocalMap 进行初始化

大致原理为:新建一个和原来长度一样的 Entry 数组,遍历原数组,排除 key = null 的节点,放到新的 Entry 数组中。

  1. private ThreadLocalMap(ThreadLocalMap parentMap) {
  2. Entry[] parentTable = parentMap.table; // 传入进来的数组
  3. int len = parentTable.length;
  4. setThreshold(len);
  5. table = new Entry[len];
  6. for (int j = 0; j < len; j++) {
  7. Entry e = parentTable[j]; // 对原数组进行遍历
  8. if (e != null) { // 节点不为空
  9. @SuppressWarnings("unchecked")
  10. ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
  11. if (key != null) { // entry 元素的 key 不为空时
  12. Object value = key.childValue(e.value); // 其实是返回 ThreadLocal 对应的值
  13. Entry c = new Entry(key, value); // 新建 Entry 节点
  14. int h = key.threadLocalHashCode & (len - 1); // 取模
  15. while (table[h] != null) // 数组已有元素,产生冲突
  16. h = nextIndex(h, len);
  17. table[h] = c;
  18. size++;
  19. }
  20. }
  21. }
  22. }
  23. // 返回当前 i 的后一位,若 i 已经是最后一位,返回第 0 位,依次循环
  24. private static int nextIndex(int i, int len) {
  25. return ((i + 1 < len) ? i + 1 : 0);
  26. }

成员方法

getEntry

根据 ThreadLocal 获取 entry 节点,若 entry 为 null,那么直接返回 null;

若 entry 中保存的 threadlocal 变量和传进来的 threadlocal 变量不一致,那么向后遍历下一个节点,直到找到相同的节点;

若没有相同节点,找到 key = null 的节点那么返回 null。

  1. private Entry getEntry(ThreadLocal<?> key) {
  2. int i = key.threadLocalHashCode & (table.length - 1); // 对长度取模定位元素在数组的位置
  3. Entry e = table[i];
  4. if (e != null && e.get() == key)
  5. return e; // 节点是同一个,直接返回
  6. else
  7. // 没找到节点返回 null
  8. // 数组中的 entry 和 threadLocal 不一致
  9. return getEntryAfterMiss(key, i, e);
  10. }
  11. // 顺序向后遍历,直到找到满足的点
  12. private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
  13. Entry[] tab = table;
  14. int len = tab.length;
  15. while (e != null) { // 节点为 null 直接返回 null
  16. ThreadLocal<?> k = e.get(); // k 为 entry 数组中的 entry 对应的 key
  17. if (k == key)
  18. return e;
  19. if (k == null) // 若 key 被置为空
  20. expungeStaleEntry(i); // 清除 key 为 null 的节点
  21. else
  22. i = nextIndex(i, len); // 找到下一个节点
  23. e = tab[i];
  24. }
  25. return null;
  26. }

expungeStaleEntry

staleSlot 表示过时的插槽,方法的意思在下面有注释,主要包含:

  1. 清除 key = null 的 entry,便于垃圾回收
    2. 从 staleSlot 位置开始遍历 entry 数组,对数组中元素进行 rehash 操作,若遍历到的数组的 key = null,结束遍历。
  1. private int expungeStaleEntry(int staleSlot) {
  2. Entry[] tab = table;
  3. int len = tab.length;
  4. tab[staleSlot].value = null; // 将 staleSlot 处的 key,value 都置为 null,方便垃圾回收
  5. tab[staleSlot] = null;
  6. size--;
  7. Entry e;
  8. int i;
  9. // 开始进行 rehash 操作,即遍历数组直至遇到 key = null
  10. for (i = nextIndex(staleSlot, len);
  11. (e = tab[i]) != null;
  12. i = nextIndex(i, len)) {
  13. // 遍历 staleSlot 后面的节点
  14. ThreadLocal<?> k = e.get();
  15. if (k == null) { // 删除 key = null 的节点
  16. e.value = null;
  17. tab[i] = null;
  18. size--;
  19. } else {
  20. // 重新计算该 key 对应数组中的位置
  21. int h = k.threadLocalHashCode & (len - 1);
  22. if (h != i) { // 节点不一致,表示需要迁移
  23. tab[i] = null; // 原来的位置置为 null
  24. while (tab[h] != null) // 找到第一个为 null 的位置,插入元素
  25. h = nextIndex(h, len);
  26. tab[h] = e;
  27. }
  28. }
  29. }
  30. return i;
  31. }

set

将键值对 set 到 entry 数组中

  1. private void set(ThreadLocal<?> key, Object value) {
  2. Entry[] tab = table;
  3. int len = tab.length;
  4. // 获取 key 对应数组的位置
  5. int i = key.threadLocalHashCode & (len-1);
  6. // 遍历数组
  7. for (Entry e = tab[i];
  8. e != null;
  9. e = tab[i = nextIndex(i, len)]) {
  10. // 获取数组中的 entry 的 ThreadLocal 引用
  11. ThreadLocal<?> k = e.get();
  12. // 如果 key 一样,覆盖 value
  13. if (k == key) {
  14. e.value = value;
  15. return;
  16. }
  17. // 传进来的 key 为 null,将 i 处的节点设为 key-value,并清除后面 key = null 的节点
  18. if (k == null) {
  19. replaceStaleEntry(key, value, i);
  20. return;
  21. }
  22. }
  23. tab[i] = new Entry(key, value);
  24. int sz = ++size;
  25. // 超过阈值,进行 rehash
  26. if (!cleanSomeSlots(i, sz) && sz >= threshold)
  27. rehash();
  28. }

replaceStaleEntry

  1. private void replaceStaleEntry(ThreadLocal<?> key, Object value,
  2. int staleSlot) {
  3. Entry[] tab = table;
  4. int len = tab.length;
  5. Entry e;
  6. // 当前需要清除的节点位置
  7. int slotToExpunge = staleSlot;
  8. // 从当前节点往前遍历,找到 value 为 null 的位置
  9. for (int i = prevIndex(staleSlot, len);
  10. (e = tab[i]) != null;
  11. i = prevIndex(i, len))
  12. if (e.get() == null)
  13. slotToExpunge = i;
  14. // 从 staleSlot 位置开始往后遍历
  15. for (int i = nextIndex(staleSlot, len);
  16. (e = tab[i]) != null;
  17. i = nextIndex(i, len)) {
  18. ThreadLocal<?> k = e.get(); // 获取 threadlocal 的引用
  19. if (k == key) {
  20. e.value = value;
  21. tab[i] = tab[staleSlot];
  22. tab[staleSlot] = e;
  23. // Start expunge at preceding stale entry if it exists
  24. if (slotToExpunge == staleSlot)
  25. slotToExpunge = i;
  26. cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
  27. return;
  28. }
  29. // If we didn't find stale entry on backward scan, the
  30. // first stale entry seen while scanning for key is the
  31. // first still present in the run.
  32. if (k == null && slotToExpunge == staleSlot)
  33. slotToExpunge = i;
  34. }
  35. // If key not found, put new entry in stale slot
  36. tab[staleSlot].value = null;
  37. tab[staleSlot] = new Entry(key, value);
  38. // If there are any other stale entries in run, expunge them
  39. if (slotToExpunge != staleSlot)
  40. cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
  41. }

ThreadLocal继承性解决方案https://blog.csdn.net/weixin_42200859/article/details/105396338