先了解以下 Hashtable,故所周知,Hashtable 是线程安全的

  1. public synchronized V put(K key, V value);
  2. public synchronized V get(Object key);
  3. public synchronized V remove(Object key);

所以在多线程操作下不会出现问题,那么问题来了,由于是锁当前的对象,那么效率就会变得很低,有没有一种既能保证效率,又能线程安全的 HashMap 呢?就是 ConcurrentHashMap 了

分段锁

ConcurrentHashMap 使用了分段锁的设计思路,它使用了 Segment[] 表示数组,而 Segment 中又定义了了一个 HashEntry[] 来表示一段小的数组,HashEntry 和 HashMap 中的 Entry 是一样的,这样就实现了分段的概念。

ConcurrentHashMap 的默认构造函数如下:

  1. public ConcurrentHashMap(int initialCapacity, float loadFactor, int concurrencyLevel);
  • initialCapacity:表示 HashEntry 的数量默认 16
  • loadFactor:加载因子,默认 0.75
  • concurrencyLevel:并发等级默认 16,表示有多少个 Segment

如果 initialCapacity = 32,concurrencyLevel = 16,那么每个 Segment 中的 HashEntry[] 的长度就为 2 ,
即:每个 Segment 中 HashEntry 的数量 = initialCapacity / concurrencyLevel
**

如果是按默认 initialCapacity = 16,concurrencyLevel = 16,的话,HashEntry[] 就等于 2,这是默认的,也就是说 HashEntry[] 最小等于 2

扩容

先扩容 Segment 中的 HashEntry[]

UNSAFE

  1. private static sun.misc.Unsafe UNSAFE;
  2. private static int i;
  3. static {
  4. try {
  5. Field field = Unsafe.class.getDeclaredField("theUnsafe");
  6. field.setAccessible(true);
  7. UNFAFE = (Unsafe) field.get(null);
  8. } catch(Exception ex) {
  9. ex.printStackTrace();
  10. }
  11. }
  12. i = UNSAFE.objectFieldOffset(Person.class.getDeclaredField("i"));
  13. UNSAFE.compareAndSwapInt(person, I_OFFSET, person.i, person.i++);
  14. UNSAFE.getIntVolatile(person, I_OFFSET);

源码解析

  1. // initialCapacity表示ConcurrentHashMap初始化的数组的容量
  2. // loadFactor表示加载因子
  3. // concurrencyLevel表示并发数,意思是一个ConcurrentHashMap对象支持几个线程同时操作
  4. public ConcurrentHashMap(int initialCapacity, float loadFactor, int concurrencyLevel) {
  5. // loadFactor不能大于0,initialCapacity不能小于0,concurrencyLevel不能小于或等于0
  6. if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
  7. throw new IllegalArgumentException();
  8. // 并发数最大为1 << 16
  9. if (concurrencyLevel > MAX_SEGMENTS)
  10. concurrencyLevel = MAX_SEGMENTS;
  11. // Find power-of-two sizes best matching arguments
  12. // 找到一个大于等于concurrencyLevel的2的幂次方数
  13. // 如果concurrencyLevel为17,那么ssize=32, sshift=5
  14. int sshift = 0;
  15. int ssize = 1;
  16. while (ssize < concurrencyLevel) {
  17. ++sshift;
  18. ssize <<= 1;
  19. }
  20. // 以下两个参数与计算segment数组下标有关系
  21. // segmentShift表示hash值要右移的位数,int类型为32位,一个hash值右移了segmentShift位后
  22. // 就表示只剩下了hash值的高sshift位
  23. // segmentMask就是用来进行&操作的
  24. this.segmentShift = 32 - sshift;
  25. this.segmentMask = ssize - 1;
  26. // 数组容量限制
  27. if (initialCapacity > MAXIMUM_CAPACITY)
  28. initialCapacity = MAXIMUM_CAPACITY;
  29. // 数组容量除以ssize,表示指定的每个segment的容量,然后向上取整
  30. int c = initialCapacity / ssize;
  31. if (c * ssize < initialCapacity)
  32. ++c;
  33. // segment的容量最小为2,然后取大于等于c的2的幂次方数
  34. int cap = MIN_SEGMENT_TABLE_CAPACITY;
  35. while (cap < c)
  36. cap <<= 1;
  37. // create segments and segments[0]
  38. // 创建s0,并把它设置到segments数组中的第0个位置
  39. // Segment就是一个小型的HashMap,所以也有加载因子,阈值,Entry数组
  40. Segment<K,V> s0 =
  41. new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
  42. (HashEntry<K,V>[])new HashEntry[cap]);
  43. Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
  44. // 通过UNSAFE将ss也就是segments数组的第0个位置设置为s0;
  45. UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
  46. this.segments = ss;
  47. }
  1. public V put(K key, V value) {
  2. Segment<K,V> s;
  3. // value不能为空
  4. if (value == null)
  5. throw new NullPointerException();
  6. // 根据key计算hash值
  7. int hash = hash(key);
  8. // 根据hash值计算segments数组的下标,hash值只保留高sshift位,可是为什么要右移?
  9. int j = (hash >>> segmentShift) & segmentMask;
  10. // 判断segments数组的第j个位置是否为null,如果为null则生成一个segment对象
  11. if ((s = (Segment<K,V>)UNSAFE.getObject // nonvolatile; recheck
  12. (segments, (j << SSHIFT) + SBASE)) == null) // in ensureSegment
  13. s = ensureSegment(j);
  14. // 有了segment对象之后,就向segment对象中设置key,value
  15. return s.put(key, hash, value, false);
  16. }
  1. // 确保segments数组的第k个位置是否为null,如果为null则在该位置生成一个segment对象
  2. private Segment<K,V> ensureSegment(int k) {
  3. final Segment<K,V>[] ss = this.segments;
  4. long u = (k << SSHIFT) + SBASE; // raw offset
  5. Segment<K,V> seg;
  6. // 判断segments数组的第k个位置是否为null,如果为null则生成一个segment对象
  7. if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
  8. // 以segments数组的中的0个位置的segment对象作为原型,接下来将以原型作为参照来生成新的segment对象
  9. Segment<K,V> proto = ss[0]; // use segment 0 as prototype
  10. // 通过从原型中获取参数
  11. int cap = proto.table.length;
  12. float lf = proto.loadFactor;
  13. int threshold = (int)(cap * lf);
  14. HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
  15. // 再次获取segments数组中的第k个位置的segment对象是否为null
  16. // 如果为null则先生成一个新的Segment对象
  17. // 然后使用CAS来将新的Segment对象设置到segments数组中去
  18. if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
  19. == null) { // recheck
  20. //
  21. Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
  22. while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
  23. == null) {
  24. if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
  25. break;
  26. }
  27. }
  28. }
  29. return seg;
  30. }
  1. // 如果第一次tryLock的时候就成功了获取到锁了,那么node为null
  2. // 如果第一次tryLock的如果没有成功,那么先判断hash值对应的数组下标位置是否有值,
  3. // 如果没值则会生成一个HashEntry对象,next属性为null并赋值给node
  4. // 如果有值则先遍历当前链表,找到key相同的元素,如果没有相同的key,最终那么也会生成一个HashEntry对象,并赋值给node
  5. // 以上步骤相当于寻找当前Segment对象中是否存在key相同的元素,这里是基于效率来考虑:
  6. // 反正要加锁,如果一开始没加上,就先不阻塞,先去遍历看是否要生成一个HashEntry对象,如果要则先生成
  7. // 先完成了以上步骤之后,再重试加锁:
  8. // 最大重试次数跟cpu核数相关,先理解为有多次重试
  9. // 在重试的过程中,偶数次重试的时候会去检查一下链表头元素,因为如果链表头元素发生了改变就表示当前位置
  10. // 有元素添加进来了,则重试次数重新开始计数
  11. // 当超过了最大重试次数之后就开始加锁(加不到则阻塞)
  12. // 总结,该方法的目的是使用自旋的方式去加锁,自旋的次数是不确定的,自旋分为遍历链表慢速自旋和快速自旋,自旋规则是这样的:
  13. // 如果当前位置上没有元素,则生成新node,然后快速自旋MAX_SCAN_RETRIES次
  14. // 如果当前位置上有元素,则遍历链表进行慢速自旋,如果发现链表上有key相同的,则快速自旋MAX_SCAN_RETRIES次
  15. // 如果当前位置上有元素,则遍历链表进行慢速自旋,遍历完尾结点之后,则快速自旋MAX_SCAN_RETRIES次
  16. // 在快速自旋的过程中,如果发现当前位置上的头结点发生了改变,则重新进入自旋的过程中,为什么要这样呢?
  17. // 自旋的时候遍历链表的目的就是看要不要生成一个新的node,那既然都已经生成了,为什么链表头发生了变化还要重新开始自旋呢?
  18. // 只能理解为,链表发生了变化就要重新开始自旋,也就是延长自旋时间,尽量不直接使用lock加锁,因为使用lock加锁后就只能等其他线程释放后继续往下走了
  19. private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
  20. // 根据key的hash值找到Segment对象内部的Entry数组中对应位置的元素,也就是链表的头结点
  21. HashEntry<K,V> first = entryForHash(this, hash);
  22. HashEntry<K,V> e = first;
  23. HashEntry<K,V> node = null;
  24. int retries = -1; // negative while locating node
  25. while (!tryLock()) {
  26. HashEntry<K,V> f; // to recheck first below
  27. if (retries < 0) {
  28. if (e == null) {
  29. if (node == null) // speculatively create node
  30. node = new HashEntry<K,V>(hash, key, value, null);
  31. retries = 0;
  32. }
  33. else if (key.equals(e.key))
  34. retries = 0;
  35. else
  36. e = e.next;
  37. }
  38. else if (++retries > MAX_SCAN_RETRIES) {
  39. lock();
  40. break;
  41. }
  42. else if ((retries & 1) == 0 &&
  43. (f = entryForHash(this, hash)) != first) {
  44. e = first = f; // re-traverse if entry changed
  45. retries = -1;
  46. }
  47. }
  48. return node;
  49. }
  1. // 这是Segment对象中的put方法
  2. final V put(K key, int hash, V value, boolean onlyIfAbsent) {
  3. // 先尝试获取一下锁,如果获取到了则node为null,如果第一次没有获取到,则开始加锁
  4. // node表示新的hashentry对象
  5. HashEntry<K,V> node = tryLock() ? null :
  6. scanAndLockForPut(key, hash, value);
  7. V oldValue;
  8. try {
  9. HashEntry<K,V>[] tab = table;
  10. int index = (tab.length - 1) & hash;
  11. // 链表头结点
  12. HashEntry<K,V> first = entryAt(tab, index);
  13. // 从链表头结点开始遍历
  14. for (HashEntry<K,V> e = first;;) {
  15. // 如果当前遍历到的entry对象的key不为空,则比较key,hash值是否相等,等等则覆盖
  16. // 如果不相等则e = e.next
  17. if (e != null) {
  18. K k;
  19. if ((k = e.key) == key ||
  20. (e.hash == hash && key.equals(k))) {
  21. oldValue = e.value;
  22. if (!onlyIfAbsent) {
  23. e.value = value;
  24. ++modCount;
  25. }
  26. break;
  27. }
  28. e = e.next;
  29. }
  30. else {
  31. // 如果尾结点都遍历完了
  32. // 设置node的next属性
  33. if (node != null)
  34. node.setNext(first);
  35. else
  36. node = new HashEntry<K,V>(hash, key, value, first);
  37. // ConcurrentHash的元素数量+1,如果超过了阈值则扩容
  38. // 如果没有超过则将node结点设置到数组上,相当于下移
  39. int c = count + 1;
  40. if (c > threshold && tab.length < MAXIMUM_CAPACITY)
  41. rehash(node);
  42. else
  43. setEntryAt(tab, index, node); // 将node添加到数组上,向下移动
  44. ++modCount;
  45. count = c;
  46. oldValue = null;
  47. break;
  48. }
  49. }
  50. } finally {
  51. // 解锁
  52. unlock();
  53. }
  54. return oldValue;
  55. }
  1. // Segment内扩容
  2. private void rehash(HashEntry<K,V> node) {
  3. HashEntry<K,V>[] oldTable = table;
  4. int oldCapacity = oldTable.length;
  5. // 数组大小翻倍扩容
  6. int newCapacity = oldCapacity << 1;
  7. threshold = (int)(newCapacity * loadFactor);
  8. HashEntry<K,V>[] newTable =
  9. (HashEntry<K,V>[]) new HashEntry[newCapacity];
  10. int sizeMask = newCapacity - 1;
  11. // 遍历HashEntry老数组上的每个链表
  12. for (int i = 0; i < oldCapacity ; i++) {
  13. HashEntry<K,V> e = oldTable[i]; // 链表头结点
  14. if (e != null) {
  15. HashEntry<K,V> next = e.next;
  16. // 链表头结点的在新数组上的下表
  17. int idx = e.hash & sizeMask;
  18. // 如果链表上只有一个结点,那么直接将当前entry移动到新数组上
  19. if (next == null) // Single node on list
  20. newTable[idx] = e;
  21. else { // Reuse consecutive sequence at same slot
  22. HashEntry<K,V> lastRun = e;
  23. int lastIdx = idx;
  24. // 从链表的第二个节点开始遍历链表,找到当前链表上和链表头节点新位置相同的最后一个结点
  25. for (HashEntry<K,V> last = next;
  26. last != null;
  27. last = last.next) {
  28. int k = last.hash & sizeMask;
  29. // 当前遍历到的entry的新下标不等于上一个entry的新下标,则记录一下
  30. // 如果链表的最后3个节点的新下标都一样的话,那么lastRun就是倒数第三个Entry
  31. if (k != lastIdx) {
  32. lastIdx = k;
  33. lastRun = last;
  34. }
  35. }
  36. // 将链表上的最后一个节点(不准确,参照上面的说法)转移到新数组上
  37. // 可能会把一个子链表转移过去
  38. newTable[lastIdx] = lastRun;
  39. // Clone remaining nodes
  40. // 从链表头部开始遍历,知道lastRun结束,一个个entry进行转移
  41. for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
  42. V v = p.value;
  43. int h = p.hash;
  44. int k = h & sizeMask;
  45. // 头插法
  46. HashEntry<K,V> n = newTable[k];
  47. newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
  48. }
  49. }
  50. }
  51. }
  52. // 把这个新node加到新数组上
  53. int nodeIndex = node.hash & sizeMask; // add the new node
  54. node.setNext(newTable[nodeIndex]);
  55. newTable[nodeIndex] = node;
  56. table = newTable;
  57. }