HashTable ConcurrentHashMap key,value 都不能为 null。HashMap 可以

JDK 1.7 前使用的是分段锁,每次锁住一个 segment,一个 segment 包含多个 bucket。
JDK 1.8 后使用数组 + 链表/红黑树,每次锁住一个 bucket 上的链表/红黑树,从而提高并发度。

扩容 JDK1.8:https://juejin.im/post/6844903607901356046


主要利用 Unsafe 操作 + ReentrantLock + 分段思想。

  • Unsafe:CAS、获取数组的元素、设置数组的元素。直接操作主存
  • 分段:Segment 提高并发度

1.5-1.7
image.png

数据结构:

  1. // 掩码,segments长度-1。即bit全为1。与HashMap中的indexFor类似
  2. final int segmentMask;
  3. final int segmentShift;
  4. final Segment<K,V>[] segments;
  5. static final class Segment<K,V> extends ReentrantLock implements Serializable {
  6. transient volatile HashEntry<K,V>[] table;
  7. }

参数:

  1. // 默认容量 16。全部 HashEntry[] 数组的数量
  2. static final int DEFAULT_INITIAL_CAPACITY = 16;
  3. // 负载因子 0.75
  4. static final float DEFAULT_LOAD_FACTOR = 0.75f;
  5. // 并发级别 16。segments 数组的大小
  6. static final int DEFAULT_CONCURRENCY_LEVEL = 16;
  7. // Integer.MAX_VALUE = 2^31-1。所以最大容量只能是 2^30(因为要满足2的幂次方)
  8. static final int MAXIMUM_CAPACITY = 1 << 30;
  9. // per-segment tables 大小的最小值,2的幂次方
  10. static final int MIN_SEGMENT_TABLE_CAPACITY = 2;
  • 虽然默认容量16/默认并发度16=1,但是默认的 per-segment tables 大小为 2

    初始化:

    1. int sshift = 0;
    2. int ssize = 1;
    3. while (ssize < concurrencyLevel) {
    4. ++sshift;
    5. ssize <<= 1;
    6. }
    7. this.segmentShift = 32 - sshift;
    8. this.segmentMask = ssize - 1;
    9. if (initialCapacity > MAXIMUM_CAPACITY)
    10. initialCapacity = MAXIMUM_CAPACITY;
    11. int c = initialCapacity / ssize;
    12. // c 要向上取整
    13. if (c * ssize < initialCapacity)
    14. ++c;
    15. int cap = MIN_SEGMENT_TABLE_CAPACITY;
    16. while (cap < c)
    17. cap <<= 1;
    18. // create segments and segments[0]
    19. Segment<K,V> s0 =
    20. new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
    21. (HashEntry<K,V>[])new HashEntry[cap]);
    22. Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
    23. UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
    24. this.segments = ss;
  • 实际的并发度(segments 数组的大小) ssize 需要是 2 的幂次方,所以要根据 concurrencyLevel 计算

  • c 表示每个 segment 中 table 的大小,cap 则取大于等于 c 的 2 幂次方的最小值
    • 阈值是对某个 segment 对象里面的 tables 的,而不是对于 segments 数组的
  • 会先创建 segments[0] 这个 s0 对象,是为了保存 cap 和 threshold 等属性,以便于后续其他对象的创建
  • sshift 为 log2(ssize),即 2^sshift=ssize
    • this.segmentShift = 32 - sshift:32减去并发度

扩容:

  • 扩容只会增大某个 segment 对象里面的 tables 数组的大小,而不会去增大 segments 数组的大小

线程安全:

  • Segment 继承了 ReentrantLock,直接调用内部的 tryLock()、unlock()
  • 并且使用 scanAndLockForPut()

put:

  1. public V put(K key, V value) {
  2. Segment<K,V> s;
  3. if (value == null)
  4. throw new NullPointerException();
  5. int hash = hash(key);
  6. int j = (hash >>> segmentShift) & segmentMask;
  7. // SSHIFT 数组中一个元素的大小的位数
  8. if ((s = (Segment<K,V>)UNSAFE.getObject // nonvolatile; recheck
  9. (segments, (j << SSHIFT) + SBASE)) == null) // in ensureSegment
  10. // 如果对应的位置为 null,则生成一个 Segment 对象
  11. s = ensureSegment(j);
  12. return s.put(key, hash, value, false);
  13. }
  • int j = (hash >>> segmentShift) & segmentMask;
  • hash >>> segmentShift:保留与 segments 数组同样长度的最高位
    • segment.put 里面就是使用 (tab.length - 1) & hash
  • ensureSegment 生成 Segment 对象,注意并发安全问题

    1. private Segment<K,V> ensureSegment(int k) {
    2. final Segment<K,V>[] ss = this.segments;
    3. long u = (k << SSHIFT) + SBASE; // raw offset
    4. Segment<K,V> seg;
    5. // 进来后再检查一次,k 位置的元素为 null
    6. if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
    7. // 使用 segments 的第一个对象作为原型,拿到 cap 和 loadFactor
    8. Segment<K,V> proto = ss[0]; // use segment 0 as prototype
    9. int cap = proto.table.length;
    10. float lf = proto.loadFactor;
    11. int threshold = (int)(cap * lf);
    12. // 创建 table
    13. HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
    14. if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
    15. == null) { // recheck
    16. // 创建 segment 对象
    17. Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
    18. // 如果有其他线程设置成功,获取
    19. while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
    20. == null) {
    21. if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
    22. break;
    23. }
    24. }
    25. }
    26. return seg;
    27. }
  • 进来后再检查一次,k 位置的元素为 null

  • 根据数组第一个元素为原型,创建 Segment 对象和里面的 HashEntry[] 数组
  • 使用 CAS 自旋去 set
    • 如果有其他线程设置成功了,则获取并返回 seg
    • 用 while 而不用 false,是为了防止 ABA,第三个线程把元素 remove 了就变成 null ```java // 获取数组对象中,每个元素的大小 ss = UNSAFE.arrayIndexScale(sc); // ss对应二进制的长度。比如16为4,8为3,4为2。 // 16=10000b=5位。32-5+1=31-5 // 31 - 最高位1前面0的数量 SSHIFT = 31 - Integer.numberOfLeadingZeros(ss);

// 比如元素大小为16,则返回17,31-27=4。则得到了16的2的4次方 (j * ss) + SBASE (j << SSHIFT) + SBASE // 左移1位等于乘以2,左移4位等于乘以2^4=16 // 所以这里是通过左移来实现乘法,因为元素大小ss总是2的幂次方,所以可以等同


<a name="GxGUH"></a>
### Segment 的 put:
```java
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
    HashEntry<K,V> node = tryLock() ? null :
    scanAndLockForPut(key, hash, value);
    V oldValue;
    try {
        HashEntry<K,V>[] tab = table;
        // 计算下标
        int index = (tab.length - 1) & hash;
        // 使用 UNSAFE 得到数组中的元素。链表的头节点
        HashEntry<K,V> first = entryAt(tab, index);
        for (HashEntry<K,V> e = first;;) {
            if (e != null) {
                K k;
                // 遍历链表,替换值
                if ((k = e.key) == key ||
                    (e.hash == hash && key.equals(k))) {
                    oldValue = e.value;
                    // 不是不存在的时候才插入,则插入
                    if (!onlyIfAbsent) {
                        e.value = value;
                        ++modCount;
                    }
                    break;
                }
                e = e.next;
            }
            else {
                // 当前数组的元素为 null,或者遍历到链表末尾
                // 如果 scanAndLockForPut 的时候,创建了节点,则直接赋值
                if (node != null)
                    node.setNext(first);
                else
                    // 新的node指向头节点
                    node = new HashEntry<K,V>(hash, key, value, first);
                int c = count + 1;
                if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                    // 扩容
                    rehash(node);
                else
                    // set元素,替换头节点
                    // UNSAFE直接修改主存的值。而[]操作只能修改线程副本的值
                    setEntryAt(tab, index, node);
                ++modCount;
                count = c;
                oldValue = null;
                break;
            }
        }
    } finally {
        unlock();
    }
    return oldValue;
}

获取锁:

static final int MAX_SCAN_RETRIES = Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;
private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
    // 取出数组中对应链表的头节点
    HashEntry<K,V> first = entryForHash(this, hash);
    HashEntry<K,V> e = first;
    HashEntry<K,V> node = null;
    int retries = -1; // negative while locating node
    while (!tryLock()) {
        HashEntry<K,V> f; // to recheck first below
        if (retries < 0) {
            // 如果当前节点为null,或者到了链表末尾
            if (e == null) {
                // 只创建一次
                if (node == null) // speculatively create node
                    node = new HashEntry<K,V>(hash, key, value, null);
                retries = 0;
            }
            // 如果 key 相同
            else if (key.equals(e.key))
                retries = 0;
            else
                // 遍历链表
                e = e.next;
        }
        else if (++retries > MAX_SCAN_RETRIES) {
            // 如果超过最大次数,则阻塞获取锁
            lock();
            break;
        }
        // retries为偶数,tables[n] 不等于头节点,即另一个线程更新了链表
        else if ((retries & 1) == 0 &&
                 (f = entryForHash(this, hash)) != first) {
            // 重新开始遍历链表,并重置次数
            e = first = f; // re-traverse if entry changed
            retries = -1;
        }
        // 如果在 MAX_SCAN_RETRIES 之内获取到锁,就退出循环
    }
    return node;
}
  • 先遍历链表,直到到链表的末尾,或者遇到相同节点,就 retries = 0。跳出这个 if 分支
    • 如果到了链表末尾 e=null,并且 node=null,则创建 node
    • 如果 key 相同,则 node=null
    • 在遍历链表过程中,也会 tryLock()
  • ++retries,判断是否超过最大次数
    • 超过的话,lock() 阻塞,等获取到锁后,break
  • 偶数次时,判断头节点是否被其他线程改变了
    • 如果是的话,则重新赋值遍历变量和 retries,回到第一个 if 重新遍历链表

Segment table 的扩容:

private void rehash(HashEntry<K,V> node) {
    HashEntry<K,V>[] oldTable = table;
    int oldCapacity = oldTable.length;
    // 新容量为原来的两倍
    int newCapacity = oldCapacity << 1;
    threshold = (int)(newCapacity * loadFactor);
    // 用新容量创建新的 table
    HashEntry<K,V>[] newTable =
        (HashEntry<K,V>[]) new HashEntry[newCapacity];
    int sizeMask = newCapacity - 1;
    // 遍历旧数组
    for (int i = 0; i < oldCapacity ; i++) {
        // 数组当前位置的头节点
        HashEntry<K,V> e = oldTable[i];
        if (e != null) {
            HashEntry<K,V> next = e.next;
            // 计算元素在新数组的位置
            int idx = e.hash & sizeMask;
            // 如果只有一个元素
            if (next == null)   //  Single node on list
                newTable[idx] = e;
            else { // Reuse consecutive sequence at same slot
                // 记录最后一段连续的序列
                HashEntry<K,V> lastRun = e;
                int lastIdx = idx;
                for (HashEntry<K,V> last = next;
                     last != null; 
                     last = last.next) {
                    int k = last.hash & sizeMask;
                    // 每次计算的新位置与上次保留的不同,即记录下来
                    if (k != lastIdx) {
                        lastIdx = k;
                        lastRun = last;
                    }
                }
                // 直接将最后一段的元素移动到新数组
                newTable[lastIdx] = lastRun;
                // Clone remaining nodes
                for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                    V v = p.value;
                    int h = p.hash;
                    int k = h & sizeMask;
                    // 头插法
                    HashEntry<K,V> n = newTable[k];
                    newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                }
            }
        }
    }
    // 插入新节点
    int nodeIndex = node.hash & sizeMask; // add the new node
    node.setNext(newTable[nodeIndex]);
    newTable[nodeIndex] = node;
    // 重新指向新数组
    table = newTable;
}

get:

public V get(Object key) {
    Segment<K,V> s; // manually integrate access methods to reduce overhead
    HashEntry<K,V>[] tab;
    int h = hash(key);
    long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
    // 获取对应位置的 segment 对象,以及内部的 table 数组
    if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
        (tab = s.table) != null) {
        // 获取 table 中对应的链表,并遍历元素
        for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
             (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
             e != null; e = e.next) {
            K k;
            if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                return e.value;
        }
    }
    return null;
}

remove:

size:

尝试非加锁统计,sum += seg.modCount,size += seg.count
如果两次统计的 sum 都一样,则返回 size,否则。加锁



弱一致性:

因为 get 的时候是无锁的

put() 后并不能马上 get() 最新数据