1.复制父线程变量到子线程

只有父线程里面创建一个线程的时候才会去调用init才会有inheritableThreadLocals的拷贝动作。

  1. public class InheriableThreadLocalTest {
  2. private static ThreadLocal<Integer> context = new InheritableThreadLocal<>();
  3. private static ExecutorService pool = Executors.newFixedThreadPool(2);
  4. static class MainThread extends Thread {
  5. private int index;
  6. public MainThread(int index) {
  7. this.index = index;
  8. }
  9. @Override
  10. public void run() {
  11. context.set(index);
  12. //用线程池的线程就会出现传递信息错乱
  13. //pool.execute(() -> System.out.println(Thread.currentThread().getName()+":" + context.get()));
  14. new Thread(()->System.out.println(Thread.currentThread().getName()+":" + context.get())).start();
  15. }
  16. }
  17. /**
  18. * 当使用线程池来运行我们的子线程的任务的时候,
  19. * 采用InheriableThreadLocal是无法解决变量传递的问题的
  20. */
  21. public static void main(String[] args) {
  22. for (int i = 0; i < 10; i++) {
  23. new MainThread(i).start();
  24. }
  25. pool.shutdown();
  26. }
  27. }

注意:要求必须是new的线程,不能使用线程池的线程,线程复用会导致信息传递出错。

2.解决无法使用线程池的问题

  1. /**
  2. * @author 二十
  3. * @since 2021/9/14 11:15 上午
  4. */
  5. public class TransmittableTest {
  6. private static TransmittableThreadLocal<String> context = new TransmittableThreadLocal<>();
  7. private static ExecutorService pool = TtlExecutors.getTtlExecutorService(Executors.newFixedThreadPool(5));
  8. static class MainThread extends Thread {
  9. private int index;
  10. public MainThread(int index) {
  11. this.index = index;
  12. }
  13. @Override
  14. public void run() {
  15. context.set(String.valueOf(index));
  16. pool.execute(() -> System.out.println(Thread.currentThread().getName()+":" + context.get()));
  17. // new Thread(()->System.out.println(Thread.currentThread().getName()+":" + context.get())).start();
  18. }
  19. }
  20. public static void main(String[] args) {
  21. for (int i = 0; i < 10; i++) {
  22. new MainThread(i).start();
  23. }
  24. pool.shutdown();
  25. }
  26. }

3.源码

TransmittableThreadLocal继承了InheritableThreadLocal,先点一下InheritableThreadLocal瞅瞅。

  1. public class InheritableThreadLocal<T> extends ThreadLocal<T> {
  2. /**
  3. * 新建线程时,如果当前inheritableThreadLocals非空,则会获取当前inheritableThreadLocals传递给新线程
  4. */
  5. protected T childValue(T parentValue) {
  6. return parentValue;
  7. }
  8. /**
  9. * InheritableThreadLocal变量的set/get/remove操作都是在inheritableThreadLocals上
  10. */
  11. ThreadLocalMap getMap(Thread t) {
  12. return t.inheritableThreadLocals;
  13. }
  14. /**
  15. * 创建inheritableThreadLocals
  16. */
  17. void createMap(Thread t, T firstValue) {
  18. t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
  19. }
  20. }

对TL进行了一层包装和增强。

Thread类中有两个ThreadLocal相关的ThreadLocalMap属性,如下:

  1. ThreadLocal.ThreadLocalMap threadLocalsThreadLocal变量使用
  2. ThreadLocal.ThreadLocalMap inheritableThreadLocalsInheritableThreadLocal变量使用

新建线程时,将当前线程的inheritableThreadLocals传递给新线程,这里的传递是对InheritableThreadLocal变量的数据做浅拷贝(引用复制),这样新线程可以使用同一个InheritableThreadLocal变量查看上一个线程的数据。

下面以TtlRunnable.get()为起点分析TTL的设计实现,TtlRunnable.get源码如下(TtlRunnable.get流程对应的初始化时capture操作,保存快照。TtlCallable和TtlRunnable流程类似):

  1. public static TtlRunnable get(@Nullable Runnable runnable) {
  2. return get(runnable, false, false);
  3. }
  4. public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
  5. if (runnable instanceof TtlEnhanced) {
  6. // 幂等时直接返回,否则执行会产生问题,直接抛异常
  7. if (idempotent) return (TtlRunnable) runnable;
  8. else throw new IllegalStateException("Already TtlRunnable!");
  9. }
  10. return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
  11. }
  12. private TtlRunnable(@Nonnull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
  13. this.capturedRef = new AtomicReference<Object>(capture());
  14. this.runnable = runnable;
  15. this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
  16. }
  17. public static Object capture() {
  18. Map<TransmittableThreadLocal<?>, Object> captured = new HashMap<TransmittableThreadLocal<?>, Object>();
  19. // 从holder获取所有threadLocal,存到captured,这里相当于对当前线程holder做一个快照保存
  20. // 到TtlRunnable实例属性中,在执行TtlRunnable时进行回放
  21. for (TransmittableThreadLocal<?> threadLocal : holder.get().keySet()) {
  22. captured.put(threadLocal, threadLocal.copyValue());
  23. }
  24. return captured;
  25. }

在新建TtlRunnable过程中,会保存下TransmittableThreadLocal.holder到captured,记录到TtlRunnable实例中的capturedRef字段,TransmittableThreadLocal.holder类型是:

  1. // Note about holder:
  2. // 1. The value of holder is type Map<TransmittableThreadLocal<?>, ?> (WeakHashMap implementation),
  3. // but it is used as *set*. 因为没有WeakSet的原因
  4. // 2. WeakHashMap support null value.
  5. private static InheritableThreadLocal<Map<TransmittableThreadLocal<?>, ?>> holder =
  6. new InheritableThreadLocal<Map<TransmittableThreadLocal<?>, ?>>() {
  7. @Override
  8. protected Map<TransmittableThreadLocal<?>, ?> initialValue() {
  9. return new WeakHashMap<TransmittableThreadLocal<?>, Object>();
  10. }
  11. @Override
  12. protected Map<TransmittableThreadLocal<?>, ?> childValue(Map<TransmittableThreadLocal<?>, ?> parentValue) {
  13. return new WeakHashMap<TransmittableThreadLocal<?>, Object>(parentValue);
  14. }
  15. };

从上面代码我们知道初始化TtlRunnable时已经将TransmittableThreadLocal保存下来了,那么什么时候应用到当前线程ThreadLocal中呢,这是就需要看下TtlRunnable.run方法:

  1. public void run() {
  2. Object captured = capturedRef.get();
  3. // captured不应该为空,releaseTtlValueReferenceAfterRun为true时设置capturedRef为null,防止当前Runnable重复执行
  4. if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
  5. throw new IllegalStateException("TTL value reference is released after run!");
  6. }
  7. // captured进行回放,应用到当前线程中
  8. Object backup = replay(captured);
  9. try {
  10. runnable.run();
  11. } finally {
  12. restore(backup);
  13. }
  14. }

注意,TTL中的replay操作是以captured为当前inheritableThreadLocals的(处理逻辑是在TtlRunable run时,会以TtlRunnable.get时间点获取的captured(类似TTL快照)为准,holder中不在captured的先移除,在的会被替换)。回放captured和执行完runnable.run之后,再restore恢复到原来inheritableThreadLocals的状态。

ttl.png

4.自己写一个轻量级的

  1. /**
  2. * @author 二十
  3. * @since 2021/9/14 3:13 下午
  4. */
  5. public class EsThreadLocal<T> extends InheritableThreadLocal<T> {
  6. private static InheritableThreadLocal<WeakHashMap<EsThreadLocal<Object>, ?>> holder = new InheritableThreadLocal<WeakHashMap<EsThreadLocal<Object>, ?>>() {
  7. @Override
  8. protected WeakHashMap<EsThreadLocal<Object>, ?> childValue(WeakHashMap<EsThreadLocal<Object>, ?> parentValue) {
  9. return new WeakHashMap<>(parentValue);
  10. }
  11. @Override
  12. protected WeakHashMap<EsThreadLocal<Object>, ?> initialValue() {
  13. return new WeakHashMap<>();
  14. }
  15. };
  16. @Override
  17. public int hashCode() {
  18. return super.hashCode();
  19. }
  20. @Override
  21. public boolean equals(Object obj) {
  22. return super.equals(obj);
  23. }
  24. @Override
  25. public T get() {
  26. T value = super.get();
  27. if (null != value)
  28. addToHolder();
  29. return value;
  30. }
  31. @Override
  32. public void set(T value) {
  33. if (null == value) {
  34. removeFromHolder();
  35. super.remove();
  36. } else {
  37. super.set(value);
  38. addToHolder();
  39. }
  40. }
  41. private void removeFromHolder() {
  42. holder.get().remove(this);
  43. }
  44. private void addToHolder() {
  45. if (!holder.get().containsKey(this))
  46. holder.get().put((EsThreadLocal<Object>) this, null);
  47. }
  48. static class SnapShot {
  49. final WeakHashMap<EsThreadLocal<Object>, Object> ctlValue;
  50. private SnapShot(WeakHashMap<EsThreadLocal<Object>, Object> ctlValue) {
  51. this.ctlValue = ctlValue;
  52. }
  53. }
  54. static class Transmitter {
  55. public static SnapShot capture() {
  56. return new SnapShot(captureCtlValues());
  57. }
  58. private static WeakHashMap<EsThreadLocal<Object>, Object> captureCtlValues() {
  59. return holder.get().keySet().stream().collect(Collectors.toMap(ctlItem -> ctlItem, EsThreadLocal::get, (a, b) -> b, WeakHashMap::new));
  60. }
  61. public static SnapShot replay(SnapShot snapShot) {
  62. WeakHashMap<EsThreadLocal<Object>, Object> capture = snapShot.ctlValue;
  63. WeakHashMap<EsThreadLocal<Object>, Object> backValue = new WeakHashMap<>();
  64. /*
  65. * 从holder中获取当前线程持有的threadLocal的Map,进行迭代保存
  66. */
  67. Iterator<EsThreadLocal<Object>> iterator = holder.get().keySet().iterator();
  68. while (iterator.hasNext()) {
  69. EsThreadLocal<Object> threadLocal = iterator.next();
  70. backValue.put(threadLocal, threadLocal.get());
  71. if (!capture.containsKey(threadLocal)) {
  72. iterator.remove();
  73. threadLocal.remove();
  74. }
  75. }
  76. /*
  77. 设置上capture
  78. */
  79. setThreadLocal(capture);
  80. return new SnapShot(backValue);
  81. }
  82. public static void setThreadLocal(WeakHashMap<EsThreadLocal<Object>, Object> ctlValues) {
  83. ctlValues.forEach(EsThreadLocal::set);
  84. }
  85. public static void restore(EsThreadLocal.SnapShot backUp) {
  86. Iterator<EsThreadLocal<Object>> iterator = holder.get().keySet().iterator();
  87. while (iterator.hasNext()) {
  88. EsThreadLocal<Object> threadLocal = iterator.next();
  89. if (!backUp.ctlValue.containsKey(threadLocal)) {
  90. iterator.remove();
  91. threadLocal.remove();
  92. }
  93. }
  94. setThreadLocal(backUp.ctlValue);
  95. }
  96. }
  97. static class EsRunnable implements Runnable {
  98. private AtomicReference<SnapShot> captureRef;
  99. private Runnable runnable;
  100. public EsRunnable(Runnable runnable) {
  101. this.runnable = runnable;
  102. captureRef = new AtomicReference<>(Transmitter.capture());
  103. }
  104. @Override
  105. public void run() {
  106. SnapShot capture = captureRef.get();
  107. SnapShot backUp = Transmitter.replay(capture);
  108. try {
  109. runnable.run();
  110. } finally {
  111. Transmitter.restore(backUp);
  112. }
  113. }
  114. public static EsRunnable getRunnable(Runnable runnable) {
  115. return new EsRunnable(runnable);
  116. }
  117. }
  118. }
  119. class Test {
  120. private static ThreadLocal<String> context = new EsThreadLocal<>();
  121. private static ExecutorService pool = Executors.newFixedThreadPool(5);
  122. public static void main(String[] args)throws Exception {
  123. for (int i = 1; i <=10; i++) {
  124. context.set(String.valueOf(i));
  125. pool.execute(new EsThreadLocal.EsRunnable(()->System.out.println(Thread.currentThread().getName()+" : " + context.get() )));
  126. TimeUnit.SECONDS.sleep(1);
  127. }
  128. pool.shutdown();
  129. }
  130. }

运行结果:
image.png