1.复制父线程变量到子线程
只有父线程里面创建一个线程的时候才会去调用init才会有inheritableThreadLocals的拷贝动作。
public class InheriableThreadLocalTest {private static ThreadLocal<Integer> context = new InheritableThreadLocal<>();private static ExecutorService pool = Executors.newFixedThreadPool(2);static class MainThread extends Thread {private int index;public MainThread(int index) {this.index = index;}@Overridepublic void run() {context.set(index);//用线程池的线程就会出现传递信息错乱//pool.execute(() -> System.out.println(Thread.currentThread().getName()+":" + context.get()));new Thread(()->System.out.println(Thread.currentThread().getName()+":" + context.get())).start();}}/*** 当使用线程池来运行我们的子线程的任务的时候,* 采用InheriableThreadLocal是无法解决变量传递的问题的*/public static void main(String[] args) {for (int i = 0; i < 10; i++) {new MainThread(i).start();}pool.shutdown();}}
注意:要求必须是new的线程,不能使用线程池的线程,线程复用会导致信息传递出错。
2.解决无法使用线程池的问题
/*** @author 二十* @since 2021/9/14 11:15 上午*/public class TransmittableTest {private static TransmittableThreadLocal<String> context = new TransmittableThreadLocal<>();private static ExecutorService pool = TtlExecutors.getTtlExecutorService(Executors.newFixedThreadPool(5));static class MainThread extends Thread {private int index;public MainThread(int index) {this.index = index;}@Overridepublic void run() {context.set(String.valueOf(index));pool.execute(() -> System.out.println(Thread.currentThread().getName()+":" + context.get()));// new Thread(()->System.out.println(Thread.currentThread().getName()+":" + context.get())).start();}}public static void main(String[] args) {for (int i = 0; i < 10; i++) {new MainThread(i).start();}pool.shutdown();}}
3.源码
TransmittableThreadLocal继承了InheritableThreadLocal,先点一下InheritableThreadLocal瞅瞅。
public class InheritableThreadLocal<T> extends ThreadLocal<T> {/*** 新建线程时,如果当前inheritableThreadLocals非空,则会获取当前inheritableThreadLocals传递给新线程*/protected T childValue(T parentValue) {return parentValue;}/*** InheritableThreadLocal变量的set/get/remove操作都是在inheritableThreadLocals上*/ThreadLocalMap getMap(Thread t) {return t.inheritableThreadLocals;}/*** 创建inheritableThreadLocals*/void createMap(Thread t, T firstValue) {t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);}}
对TL进行了一层包装和增强。
Thread类中有两个ThreadLocal相关的ThreadLocalMap属性,如下:
ThreadLocal.ThreadLocalMap threadLocals:ThreadLocal变量使用ThreadLocal.ThreadLocalMap inheritableThreadLocals:InheritableThreadLocal变量使用
新建线程时,将当前线程的inheritableThreadLocals传递给新线程,这里的传递是对InheritableThreadLocal变量的数据做浅拷贝(引用复制),这样新线程可以使用同一个InheritableThreadLocal变量查看上一个线程的数据。
下面以TtlRunnable.get()为起点分析TTL的设计实现,TtlRunnable.get源码如下(TtlRunnable.get流程对应的初始化时capture操作,保存快照。TtlCallable和TtlRunnable流程类似):
public static TtlRunnable get(@Nullable Runnable runnable) {return get(runnable, false, false);}public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {if (runnable instanceof TtlEnhanced) {// 幂等时直接返回,否则执行会产生问题,直接抛异常if (idempotent) return (TtlRunnable) runnable;else throw new IllegalStateException("Already TtlRunnable!");}return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);}private TtlRunnable(@Nonnull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {this.capturedRef = new AtomicReference<Object>(capture());this.runnable = runnable;this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;}public static Object capture() {Map<TransmittableThreadLocal<?>, Object> captured = new HashMap<TransmittableThreadLocal<?>, Object>();// 从holder获取所有threadLocal,存到captured,这里相当于对当前线程holder做一个快照保存// 到TtlRunnable实例属性中,在执行TtlRunnable时进行回放for (TransmittableThreadLocal<?> threadLocal : holder.get().keySet()) {captured.put(threadLocal, threadLocal.copyValue());}return captured;}
在新建TtlRunnable过程中,会保存下TransmittableThreadLocal.holder到captured,记录到TtlRunnable实例中的capturedRef字段,TransmittableThreadLocal.holder类型是:
// Note about holder:// 1. The value of holder is type Map<TransmittableThreadLocal<?>, ?> (WeakHashMap implementation),// but it is used as *set*. 因为没有WeakSet的原因// 2. WeakHashMap support null value.private static InheritableThreadLocal<Map<TransmittableThreadLocal<?>, ?>> holder =new InheritableThreadLocal<Map<TransmittableThreadLocal<?>, ?>>() {@Overrideprotected Map<TransmittableThreadLocal<?>, ?> initialValue() {return new WeakHashMap<TransmittableThreadLocal<?>, Object>();}@Overrideprotected Map<TransmittableThreadLocal<?>, ?> childValue(Map<TransmittableThreadLocal<?>, ?> parentValue) {return new WeakHashMap<TransmittableThreadLocal<?>, Object>(parentValue);}};
从上面代码我们知道初始化TtlRunnable时已经将TransmittableThreadLocal保存下来了,那么什么时候应用到当前线程ThreadLocal中呢,这是就需要看下TtlRunnable.run方法:
public void run() {Object captured = capturedRef.get();// captured不应该为空,releaseTtlValueReferenceAfterRun为true时设置capturedRef为null,防止当前Runnable重复执行if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {throw new IllegalStateException("TTL value reference is released after run!");}// captured进行回放,应用到当前线程中Object backup = replay(captured);try {runnable.run();} finally {restore(backup);}}
注意,TTL中的replay操作是以captured为当前inheritableThreadLocals的(处理逻辑是在TtlRunable run时,会以TtlRunnable.get时间点获取的captured(类似TTL快照)为准,holder中不在captured的先移除,在的会被替换)。回放captured和执行完runnable.run之后,再restore恢复到原来inheritableThreadLocals的状态。
4.自己写一个轻量级的
/*** @author 二十* @since 2021/9/14 3:13 下午*/public class EsThreadLocal<T> extends InheritableThreadLocal<T> {private static InheritableThreadLocal<WeakHashMap<EsThreadLocal<Object>, ?>> holder = new InheritableThreadLocal<WeakHashMap<EsThreadLocal<Object>, ?>>() {@Overrideprotected WeakHashMap<EsThreadLocal<Object>, ?> childValue(WeakHashMap<EsThreadLocal<Object>, ?> parentValue) {return new WeakHashMap<>(parentValue);}@Overrideprotected WeakHashMap<EsThreadLocal<Object>, ?> initialValue() {return new WeakHashMap<>();}};@Overridepublic int hashCode() {return super.hashCode();}@Overridepublic boolean equals(Object obj) {return super.equals(obj);}@Overridepublic T get() {T value = super.get();if (null != value)addToHolder();return value;}@Overridepublic void set(T value) {if (null == value) {removeFromHolder();super.remove();} else {super.set(value);addToHolder();}}private void removeFromHolder() {holder.get().remove(this);}private void addToHolder() {if (!holder.get().containsKey(this))holder.get().put((EsThreadLocal<Object>) this, null);}static class SnapShot {final WeakHashMap<EsThreadLocal<Object>, Object> ctlValue;private SnapShot(WeakHashMap<EsThreadLocal<Object>, Object> ctlValue) {this.ctlValue = ctlValue;}}static class Transmitter {public static SnapShot capture() {return new SnapShot(captureCtlValues());}private static WeakHashMap<EsThreadLocal<Object>, Object> captureCtlValues() {return holder.get().keySet().stream().collect(Collectors.toMap(ctlItem -> ctlItem, EsThreadLocal::get, (a, b) -> b, WeakHashMap::new));}public static SnapShot replay(SnapShot snapShot) {WeakHashMap<EsThreadLocal<Object>, Object> capture = snapShot.ctlValue;WeakHashMap<EsThreadLocal<Object>, Object> backValue = new WeakHashMap<>();/** 从holder中获取当前线程持有的threadLocal的Map,进行迭代保存*/Iterator<EsThreadLocal<Object>> iterator = holder.get().keySet().iterator();while (iterator.hasNext()) {EsThreadLocal<Object> threadLocal = iterator.next();backValue.put(threadLocal, threadLocal.get());if (!capture.containsKey(threadLocal)) {iterator.remove();threadLocal.remove();}}/*设置上capture*/setThreadLocal(capture);return new SnapShot(backValue);}public static void setThreadLocal(WeakHashMap<EsThreadLocal<Object>, Object> ctlValues) {ctlValues.forEach(EsThreadLocal::set);}public static void restore(EsThreadLocal.SnapShot backUp) {Iterator<EsThreadLocal<Object>> iterator = holder.get().keySet().iterator();while (iterator.hasNext()) {EsThreadLocal<Object> threadLocal = iterator.next();if (!backUp.ctlValue.containsKey(threadLocal)) {iterator.remove();threadLocal.remove();}}setThreadLocal(backUp.ctlValue);}}static class EsRunnable implements Runnable {private AtomicReference<SnapShot> captureRef;private Runnable runnable;public EsRunnable(Runnable runnable) {this.runnable = runnable;captureRef = new AtomicReference<>(Transmitter.capture());}@Overridepublic void run() {SnapShot capture = captureRef.get();SnapShot backUp = Transmitter.replay(capture);try {runnable.run();} finally {Transmitter.restore(backUp);}}public static EsRunnable getRunnable(Runnable runnable) {return new EsRunnable(runnable);}}}class Test {private static ThreadLocal<String> context = new EsThreadLocal<>();private static ExecutorService pool = Executors.newFixedThreadPool(5);public static void main(String[] args)throws Exception {for (int i = 1; i <=10; i++) {context.set(String.valueOf(i));pool.execute(new EsThreadLocal.EsRunnable(()->System.out.println(Thread.currentThread().getName()+" : " + context.get() )));TimeUnit.SECONDS.sleep(1);}pool.shutdown();}}
运行结果:
