InheritableThreadLocal可以做什么

我们知道ThreadLocal解决的是让每个线程读取的ThreadLocal变量是相互独立的。通俗的讲就是,比如我再线程1中set了ThreadLocal的值,那我在线程2中是get不到线程1设置的值的,只能读到线程2自己set的值。
ThreadLocal有一个需求不能满足:就是子线程无法直接复用父线程的ThreadLocal变量里的内容。

InheritableThreadLocal 使用

InheritableThreadLocal 是ThreadLocal 的子类,在ThreadLocal 基础上可以让子线程从父线程中取得值。但是有一点需要注意:如果子线程创建后,主线程将值进行更改,那么子线程取得的值还是旧值。

  1. package com.hanliukui.example.threadlocaltest;
  2. /**
  3. * @Author hanliukui
  4. * @Date 2022/4/4 15:22
  5. * @Description xxx
  6. */
  7. public class Main {
  8. static ThreadLocal<String> local = new InheritableThreadLocal<>();
  9. public static void main(String[] args) {
  10. local.set("MainA");
  11. Thread threadA = new Thread(new Runnable() {
  12. @Override
  13. public void run() {
  14. for (int i = 0; i < 10; i++) {
  15. System.out.println("ThreadA获取值:"+local.get());
  16. }
  17. }
  18. });
  19. Thread threadB = new Thread(new Runnable() {
  20. @Override
  21. public void run() {
  22. for (int i = 0; i < 10; i++) {
  23. System.out.println("ThreadB获取值:"+local.get());
  24. }
  25. }
  26. });
  27. threadA.start();
  28. threadB.start();
  29. // 父线程改变值
  30. local.set("MainB");
  31. System.out.println("Main获取值:"+local.get());
  32. try {
  33. Thread.sleep(1000);
  34. } catch (InterruptedException e) {
  35. e.printStackTrace();
  36. }
  37. }
  38. }

Main获取值:MainB ThreadB获取值:MainA ThreadA获取值:MainA ThreadB获取值:MainA ThreadA获取值:MainA ThreadB获取值:MainA ThreadA获取值:MainA ThreadB获取值:MainA ThreadA获取值:MainA ThreadA获取值:MainA ThreadA获取值:MainA ThreadA获取值:MainA ThreadA获取值:MainA ThreadB获取值:MainA ThreadA获取值:MainA ThreadB获取值:MainA ThreadA获取值:MainA ThreadB获取值:MainA ThreadB获取值:MainA ThreadB获取值:MainA ThreadB获取值:MainA

Process finished with exit code 0

InheritableThreadLocal 原理

InheritableThreadLocal 其实就是重写ThreadLocal 的3个方法。

  1. public class InheritableThreadLocal<T> extends ThreadLocal<T> {
  2. protected T childValue(T parentValue) {
  3. return parentValue;
  4. }
  5. ThreadLocalMap getMap(Thread t) {
  6. return t.inheritableThreadLocals;
  7. }
  8. void createMap(Thread t, T firstValue) {
  9. t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
  10. }
  11. }

首先,当我们调用 get 方法的时候,由于子类没有重写,所以我们调用了父类的 get 方法:

  1. public T get() {
  2. Thread t = Thread.currentThread();
  3. ThreadLocalMap map = getMap(t);
  4. if (map != null) {
  5. ThreadLocalMap.Entry e = map.getEntry(this);
  6. if (e != null) {
  7. @SuppressWarnings("unchecked")
  8. T result = (T)e.value;
  9. return result;
  10. }
  11. }
  12. return setInitialValue();
  13. }

这里会有一个getMap(t) 方法,所以就会得到这个线程 threadlocals。 但是,由于子类 InheritableThreadLocal 重写了 getMap()方法,再看上述代码,我们可以看到:其实不是得到 threadlocals,而是得到 inheritableThreadLocals。

inheritableThreadLocals 之前一直没提及过,其实它也是 Thread 类的一个 ThreadLocalMap 类型的 属性,如下 Thread 类的部分代码:

  1. ThreadLocal.ThreadLocalMap threadLocals = null;
  2. ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

那么,这里看 InheritableThreadLocal 重写的方法,感觉 inheritableThreadLocals 和 threadLocals 几乎是一模一样的作用,只是换了个名字而且,那么究竟为什么在新的线程中通过 threadlocal.get()方法还能得到值呢?

当 我们 new 一个 线程的时候:

  1. public Thread() {
  2. init(null, null, "Thread-" + nextThreadNum(), 0);
  3. }

然后:

  1. private void init(ThreadGroup g, Runnable target, String name,
  2. long stackSize) {
  3. init(g, target, name, stackSize, null);
  4. }

然后:

  1. private void init(ThreadGroup g, Runnable target, String name,
  2. long stackSize, AccessControlContext acc) {
  3. ......
  4. Thread parent = currentThread();
  5. ......
  6. if (parent.inheritableThreadLocals != null)
  7. this.inheritableThreadLocals =
  8. ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
  9. ......
  10. }

这时候有一句 ‘ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);’ ,然后:

  1. static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
  2. return new ThreadLocalMap(parentMap);
  3. }
  4. private ThreadLocalMap(ThreadLocalMap parentMap) {
  5. Entry[] parentTable = parentMap.table;
  6. int len = parentTable.length;
  7. setThreshold(len);
  8. table = new Entry[len];
  9. for (int j = 0; j < len; j++) {
  10. Entry e = parentTable[j];
  11. if (e != null) {
  12. @SuppressWarnings("unchecked")
  13. ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
  14. if (key != null) {
  15. Object value = key.childValue(e.value);
  16. Entry c = new Entry(key, value);
  17. int h = key.threadLocalHashCode & (len - 1);
  18. while (table[h] != null)
  19. h = nextIndex(h, len);
  20. table[h] = c;
  21. size++;
  22. }
  23. }
  24. }
  25. }

当我们创建一个新的线程的时候X,X线程就会有 ThreadLocalMap 类型的 inheritableThreadLocals ,因为它是 Thread 类的一个属性。然后先得到当前线程存储的这些值,例如 Entry[] parentTable = parentMap.table;。再通过一个 for 循环,不断的把当前线程的这些值复制到我们新创建的线程X 的inheritableThreadLocals 中。就这样,就ok了。