前言 业务开发中经常使用 ThreadLocal 来存储用户信息等线程私有对象… ThreadLocal 内部构造是什么样子的?为什么可以线程私有?常说的内存泄露又是怎么回事? 公众号:liuzhihangs ,记录工作学习中的技术、开发及源码笔记;时不时分享一些生活中的见闻感悟。欢迎大佬来指导!

介绍

ThreadLocal 类提供了线程局部变量。和正常对象不同的是,每个线程都可以访问 get()、set() 方法,获取独属于自己的副本。 ThreadLocal 实例通常是类中的私有静态字段,并且其状态和线程关联。 每个线程都保持对其线程局部变量副本的隐式引用,只要线程是活动的并且 ThreadLocal 实例访问; 一个线程消失之后,所有的线程局部实例的副本都会被垃圾回收(除非存在对这些副本的其他引用)。

使用

有这么一种使用场景,收到 web 请求,先进行 token 验证,而这个 token,可以解析出用户 user 的信息。所以我这边一般是这样使用的:

  1. 自定义注解, @CheckToken , 标识该方法需要校验 token。
  2. Interceptor(拦截器)中检查,如果方法有 @CheckToken 注解则校验 token。
  3. 从Header中获取 Authorization ,请求第三方或者自己的逻辑校验 token ,并解析成 user。
  4. 将user放到ThreadLocal中。
  5. controller、service 在后续使用中, 如果需要 user 信息,可以直接从 ThreadLocal 中获取。
  6. 使用结束后进行remove。

    代码如下:

    1. public class LocalUserUtils {
    2. /**
    3. * 用户信息保存至 ThreadLocal 中
    4. */
    5. private static final ThreadLocal<User> USER_THREAD_LOCAL = new ThreadLocal<>();
    6. public static void set(User user) {
    7. USER_THREAD_LOCAL.set(user);
    8. }
    9. public static User get() {
    10. return USER_THREAD_LOCAL.get();
    11. }
    12. public static void remove() {
    13. USER_THREAD_LOCAL.remove();
    14. }
    15. }
    16. /**
    17. * 1. 加上注解 CheckToken
    18. * 只有方法, 类忽略
    19. */
    20. @CheckToken
    21. @PostMapping("/doXxx")
    22. public Result<Resp> doXxx(@RequestBody Req req) {
    23. Resp resp = xxxService.doXxx(req);
    24. return result.success(resp);
    25. }
    26. /**
    27. * 2. 3. 4.
    28. */
    29. @Component
    30. public class TokenInterceptor implements HandlerInterceptor {
    31. @Override
    32. public void afterCompletion(HttpServletRequest arg0, HttpServletResponse arg1, Object arg2, Exception arg3)
    33. throws Exception {
    34. LocalUserUtils.remove();
    35. }
    36. @Override
    37. public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
    38. // 请求方法是否存在注解
    39. boolean assignableFrom = handler.getClass().isAssignableFrom(HandlerMethod.class);
    40. if (!assignableFrom) {
    41. return true;
    42. }
    43. CheckToken checkToken = null;
    44. if (handler instanceof HandlerMethod) {
    45. checkToken = ((HandlerMethod) handler).getMethodAnnotation(CheckToken.class);
    46. }
    47. // 没有加注解 直接放过
    48. if (checkToken == null) {
    49. return true;
    50. }
    51. // 从Header中获取Authorization
    52. String authorization = request.getHeader("Authorization");
    53. log.info("header authorization : {}", authorization);
    54. if (StringUtils.isBlank(authorization)) {
    55. log.error("从Header中获取Authorization失败");
    56. throw CustomExceptionEnum.NOT_HAVE_TOKEN.throwCustomException();
    57. }
    58. User user = xxxUserService.checkAuthorization(authorization);
    59. // 放到
    60. LocalUserUtils.set(user);
    61. return true;
    62. }
    63. }
    64. /**
    65. * 5. 使用
    66. * 只有方法, 类忽略
    67. */
    68. @Override
    69. public Resp doXxx(Req req) {
    70. User user = LocalUserUtils.get();
    71. // do something ...
    72. return resp;
    73. }

    ThreadLocal - 图1

    抛出问题

  7. 为什么可以线程私有?

  8. 为什么建议声明为静态?
  9. 为什么强制使用后必须remove?

ThreadLocal - 图2
图 | 阿里巴巴 - Java开发手册(截图)
ThreadLocal - 图3
图 | 阿里巴巴 - Java开发手册(截图)

源码分析

Thread

  1. public class Thread implements Runnable {
  2. // 省略 ...
  3. ThreadLocal.ThreadLocalMap threadLocals = null;
  4. ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
  5. // 省略 ...
  6. }

可以看出 Thread 对象中声明了 ThreadLocal.ThreadLocalMap 对象,每个线程都有自己的工作内存,每个线程都有自己的 ThreadLocal. ThreadLocalMap 对象,所以在线程之间是互相隔离的。

ThreadLocal

ThreadLocal则是一个泛型类,同时提供 set()get()remove()静态方法。

  1. public class ThreadLocal<T> {
  2. // 线程本地hashCode
  3. private final int threadLocalHashCode = nextHashCode();
  4. // 获取此线程局部变量的当前线程副本中的值
  5. public T get() {...}
  6. // 设置当前线程的此线程局部变量的复制到指定的值
  7. public void set(T value) {...}
  8. // 删除当前线程的此线程局部变量的值
  9. public void remove() {...}
  10. // ThreadLocalMap只是用来维持线程本地值的定制Map
  11. static class ThreadLocalMap {...}
  12. }

set(T value)方法
  1. public void set(T value) {
  2. // 获取当前线程
  3. Thread t = Thread.currentThread();
  4. // 获取当前线程的 threadLocals 属性
  5. ThreadLocalMap map = getMap(t);
  6. if (map != null)
  7. // 存在则赋值
  8. map.set(this, value);
  9. else
  10. // 不存在则直接创建
  11. createMap(t, value);
  12. }
  13. // 根据线程获取当前线程的ThreadLocalMap
  14. ThreadLocalMap getMap(Thread t) {
  15. return t.threadLocals;
  16. }
  17. // 创建ThreadLocalMap 并赋值给当前线程的threadLocals字段
  18. void createMap(Thread t, T firstValue) {
  19. t.threadLocals = new ThreadLocalMap(this, firstValue);
  20. }

1.Thread.currentThread() 先获取到当前线程。
2. 获取当前线程的 threadLocals 属性,即 ThreadLocalMap
3. 判断 Map 是否存在,存在则赋值,不存在则创建对象。

get()方法
  1. public T get() {
  2. // 获取当前线程
  3. Thread t = Thread.currentThread();
  4. // 获取当前线程的 threadLocals 属性
  5. ThreadLocalMap map = getMap(t);
  6. // map不为空
  7. if (map != null) {
  8. // 根据当前ThreadLocal获取的ThreadLocalMap的Entry节点
  9. ThreadLocalMap.Entry e = map.getEntry(this);
  10. if (e != null) {
  11. // 获取节点的value 并返回
  12. @SuppressWarnings("unchecked")
  13. T result = (T)e.value;
  14. return result;
  15. }
  16. }
  17. // 设置初始值并返回 (null)
  18. return setInitialValue();
  19. }

1.Thread.currentThread() 先获取到当前线程。
2. 获取当前线程的 threadLocals 属性,即 ThreadLocalMap
3. 判断 Map 不为空,根据当前 ThreadLocal 对象获取 ThreadLocalMap.Entry 节点, 从节点中获取 value。
4.ThreadLocalMap 为空或者 ThreadLocalMap.Entry 为空,则初始化 ThreadLocalMap 并返回。

remove()方法
  1. public void remove() {
  2. // 获取当前线程的ThreadLocalMap
  3. ThreadLocalMap m = getMap(Thread.currentThread());
  4. // 不为空, 从ThreadLocalMap中移除该属性
  5. if (m != null)
  6. m.remove(this);
  7. }

阅读 set()get()remove() 的源码之后发现后面其实是操作的 ThreadLocalMap, 主要还是操作的 ThreadLocalMapset()getEntry()remove() 以及构造函数。下面看是看 ThreadLocalMap 的源码。

ThreadLocalMap

  1. static class ThreadLocalMap {
  2. /**
  3. * Entry节点继承WeakReference是弱引用
  4. */
  5. static class Entry extends WeakReference<ThreadLocal<?>> {
  6. /** 与此ThreadLocal关联的值。 */
  7. Object value;
  8. Entry(ThreadLocal<?> k, Object v) {
  9. super(k);
  10. value = v;
  11. }
  12. }
  13. // 初始容量-必须是2的幂
  14. private static final int INITIAL_CAPACITY = 16;
  15. // 表,根据需要调整大小. table.length必须始终为2的幂.
  16. private ThreadLocal.ThreadLocalMap.Entry[] table;
  17. // 表中的条目数。
  18. private int size = 0;
  19. // 扩容阈值
  20. private int threshold; // Default to 0
  21. // 设置阀值为长度的 2/3
  22. private void setThreshold(int len) {
  23. threshold = len * 2 / 3;
  24. }
  25. // 构造函数
  26. ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {...}
  27. // 根据ThreadLocal获取节点Entry
  28. private ThreadLocal.ThreadLocalMap.Entry getEntry(ThreadLocal<?> key) {...}
  29. // set ThreadLocalMap的k-v
  30. private void set(ThreadLocal<?> key, Object value) {...}
  31. // 移除当前值
  32. private void remove(ThreadLocal<?> key) {...}
  33. }
  1. Entry 继承了 WeakReference<ThreadLocal<?> 也就意味着, Entry 节点的 key 是弱引用
  2. Entry 对象的key弱引用,指向的是 ThreadLocal 对象。
  3. 线程对象执行完毕,线程对象内实例属性会被回收,此时线程内 ThreadLocal 对象的引用被置为 null ,即 Entry 的 keynull, key 会被垃圾回收。
  4. ThreadLocal 对象通常为私有静态变量, 生命周期不会至少不会随着线程技术而结束。
  5. ThreadLocal 对象存在,并且 Entry的 key == null && value != null ,这时就会造成内存泄漏。
  • 小补充
  1. 强引用、软引用、弱引用、虚引用

    强引用(StrongReference):最常见,直接 new Object(); 创建的即为强引用。当内存空间不足,Java虚拟机宁愿抛出 OOM,也不愿意随意回收具有强引用的对象来解决内存不足问题。 软引用(SoftReference):内存足够,垃圾回收器不会回收软引用对象;内存不足时,垃圾回收器会回收。 弱引用(WeakReference):垃圾回收器线程,发现就会回收。 虚引用(PhantomReference):任何时候都有可能被垃圾回收,必须引用队列联合使用。

  2. 内存泄露:

    内存泄漏(Memory leak)是在计算机科学中,由于疏忽或错误造成程序未能释放已经不再使用的内存。内存泄漏并非指内存在物理上的消失,而是应用程序分配某段内存后,由于设计错误,导致在释放该段内存之前就失去了对该段内存的控制,从而造成了内存的浪费。 —— 维基百科

    ThreadLocal - 图4

    构造函数及hash计算
    1. ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    2. // 初始化Entry数组, 长度为16
    3. table = new Entry[INITIAL_CAPACITY];
    4. // 获取key的hashCode,并计算出在数组中的索引,
    5. // 长度是 2的幂的情况下,取模 a % b == a & (b - 1)
    6. int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    7. table[i] = new Entry(firstKey, firstValue);
    8. // 设置数组元素数
    9. size = 1;
    10. // 设置扩容阈值
    11. setThreshold(INITIAL_CAPACITY);
    12. }

    threadLocalHashCode 是 ThreadLocal 的静态属性,通过 nextHashCode 方法获取。

    1. private final int threadLocalHashCode = nextHashCode();
    2. // 被赋予了接下来的哈希码。 原子更新。 从零开始。
    3. private static AtomicInteger nextHashCode = new AtomicInteger();
    4. private static final int HASH_INCREMENT = 0x61c88647;
    5. private static int nextHashCode() {
    6. // 返回下一个hash码,通过步长 0x61c88647 累加生成,这块注释说明是最佳哈希值
    7. return nextHashCode.getAndAdd(HASH_INCREMENT);
    8. }
  3. 初始化数组,长度16。

  4. 计算 key 的 hashCode,对2的幂取模。
  5. 设置元素,元素数及扩容阈值。

hashCode 通过步长 0x61c88647 累加生成, 并且使用了 AtomicInteger,保证原子性。

set()方法
  1. private void set(ThreadLocal<?> key, Object value) {
  2. Entry[] tab = table;
  3. int len = tab.length;
  4. // hashcode取模求数组索引
  5. int i = key.threadLocalHashCode & (len-1);
  6. // 获取数组中对应的位置, 重点关注 e = tab[i = nextIndex(i, len)]
  7. for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
  8. // 获取key
  9. ThreadLocal<?> k = e.get();
  10. // key 存在则覆盖
  11. if (k == key) {
  12. e.value = value;
  13. return;
  14. }
  15. // key 不存在则赋值
  16. if (k == null) {
  17. replaceStaleEntry(key, value, i);
  18. return;
  19. }
  20. }
  21. // 此时 e == null 直接执创建节点
  22. tab[i] = new Entry(key, value);
  23. int sz = ++size;
  24. // cleanSomeSlots 循环数组 查找全部key==null的Entry
  25. if (!cleanSomeSlots(i, sz) && sz >= threshold)
  26. rehash();
  27. }
  1. 获取循环 Entry 数组,获取 tab[i] 处的 e, e != null 继续循环
    1. 此时发现 e 的 key 不存在,并且不是 null (hash冲突了。)
    2. 那就通过 e = tab[i = nextIndex(i, len)]) 继续获取下一个 i,并获取新的 tab[i] 处的 e。
    3. 赋值替换值结束结束并返回。
  2. e == null 结束循环。

    1. // 下一个index,如果 i + 1 < len 直接返回下一个位置
    2. // 如果 i + 1 >= len 则返回 0, 从头开始。
    3. private static int nextIndex(int i, int len) {
    4. return ((i + 1 < len) ? i + 1 : 0);
    5. }
    6. private static int prevIndex(int i, int len) {
    7. return ((i - 1 >= 0) ? i - 1 : len - 1);
    8. }
  3. 这块利用环形设计,如果长度到达数组长度,则从开头开始继续查找。

  4. int i = key.threadLocalHashCode & (len-1); 求出索引,并不是从0开始的。

    1. /**
    2. * staleSlot 为当前索引位置, 并且当前索引位置的 k == null
    3. */
    4. private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
    5. Entry[] tab = table;
    6. int len = tab.length;
    7. Entry e;
    8. // 需要清除的 entry 的索引
    9. int slotToExpunge = staleSlot;
    10. // 循环获取到上一个 key==null 的节点及其索引,有可能还是自己
    11. for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len))
    12. if (e.get() == null)
    13. slotToExpunge = i;
    14. // 继续上一层的循环,查找下一个 k == key 的节点索引
    15. for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
    16. ThreadLocal<?> k = e.get();
    17. if (k == key) {
    18. // key 相等 则直接赋值
    19. e.value = value;
    20. // 并且将 此处的 entry替换为 tab[staleSlot]
    21. tab[i] = tab[staleSlot];
    22. tab[staleSlot] = e;
    23. // 如果发现要清除的 entry和传入的在一个位置上, 则直接赋值
    24. if (slotToExpunge == staleSlot)
    25. slotToExpunge = i;
    26. // 清除掉过期的 expungeStaleEntry(slotToExpunge) 会清除 entry的value,将其设置为null并将其设置为null, 并返回下一个需要清除的entry的索引位置
    27. // cleanSomeSlots 循环数组 查找全部key==null的Entry
    28. cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
    29. return;
    30. }
    31. // 如果向后扫描没有找到,并且已经到第初始传入的索引位置处了
    32. if (k == null && slotToExpunge == staleSlot)
    33. slotToExpunge = i;
    34. }
    35. // 没找到, 直接将旧值 Entry 设置为 null 并指向新创建的Entry
    36. tab[staleSlot].value = null;
    37. tab[staleSlot] = new Entry(key, value);
    38. // 结束之后发现要清楚的 key的索引 不等于当前传入的索引, 说明还有其他需要清除。
    39. if (slotToExpunge != staleSlot)
    40. cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
    41. }
  5. 这里存在三个属性 key, value,以及 staleSlot, staleSlot节点的 Entry != null 但是 k == null。

  6. 向前扫描获取到上一个 Entry != null 但是 k == null 的节点及其索引, 赋值给 slotToExpunge, 没有扫描到的话 slotToExpunge 还是等于 staleSlot。
  7. 向后扫描 Entry != null 的节点,因为在 set 方法中, 后面还有一段数组没有遍历。
    1. 发现 key 相等的Entry节点了, 直接赋值,然后清除其他 Entry != null 但是 k == null 的节点, 并返回。
    2. 没有找到key相等的节点,但是找到了下一个 Entry != null 但是 k == null, 且此时 slotToExpunge 未发生变化,还是指向 staleSlot, 则 i 赋值给 slotToExpunge。
  8. 向后扫描没有扫描到,则直接对当前节点(索引值为staleSlot)的节点的value设置为null,并指向新value。
  9. 结束之后发现 slotToExpunge 被改变了, 说明还有其他的要清除。

    getEntry()方法
    1. private Entry getEntry(ThreadLocal<?> key) {
    2. // hashcode取模求数组索引
    3. int i = key.threadLocalHashCode & (table.length - 1);
    4. Entry e = table[i];
    5. if (e != null && e.get() == key)
    6. // 存在则返回
    7. return e;
    8. else
    9. // 不存在
    10. return getEntryAfterMiss(key, i, e);
    11. }
    12. private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    13. Entry[] tab = table;
    14. int len = tab.length;
    15. while (e != null) {
    16. ThreadLocal<?> k = e.get();
    17. if (k == key)
    18. return e;
    19. if (k == null)
    20. // key 已经 == null 了 清除一下 value
    21. expungeStaleEntry(i);
    22. else
    23. // 继续获取下一个
    24. i = nextIndex(i, len);
    25. e = tab[i];
    26. }
    27. return null;
    28. }
  10. hashcode 取模求数组索引。

  11. 索引处获取到 Entry 则直接返回。
  12. 获取不到或者获取到的 Entry key 不相等时,有可能是因为 hash 冲突,被放到别的地方, 调用 getEntryAfterMiss 方法。
  13. getEntryAfterMiss 方法中。

    1. e == null 返回null。
    2. e != null 判断key, key相等返回 Entry, key == null, 那就需要清除这个节点,然后继续按照 nextIndex(i, len) 方法找下一个节点。

      remove()方法

      1. private void remove(ThreadLocal<?> key) {
      2. Entry[] tab = table;
      3. int len = tab.length;
      4. // hashcode 取模求数组索引
      5. int i = key.threadLocalHashCode & (len-1);
      6. // 清除当前节点的value
      7. for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
      8. if (e.get() == key) {
      9. // 清楚对象引用
      10. e.clear();
      11. // value 指向 null
      12. expungeStaleEntry(i);
      13. return;
      14. }
      15. }
      16. }
      17. public void clear() {
      18. this.referent = null;
      19. }
  14. hashcode 取模求数组索引。

  15. 循环查找数组,将当前 key 的 Entry 的引用,将 value 设置为 null, 后面会被垃圾回收掉。

    总结

    为什么可以线程私有?

    ThreadLocal 的 get()、set()、remove()方法中都有 Thread t = Thread.currentThread(); 操作的其实是本线程,获取本线程的ThreadLocalMap。
    每个线程都有自己的 ThreadLocal,并且是将 value 存放在一个以 ThreadLocal 为 key 的 ThreadLocalMap 中的。所以线程间隔离。

    为什么建议声明为静态?

    Java开发手册已经给出说明,还有就是,如果 ThreadLocal 设置为非静态,那就是某个线程的实例类,这样的话就会失去了线程共享的本质属性。

    为什么强制必须时候后remove()?

    这块可以和内存泄露一块说明, 通过上面的 ThreadLocalMap 处关于弱引用的讲解已经说明会产生内存泄露。至于如何解决也给出了答案:
    1.set() 时清除 Entry != null && key == null 的节点, 将其 value 设置为 null。
    2.getEntry() 时清除当前 key 到 nextIndex(i, len)==null 之间的 Entry != null && key == null 的节点, 将其 value 设置为 null。
    3.remove() 时清除指定key的 Entry != null && key == null 的节点, 将其 value 设置为 null。
    之所以使用remove(),还是为了解决内存泄露的问题。

    Last

  16. 使用时注意声明为 private static final

  17. 使用后要 remove()