HashTable 和 ConcurrentHashMap 的 key,value 都不能为 null。HashMap 可以
JDK 1.7 前使用的是分段锁,每次锁住一个 segment,一个 segment 包含多个 bucket。
JDK 1.8 后使用数组 + 链表/红黑树,每次锁住一个 bucket 上的链表/红黑树,从而提高并发度。
扩容 JDK1.8:https://juejin.im/post/6844903607901356046
主要利用 Unsafe 操作 + ReentrantLock + 分段思想。
- Unsafe:CAS、获取数组的元素、设置数组的元素。直接操作主存
- 分段:Segment 提高并发度
数据结构:
// 掩码,segments长度-1。即bit全为1。与HashMap中的indexFor类似
final int segmentMask;
final int segmentShift;
final Segment<K,V>[] segments;
static final class Segment<K,V> extends ReentrantLock implements Serializable {
transient volatile HashEntry<K,V>[] table;
}
参数:
// 默认容量 16。全部 HashEntry[] 数组的数量
static final int DEFAULT_INITIAL_CAPACITY = 16;
// 负载因子 0.75
static final float DEFAULT_LOAD_FACTOR = 0.75f;
// 并发级别 16。segments 数组的大小
static final int DEFAULT_CONCURRENCY_LEVEL = 16;
// Integer.MAX_VALUE = 2^31-1。所以最大容量只能是 2^30(因为要满足2的幂次方)
static final int MAXIMUM_CAPACITY = 1 << 30;
// per-segment tables 大小的最小值,2的幂次方
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;
虽然默认容量16/默认并发度16=1,但是默认的 per-segment tables 大小为 2
初始化:
int sshift = 0;
int ssize = 1;
while (ssize < concurrencyLevel) {
++sshift;
ssize <<= 1;
}
this.segmentShift = 32 - sshift;
this.segmentMask = ssize - 1;
if (initialCapacity > MAXIMUM_CAPACITY)
initialCapacity = MAXIMUM_CAPACITY;
int c = initialCapacity / ssize;
// c 要向上取整
if (c * ssize < initialCapacity)
++c;
int cap = MIN_SEGMENT_TABLE_CAPACITY;
while (cap < c)
cap <<= 1;
// create segments and segments[0]
Segment<K,V> s0 =
new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
(HashEntry<K,V>[])new HashEntry[cap]);
Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
this.segments = ss;
实际的并发度(segments 数组的大小) ssize 需要是 2 的幂次方,所以要根据 concurrencyLevel 计算
- c 表示每个 segment 中 table 的大小,cap 则取大于等于 c 的 2 幂次方的最小值
- 阈值是对某个 segment 对象里面的 tables 的,而不是对于 segments 数组的
- 会先创建 segments[0] 这个 s0 对象,是为了保存 cap 和 threshold 等属性,以便于后续其他对象的创建
- sshift 为 log2(ssize),即 2^sshift=ssize
- this.segmentShift = 32 - sshift:32减去并发度
扩容:
- 扩容只会增大某个 segment 对象里面的 tables 数组的大小,而不会去增大 segments 数组的大小
线程安全:
- Segment 继承了 ReentrantLock,直接调用内部的 tryLock()、unlock()
- 并且使用 scanAndLockForPut()
put:
public V put(K key, V value) {
Segment<K,V> s;
if (value == null)
throw new NullPointerException();
int hash = hash(key);
int j = (hash >>> segmentShift) & segmentMask;
// SSHIFT 数组中一个元素的大小的位数
if ((s = (Segment<K,V>)UNSAFE.getObject // nonvolatile; recheck
(segments, (j << SSHIFT) + SBASE)) == null) // in ensureSegment
// 如果对应的位置为 null,则生成一个 Segment 对象
s = ensureSegment(j);
return s.put(key, hash, value, false);
}
- int j = (hash >>> segmentShift) & segmentMask;
- hash >>> segmentShift:保留与 segments 数组同样长度的最高位
- segment.put 里面就是使用 (tab.length - 1) & hash
ensureSegment 生成 Segment 对象,注意并发安全问题
private Segment<K,V> ensureSegment(int k) {
final Segment<K,V>[] ss = this.segments;
long u = (k << SSHIFT) + SBASE; // raw offset
Segment<K,V> seg;
// 进来后再检查一次,k 位置的元素为 null
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
// 使用 segments 的第一个对象作为原型,拿到 cap 和 loadFactor
Segment<K,V> proto = ss[0]; // use segment 0 as prototype
int cap = proto.table.length;
float lf = proto.loadFactor;
int threshold = (int)(cap * lf);
// 创建 table
HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) { // recheck
// 创建 segment 对象
Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
// 如果有其他线程设置成功,获取
while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) {
if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
break;
}
}
}
return seg;
}
进来后再检查一次,k 位置的元素为 null
- 根据数组第一个元素为原型,创建 Segment 对象和里面的 HashEntry[] 数组
- 使用 CAS 自旋去 set
- 如果有其他线程设置成功了,则获取并返回 seg
- 用 while 而不用 false,是为了防止 ABA,第三个线程把元素 remove 了就变成 null ```java // 获取数组对象中,每个元素的大小 ss = UNSAFE.arrayIndexScale(sc); // ss对应二进制的长度。比如16为4,8为3,4为2。 // 16=10000b=5位。32-5+1=31-5 // 31 - 最高位1前面0的数量 SSHIFT = 31 - Integer.numberOfLeadingZeros(ss);
// 比如元素大小为16,则返回17,31-27=4。则得到了16的2的4次方 (j * ss) + SBASE (j << SSHIFT) + SBASE // 左移1位等于乘以2,左移4位等于乘以2^4=16 // 所以这里是通过左移来实现乘法,因为元素大小ss总是2的幂次方,所以可以等同
<a name="GxGUH"></a>
### Segment 的 put:
```java
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
HashEntry<K,V> node = tryLock() ? null :
scanAndLockForPut(key, hash, value);
V oldValue;
try {
HashEntry<K,V>[] tab = table;
// 计算下标
int index = (tab.length - 1) & hash;
// 使用 UNSAFE 得到数组中的元素。链表的头节点
HashEntry<K,V> first = entryAt(tab, index);
for (HashEntry<K,V> e = first;;) {
if (e != null) {
K k;
// 遍历链表,替换值
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
oldValue = e.value;
// 不是不存在的时候才插入,则插入
if (!onlyIfAbsent) {
e.value = value;
++modCount;
}
break;
}
e = e.next;
}
else {
// 当前数组的元素为 null,或者遍历到链表末尾
// 如果 scanAndLockForPut 的时候,创建了节点,则直接赋值
if (node != null)
node.setNext(first);
else
// 新的node指向头节点
node = new HashEntry<K,V>(hash, key, value, first);
int c = count + 1;
if (c > threshold && tab.length < MAXIMUM_CAPACITY)
// 扩容
rehash(node);
else
// set元素,替换头节点
// UNSAFE直接修改主存的值。而[]操作只能修改线程副本的值
setEntryAt(tab, index, node);
++modCount;
count = c;
oldValue = null;
break;
}
}
} finally {
unlock();
}
return oldValue;
}
获取锁:
static final int MAX_SCAN_RETRIES = Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;
private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
// 取出数组中对应链表的头节点
HashEntry<K,V> first = entryForHash(this, hash);
HashEntry<K,V> e = first;
HashEntry<K,V> node = null;
int retries = -1; // negative while locating node
while (!tryLock()) {
HashEntry<K,V> f; // to recheck first below
if (retries < 0) {
// 如果当前节点为null,或者到了链表末尾
if (e == null) {
// 只创建一次
if (node == null) // speculatively create node
node = new HashEntry<K,V>(hash, key, value, null);
retries = 0;
}
// 如果 key 相同
else if (key.equals(e.key))
retries = 0;
else
// 遍历链表
e = e.next;
}
else if (++retries > MAX_SCAN_RETRIES) {
// 如果超过最大次数,则阻塞获取锁
lock();
break;
}
// retries为偶数,tables[n] 不等于头节点,即另一个线程更新了链表
else if ((retries & 1) == 0 &&
(f = entryForHash(this, hash)) != first) {
// 重新开始遍历链表,并重置次数
e = first = f; // re-traverse if entry changed
retries = -1;
}
// 如果在 MAX_SCAN_RETRIES 之内获取到锁,就退出循环
}
return node;
}
- 先遍历链表,直到到链表的末尾,或者遇到相同节点,就 retries = 0。跳出这个 if 分支
- 如果到了链表末尾 e=null,并且 node=null,则创建 node
- 如果 key 相同,则 node=null
- 在遍历链表过程中,也会 tryLock()
- ++retries,判断是否超过最大次数
- 超过的话,lock() 阻塞,等获取到锁后,break
- 偶数次时,判断头节点是否被其他线程改变了
- 如果是的话,则重新赋值遍历变量和 retries,回到第一个 if 重新遍历链表
Segment table 的扩容:
private void rehash(HashEntry<K,V> node) {
HashEntry<K,V>[] oldTable = table;
int oldCapacity = oldTable.length;
// 新容量为原来的两倍
int newCapacity = oldCapacity << 1;
threshold = (int)(newCapacity * loadFactor);
// 用新容量创建新的 table
HashEntry<K,V>[] newTable =
(HashEntry<K,V>[]) new HashEntry[newCapacity];
int sizeMask = newCapacity - 1;
// 遍历旧数组
for (int i = 0; i < oldCapacity ; i++) {
// 数组当前位置的头节点
HashEntry<K,V> e = oldTable[i];
if (e != null) {
HashEntry<K,V> next = e.next;
// 计算元素在新数组的位置
int idx = e.hash & sizeMask;
// 如果只有一个元素
if (next == null) // Single node on list
newTable[idx] = e;
else { // Reuse consecutive sequence at same slot
// 记录最后一段连续的序列
HashEntry<K,V> lastRun = e;
int lastIdx = idx;
for (HashEntry<K,V> last = next;
last != null;
last = last.next) {
int k = last.hash & sizeMask;
// 每次计算的新位置与上次保留的不同,即记录下来
if (k != lastIdx) {
lastIdx = k;
lastRun = last;
}
}
// 直接将最后一段的元素移动到新数组
newTable[lastIdx] = lastRun;
// Clone remaining nodes
for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
V v = p.value;
int h = p.hash;
int k = h & sizeMask;
// 头插法
HashEntry<K,V> n = newTable[k];
newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
}
}
}
}
// 插入新节点
int nodeIndex = node.hash & sizeMask; // add the new node
node.setNext(newTable[nodeIndex]);
newTable[nodeIndex] = node;
// 重新指向新数组
table = newTable;
}
get:
public V get(Object key) {
Segment<K,V> s; // manually integrate access methods to reduce overhead
HashEntry<K,V>[] tab;
int h = hash(key);
long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
// 获取对应位置的 segment 对象,以及内部的 table 数组
if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
(tab = s.table) != null) {
// 获取 table 中对应的链表,并遍历元素
for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
e != null; e = e.next) {
K k;
if ((k = e.key) == key || (e.hash == h && key.equals(k)))
return e.value;
}
}
return null;
}
remove:
size:
尝试非加锁统计,sum += seg.modCount,size += seg.count
如果两次统计的 sum 都一样,则返回 size,否则。加锁
弱一致性:
因为 get 的时候是无锁的
put() 后并不能马上 get() 最新数据