1.复制父线程变量到子线程
只有父线程里面创建一个线程的时候才会去调用init才会有inheritableThreadLocals
的拷贝动作。
public class InheriableThreadLocalTest {
private static ThreadLocal<Integer> context = new InheritableThreadLocal<>();
private static ExecutorService pool = Executors.newFixedThreadPool(2);
static class MainThread extends Thread {
private int index;
public MainThread(int index) {
this.index = index;
}
@Override
public void run() {
context.set(index);
//用线程池的线程就会出现传递信息错乱
//pool.execute(() -> System.out.println(Thread.currentThread().getName()+":" + context.get()));
new Thread(()->System.out.println(Thread.currentThread().getName()+":" + context.get())).start();
}
}
/**
* 当使用线程池来运行我们的子线程的任务的时候,
* 采用InheriableThreadLocal是无法解决变量传递的问题的
*/
public static void main(String[] args) {
for (int i = 0; i < 10; i++) {
new MainThread(i).start();
}
pool.shutdown();
}
}
注意:要求必须是new的线程,不能使用线程池的线程,线程复用会导致信息传递出错。
2.解决无法使用线程池的问题
/**
* @author 二十
* @since 2021/9/14 11:15 上午
*/
public class TransmittableTest {
private static TransmittableThreadLocal<String> context = new TransmittableThreadLocal<>();
private static ExecutorService pool = TtlExecutors.getTtlExecutorService(Executors.newFixedThreadPool(5));
static class MainThread extends Thread {
private int index;
public MainThread(int index) {
this.index = index;
}
@Override
public void run() {
context.set(String.valueOf(index));
pool.execute(() -> System.out.println(Thread.currentThread().getName()+":" + context.get()));
// new Thread(()->System.out.println(Thread.currentThread().getName()+":" + context.get())).start();
}
}
public static void main(String[] args) {
for (int i = 0; i < 10; i++) {
new MainThread(i).start();
}
pool.shutdown();
}
}
3.源码
TransmittableThreadLocal
继承了InheritableThreadLocal
,先点一下InheritableThreadLocal
瞅瞅。
public class InheritableThreadLocal<T> extends ThreadLocal<T> {
/**
* 新建线程时,如果当前inheritableThreadLocals非空,则会获取当前inheritableThreadLocals传递给新线程
*/
protected T childValue(T parentValue) {
return parentValue;
}
/**
* InheritableThreadLocal变量的set/get/remove操作都是在inheritableThreadLocals上
*/
ThreadLocalMap getMap(Thread t) {
return t.inheritableThreadLocals;
}
/**
* 创建inheritableThreadLocals
*/
void createMap(Thread t, T firstValue) {
t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
}
}
对TL进行了一层包装和增强。
Thread类中有两个ThreadLocal相关的ThreadLocalMap属性,如下:
ThreadLocal.ThreadLocalMap threadLocals:ThreadLocal变量使用
ThreadLocal.ThreadLocalMap inheritableThreadLocals:InheritableThreadLocal变量使用
新建线程时,将当前线程的inheritableThreadLocals传递给新线程,这里的传递是对InheritableThreadLocal变量的数据做浅拷贝(引用复制),这样新线程可以使用同一个InheritableThreadLocal变量查看上一个线程的数据。
下面以TtlRunnable.get()为起点分析TTL的设计实现,TtlRunnable.get源码如下(TtlRunnable.get流程对应的初始化时capture操作,保存快照。TtlCallable和TtlRunnable流程类似):
public static TtlRunnable get(@Nullable Runnable runnable) {
return get(runnable, false, false);
}
public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
if (runnable instanceof TtlEnhanced) {
// 幂等时直接返回,否则执行会产生问题,直接抛异常
if (idempotent) return (TtlRunnable) runnable;
else throw new IllegalStateException("Already TtlRunnable!");
}
return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
}
private TtlRunnable(@Nonnull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
this.capturedRef = new AtomicReference<Object>(capture());
this.runnable = runnable;
this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
}
public static Object capture() {
Map<TransmittableThreadLocal<?>, Object> captured = new HashMap<TransmittableThreadLocal<?>, Object>();
// 从holder获取所有threadLocal,存到captured,这里相当于对当前线程holder做一个快照保存
// 到TtlRunnable实例属性中,在执行TtlRunnable时进行回放
for (TransmittableThreadLocal<?> threadLocal : holder.get().keySet()) {
captured.put(threadLocal, threadLocal.copyValue());
}
return captured;
}
在新建TtlRunnable过程中,会保存下TransmittableThreadLocal.holder到captured,记录到TtlRunnable实例中的capturedRef字段,TransmittableThreadLocal.holder类型是:
// Note about holder:
// 1. The value of holder is type Map<TransmittableThreadLocal<?>, ?> (WeakHashMap implementation),
// but it is used as *set*. 因为没有WeakSet的原因
// 2. WeakHashMap support null value.
private static InheritableThreadLocal<Map<TransmittableThreadLocal<?>, ?>> holder =
new InheritableThreadLocal<Map<TransmittableThreadLocal<?>, ?>>() {
@Override
protected Map<TransmittableThreadLocal<?>, ?> initialValue() {
return new WeakHashMap<TransmittableThreadLocal<?>, Object>();
}
@Override
protected Map<TransmittableThreadLocal<?>, ?> childValue(Map<TransmittableThreadLocal<?>, ?> parentValue) {
return new WeakHashMap<TransmittableThreadLocal<?>, Object>(parentValue);
}
};
从上面代码我们知道初始化TtlRunnable时已经将TransmittableThreadLocal保存下来了,那么什么时候应用到当前线程ThreadLocal中呢,这是就需要看下TtlRunnable.run方法:
public void run() {
Object captured = capturedRef.get();
// captured不应该为空,releaseTtlValueReferenceAfterRun为true时设置capturedRef为null,防止当前Runnable重复执行
if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
throw new IllegalStateException("TTL value reference is released after run!");
}
// captured进行回放,应用到当前线程中
Object backup = replay(captured);
try {
runnable.run();
} finally {
restore(backup);
}
}
注意,TTL中的replay操作是以captured为当前inheritableThreadLocals的(处理逻辑是在TtlRunable run时,会以TtlRunnable.get时间点获取的captured(类似TTL快照)为准,holder中不在captured的先移除,在的会被替换)。回放captured和执行完runnable.run之后,再restore恢复到原来inheritableThreadLocals的状态。
4.自己写一个轻量级的
/**
* @author 二十
* @since 2021/9/14 3:13 下午
*/
public class EsThreadLocal<T> extends InheritableThreadLocal<T> {
private static InheritableThreadLocal<WeakHashMap<EsThreadLocal<Object>, ?>> holder = new InheritableThreadLocal<WeakHashMap<EsThreadLocal<Object>, ?>>() {
@Override
protected WeakHashMap<EsThreadLocal<Object>, ?> childValue(WeakHashMap<EsThreadLocal<Object>, ?> parentValue) {
return new WeakHashMap<>(parentValue);
}
@Override
protected WeakHashMap<EsThreadLocal<Object>, ?> initialValue() {
return new WeakHashMap<>();
}
};
@Override
public int hashCode() {
return super.hashCode();
}
@Override
public boolean equals(Object obj) {
return super.equals(obj);
}
@Override
public T get() {
T value = super.get();
if (null != value)
addToHolder();
return value;
}
@Override
public void set(T value) {
if (null == value) {
removeFromHolder();
super.remove();
} else {
super.set(value);
addToHolder();
}
}
private void removeFromHolder() {
holder.get().remove(this);
}
private void addToHolder() {
if (!holder.get().containsKey(this))
holder.get().put((EsThreadLocal<Object>) this, null);
}
static class SnapShot {
final WeakHashMap<EsThreadLocal<Object>, Object> ctlValue;
private SnapShot(WeakHashMap<EsThreadLocal<Object>, Object> ctlValue) {
this.ctlValue = ctlValue;
}
}
static class Transmitter {
public static SnapShot capture() {
return new SnapShot(captureCtlValues());
}
private static WeakHashMap<EsThreadLocal<Object>, Object> captureCtlValues() {
return holder.get().keySet().stream().collect(Collectors.toMap(ctlItem -> ctlItem, EsThreadLocal::get, (a, b) -> b, WeakHashMap::new));
}
public static SnapShot replay(SnapShot snapShot) {
WeakHashMap<EsThreadLocal<Object>, Object> capture = snapShot.ctlValue;
WeakHashMap<EsThreadLocal<Object>, Object> backValue = new WeakHashMap<>();
/*
* 从holder中获取当前线程持有的threadLocal的Map,进行迭代保存
*/
Iterator<EsThreadLocal<Object>> iterator = holder.get().keySet().iterator();
while (iterator.hasNext()) {
EsThreadLocal<Object> threadLocal = iterator.next();
backValue.put(threadLocal, threadLocal.get());
if (!capture.containsKey(threadLocal)) {
iterator.remove();
threadLocal.remove();
}
}
/*
设置上capture
*/
setThreadLocal(capture);
return new SnapShot(backValue);
}
public static void setThreadLocal(WeakHashMap<EsThreadLocal<Object>, Object> ctlValues) {
ctlValues.forEach(EsThreadLocal::set);
}
public static void restore(EsThreadLocal.SnapShot backUp) {
Iterator<EsThreadLocal<Object>> iterator = holder.get().keySet().iterator();
while (iterator.hasNext()) {
EsThreadLocal<Object> threadLocal = iterator.next();
if (!backUp.ctlValue.containsKey(threadLocal)) {
iterator.remove();
threadLocal.remove();
}
}
setThreadLocal(backUp.ctlValue);
}
}
static class EsRunnable implements Runnable {
private AtomicReference<SnapShot> captureRef;
private Runnable runnable;
public EsRunnable(Runnable runnable) {
this.runnable = runnable;
captureRef = new AtomicReference<>(Transmitter.capture());
}
@Override
public void run() {
SnapShot capture = captureRef.get();
SnapShot backUp = Transmitter.replay(capture);
try {
runnable.run();
} finally {
Transmitter.restore(backUp);
}
}
public static EsRunnable getRunnable(Runnable runnable) {
return new EsRunnable(runnable);
}
}
}
class Test {
private static ThreadLocal<String> context = new EsThreadLocal<>();
private static ExecutorService pool = Executors.newFixedThreadPool(5);
public static void main(String[] args)throws Exception {
for (int i = 1; i <=10; i++) {
context.set(String.valueOf(i));
pool.execute(new EsThreadLocal.EsRunnable(()->System.out.println(Thread.currentThread().getName()+" : " + context.get() )));
TimeUnit.SECONDS.sleep(1);
}
pool.shutdown();
}
}
运行结果: