一,TL的基本使用与原理

为线程创建独一份的副本数据。

1.基本使用

  1. /**
  2. * @author 二十
  3. * @since 2021/8/28 11:19 下午
  4. */
  5. public class TlTest {
  6. private static AtomicInteger id = new AtomicInteger(0);
  7. private static ThreadLocal<Integer> tl = ThreadLocal.withInitial(()->id.getAndIncrement());
  8. private static CountDownLatch count = new CountDownLatch(3);
  9. public static void main(String[] args)throws Exception {
  10. new Thread(()->{
  11. System.out.println(tl.get()+" "+Thread.currentThread().getName());
  12. tl.remove();
  13. count.countDown();
  14. },"A").start();
  15. new Thread(()->{
  16. System.out.println(tl.get()+" "+Thread.currentThread().getName());
  17. tl.remove();
  18. count.countDown();
  19. },"B").start();
  20. new Thread(()->{
  21. System.out.println(tl.get()+" "+Thread.currentThread().getName());
  22. tl.remove();
  23. count.countDown();
  24. },"C").start();
  25. count.await();
  26. }
  27. }

2.原理分析

  • 里面维护一个ThreadLocalMap结构,每一个元素对应一个桶位。

  • 使用ThreadLocal定义的变量,将指向当前线程本地的一个LocalMap空间。

  • ThreadLocal变量作为key,其内容作为value,保存在本地。

  • 多线程对ThreadLocal对象进行操作,实际上是对各自的本地变量进行操作,不存在线程安全问题。

image.png

假设一个类里面定义了三个threadlocal,三个线程来访问这个类,每个线程本地会维护一个threadlocalmap,每一个map里面会有三个entrykeythreadlocal对象,valuethreadlocal里面set的值。

tl源码流程.png

二,TL源码

1.属性

  1. /**
  2. * 线程获取Threadlocal.get()时,如果是第一次在某个threadlocal对象上get,会给当前线程分配一个value,
  3. * 这个value和当前的threadlocal对象被包装成一个entry,其中key=threadlocal对象,
  4. * value=threadlocal对象给当前线程生成的value。这个entry存放到哪个位置与这个value有关。
  5. */
  6. private final int threadLocalHashCode = nextHashCode();
  7. //创建threadlocal对象时会使用到,每创建一个threadlocal对象就会使用它分配一个hash值给对象。
  8. private static AtomicInteger nextHashCode = new AtomicInteger();
  9. //每创建一个threadlocal对象,这个nextHashCode就会增长0x61c88647。
  10. private static final int HASH_INCREMENT = 0x61c88647;
  11. //创建新的threadlocal对象的时候,给当前对象分配hash的时候用到。
  12. private static int nextHashCode() {
  13. return nextHashCode.getAndAdd(HASH_INCREMENT);
  14. }
  15. //留给子类重写扩展的
  16. protected T initialValue() {
  17. return null;
  18. }
  19. //带初始化值得threadlocal
  20. public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
  21. return new SuppliedThreadLocal<>(supplier);
  22. }
  23. public ThreadLocal() {
  24. }

2.get()

  1. public T get() {
  2. //获取当前线程
  3. Thread t = Thread.currentThread();
  4. //根据当前线程获取对应的map
  5. ThreadLocalMap map = getMap(t);
  6. if (map != null) { //已经初始化
  7. //根据当前threadlocal对象获取entry节点
  8. ThreadLocalMap.Entry e = map.getEntry(this);
  9. if (e != null) { //节点初始化过
  10. //获取entry的value并返回
  11. T result = (T)e.value;
  12. return result;
  13. }
  14. }
  15. //走到这里说明map尚未初始化获取entry尚未初始化
  16. return setInitialValue();
  17. }

2.1 setInitialValue()

  1. private T setInitialValue() {
  2. //获取初始值,留给子类重写
  3. T value = initialValue();
  4. //获取当前线程
  5. Thread t = Thread.currentThread();
  6. //获取当前线程对应的map
  7. ThreadLocalMap map = getMap(t);
  8. //如果map初始化过
  9. if (map != null)
  10. //map里面放入当前对象和value
  11. map.set(this, value);
  12. else //map尚未初始化过
  13. //初始化map--直接new一个并放入当前对象和value
  14. createMap(t, value);
  15. //返回value
  16. return value;
  17. }

2.2 getMap()

  1. //返回当前线程的threadLocals
  2. ThreadLocalMap getMap(Thread t) {
  3. return t.threadLocals;
  4. }

2.3 createMap()

  1. //利用构造器初始化threadLocals并将当前线程和线程对应的value设置进去
  2. void createMap(Thread t, T firstValue) {
  3. t.threadLocals = new ThreadLocalMap(this, firstValue);
  4. }

3.set()

  1. public void set(T value) {
  2. Thread t = Thread.currentThread();
  3. ThreadLocalMap map = getMap(t);
  4. if (map != null) //当前线程对应的map已经初始化
  5. map.set(this, value); //map放入值
  6. else //map未初始化
  7. createMap(t, value); //初始化map
  8. }

4.remove()

  1. public void remove() {
  2. ThreadLocalMap m = getMap(Thread.currentThread());
  3. if (m != null) //map已经初始化
  4. m.remove(this); //调用map的remove移除掉当前对象对应的entry
  5. }

5.内部类ThreadLocalMap

  1. //threadlocalmap里面的key是弱引用 ,key=threadlocal对象
  2. //value是强引用,value保存的是threadlocal对象与当前线程关联的value
  3. //这样设计的好处是为了防止内存泄漏
  4. static class Entry extends WeakReference<ThreadLocal<?>> {
  5. Object value;
  6. Entry(ThreadLocal<?> k, Object v) {
  7. super(k);
  8. value = v;
  9. }
  10. }
  11. //map的初始化容量为16
  12. private static final int INITIAL_CAPACITY = 16;
  13. //map里面的entry桶位列表
  14. private Entry[] table;
  15. //列表容量
  16. private int size = 0;
  17. /**
  18. * 扩容阈值 当前数组长度的三分之二
  19. */
  20. private int threshold; // Default to 0
  21. //将扩容阈值设置为当前数组长度的三分之二
  22. private void setThreshold(int len) {
  23. threshold = len * 2 / 3;
  24. }
  25. //获取下一个位置
  26. private static int nextIndex(int i, int len) {
  27. return ((i + 1 < len) ? i + 1 : 0);
  28. }
  29. //获取下一个位置
  30. private static int prevIndex(int i, int len) {
  31. return ((i - 1 >= 0) ? i - 1 : len - 1);
  32. }
  33. //其实从上层的api可以发现这里其实是延迟初始化,只有线程第一次调用threadlocal的
  34. //get或者set的时候才会初始化。
  35. ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
  36. //初始化散列表,长度为16
  37. table = new Entry[INITIAL_CAPACITY];
  38. //计算entry的存储位置
  39. int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
  40. //创建新的entry
  41. table[i] = new Entry(firstKey, firstValue);
  42. //占用量设置为1
  43. size = 1;
  44. //修改扩容阈值为初始化长度
  45. setThreshold(INITIAL_CAPACITY);
  46. }

5.1 getEntry()

  1. private Entry getEntry(ThreadLocal<?> key) {
  2. //根据当前线程的threadlocal对象获取entry的存储位置
  3. int i = key.threadLocalHashCode & (table.length - 1);
  4. Entry e = table[i];
  5. if (e != null && e.get() == key) //校验entry是不是已经丢了,或者已经被覆盖
  6. return e;
  7. else //执行打这里说明entry已经丢了或者被发生了hash冲突,继续向后寻找
  8. return getEntryAfterMiss(key, i, e);
  9. }

5.2 getEntryAfterMiss()

  1. private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
  2. //获取散列表
  3. Entry[] tab = table;
  4. //获取散列表的长度
  5. int len = tab.length;
  6. //如果entry不为空,那就说明entryhash冲突了
  7. while (e != null) {
  8. //获取entry对应的threadlocal对象
  9. ThreadLocal<?> k = e.get();
  10. //说明key对应的threadlocal对象已经被回收了,当前entry属于脏数据
  11. if (k == key)
  12. //直接返回
  13. return e;
  14. //如果key==null,说明key对应的threadlocal对象已经被回收了,当前entry属于脏数据
  15. if (k == null)
  16. //做一次探测式过期清理
  17. expungeStaleEntry(i);
  18. else //执行到这里说明发生了hash冲突,继续从当前位置往后寻找
  19. i = nextIndex(i, len);
  20. e = tab[i];
  21. }
  22. //说明entry过期了,直接返回null
  23. return null;
  24. }

5.3 expungeStaleEntry() 探测式过期清理

  1. private int expungeStaleEntry(int staleSlot) {
  2. //获取散列表
  3. Entry[] tab = table;
  4. //获取散列表的长度
  5. int len = tab.length;
  6. //因为此处threadlocal对象已经被回收,所以直接将value设置为null,help GC
  7. tab[staleSlot].value = null;
  8. //再讲当前桶位设置为空
  9. tab[staleSlot] = null;
  10. /**
  11. * 为什么这里要分两次设置为null?
  12. * 因为key本身是弱引用,但是value是强引用,如果直接回收桶位,value无法直接被回收
  13. */
  14. //散列表的占用长度-1
  15. size--;
  16. Entry e;
  17. int i;
  18. //从当前节点所在位置的下一个位置直到最后循环,
  19. for (i = nextIndex(staleSlot, len);
  20. //停止条件是当前索引对应桶位=null
  21. (e = tab[i]) != null;
  22. //循环条件是每次索引+1
  23. i = nextIndex(i, len)) {
  24. //获取entry的threadlocal对象
  25. ThreadLocal<?> k = e.get();
  26. if (k == null) {//如果对象为空,说明已经过期了,entry是脏数据
  27. //回收
  28. e.value = null;
  29. tab[i] = null;
  30. size--;
  31. } else {//此时说明entry不是脏数据
  32. //计算threadlocal对象在散列表的新索引,为啥重新计算?
  33. //因为当前get到了脏数据,刚刚从散列表移除,所以散列表的占用量已经发生了变化
  34. int h = k.threadLocalHashCode & (len - 1);
  35. //如果没有发生hash冲突
  36. if (h != i) {
  37. //将原来的桶位释放
  38. tab[i] = null;
  39. //寻找存放位置,直到所在桶位为空,因为可能计算出的位置发生了hash冲突,
  40. //这个时候,就要索引下推到下一桶位
  41. while (tab[h] != null)
  42. h = nextIndex(h, len);
  43. //将entry放到新的桶位
  44. tab[h] = e;
  45. }
  46. }
  47. }
  48. //返回最后处理的索引处
  49. return i;
  50. }

5.4 set()

  1. private void set(ThreadLocal<?> key, Object value) {
  2. Entry[] tab = table;
  3. int len = tab.length;
  4. int i = key.threadLocalHashCode & (len-1);
  5. for (Entry e = tab[i];//当前threadlocal对象所对应的节点
  6. e != null; //终止条件是entry为空,说明这个桶位能存放entry
  7. e = tab[i = nextIndex(i, len)]) { //桶位下推
  8. ThreadLocal<?> k = e.get();
  9. //如果当前对象所对应的桶位有值,且当前桶位的key是当前对象,
  10. //说明这是一次值重置,直接覆盖旧的值即可
  11. if (k == key) {
  12. e.value = value;
  13. return;
  14. }
  15. //如果k==null,说明当前位置对应的entry是过期的,
  16. if (k == null) {
  17. replaceStaleEntry(key, value, i);
  18. return;
  19. }
  20. }
  21. //来到这里的条件:这次操作不是一次对已经有的值得覆盖,或者已经找到了应该存放当前entry的桶位
  22. tab[i] = new Entry(key, value);
  23. int sz = ++size;
  24. //如果达到了扩容的条件,进行扩容操作
  25. if (!cleanSomeSlots(i, sz) && sz >= threshold)
  26. rehash();
  27. }

5.5 replaceStaleEntry()替换过期entry

  1. //替换过期entry
  2. private void replaceStaleEntry(ThreadLocal<?> key, Object value,
  3. int staleSlot) {
  4. Entry[] tab = table;
  5. int len = tab.length;
  6. Entry e;
  7. //进入这个方法的条件说明:当前位置的节点其实是过期的,但是还没来得及回收
  8. int slotToExpunge = staleSlot; //当前桶位的索引
  9. //从当前位置向前清理
  10. for (int i = prevIndex(staleSlot, len); //i=当前索引的前一个索引
  11. (e = tab[i]) != null; //终止条件是索引所在的桶位有数据
  12. i = prevIndex(i, len)) //循环条件是每次往前一个桶位
  13. //说明是过期的,那就继续往前清理
  14. if (e.get() == null)
  15. slotToExpunge = i;
  16. //从当前位置向后清理
  17. for (int i = nextIndex(staleSlot, len);
  18. (e = tab[i]) != null;
  19. i = nextIndex(i, len)) {
  20. ThreadLocal<?> k = e.get();
  21. //如果当前位置的key和当前threadlocal对象一致
  22. if (k == key) {
  23. //进行值覆盖操作
  24. e.value = value;
  25. //将过期数据放到当前循环到的table[i]
  26. tab[i] = tab[staleSlot];
  27. //这里的逻辑其实就是进行一下位置优化
  28. tab[staleSlot] = e;
  29. //说明上面的循环并没有找到过期数据
  30. if (slotToExpunge == staleSlot)
  31. //吧探测的开始位置改成当前位置
  32. slotToExpunge = i;
  33. //进行探测式过期清理
  34. cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
  35. return;
  36. }
  37. //当前遍历entry是一个过期数据 && 往前找过期数据没找到
  38. if (k == null && slotToExpunge == staleSlot)
  39. //更新探测位置为当前位置
  40. slotToExpunge = i;
  41. }
  42. //将新的值放入当前节点
  43. tab[staleSlot].value = null;
  44. tab[staleSlot] = new Entry(key, value);
  45. //如果两个索引不相等,就继续清理
  46. if (slotToExpunge != staleSlot)
  47. cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
  48. }

5.6 cleanSomeSlots()启发式清理工作

  1. //启发式清理工作 i 开始清理位置 n 结束条件,数组长度
  2. private boolean cleanSomeSlots(int i, int n) {
  3. boolean removed = false;
  4. Entry[] tab = table;
  5. int len = tab.length;
  6. do {
  7. //获取当前i的下一个下标
  8. i = nextIndex(i, len);
  9. //获取当前下标为I的元素
  10. Entry e = tab[i];
  11. //断定为过期元素
  12. if (e != null && e.get() == null) {
  13. n = len;//更新数组长度
  14. removed = true;
  15. //从当前过期位置开始一次谈测试清理工作
  16. i = expungeStaleEntry(i);
  17. }
  18. } while ( (n >>>= 1) != 0);//假设table.length=16
  19. return removed;
  20. }

5.7 rehash()

  1. private void rehash() {
  2. //遍历,探测式清理,干掉所有过期数据
  3. expungeStaleEntries();
  4. //仍然达到扩容条件
  5. if (size >= threshold - threshold / 4)
  6. //扩容
  7. resize();
  8. }

5.8 resize()

  1. private void resize() {
  2. Entry[] oldTab = table;
  3. int oldLen = oldTab.length;
  4. int newLen = oldLen * 2; //扩容为原来的2倍
  5. Entry[] newTab = new Entry[newLen];
  6. int count = 0;
  7. for (int j = 0; j < oldLen; ++j) {
  8. Entry e = oldTab[j]; //访问old表指定位置的data
  9. if (e != null) { //data存在
  10. ThreadLocal<?> k = e.get();
  11. if (k == null) { //过期数据
  12. e.value = null; // Help the GC
  13. } else {
  14. int h = k.threadLocalHashCode & (newLen - 1);//重新计算hash值
  15. while (newTab[h] != null) //获取到一个最近的,可以使用的位置
  16. h = nextIndex(h, newLen);
  17. newTab[h] = e; //数据迁移
  18. count++;
  19. }
  20. }
  21. }
  22. setThreshold(newLen);//设置下一次扩容的指标
  23. size = count;
  24. table = newTab;
  25. }

5.9 remove()

  1. private void remove(ThreadLocal<?> key) {
  2. Entry[] tab = table;
  3. int len = tab.length;
  4. int i = key.threadLocalHashCode & (len-1);
  5. for (Entry e = tab[i];
  6. e != null;
  7. e = tab[i = nextIndex(i, len)]) {
  8. //从当前位置开始,如果桶位是空,就去下一个
  9. //如果不为空的桶位与当前线程的threadlocal对象一致
  10. if (e.get() == key) {
  11. e.clear(); //干掉key的引用
  12. expungeStaleEntry(i); //探测式过期清理
  13. return;
  14. }
  15. }
  16. }