为什么要有ConcurrentHashMap
- 普通的HashMap有线程安全问题;
- HashTable虽然线程安全,但是会锁上整个HashTable,效率太低
ConcurrentHashMap采用分段锁,只锁一个Segment,多线程访问不同的Segment不会有线程安全问题,从而提高了效率。
结构
ConcurrentHashMap有一个Segment
[]数组: final Segment<K,V>[] segments;
Segment里又有一个HashEntry
[]的数组: Segment(float lf, int threshold, HashEntry<K,V>[] tab) {
this.loadFactor = lf;
this.threshold = threshold;
this.table = tab;
}
因此结构大概如下图:
构造函数
// initialCapacity: 初始时有多少个HashEntry链表
// loadFactor:加载因子
// concurrencyLevel: 并发级别,用来决定有多少个Segment
// 之所以叫并发级别,是因为上锁是单独对segment上锁,如果有16个Segment,就能允许最高16条线程并发
public ConcurrentHashMap(int initialCapacity,
float loadFactor, int concurrencyLevel) {
// 参数不合理,抛异常
if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
throw new IllegalArgumentException();
// 如果并发级别超过了预定的最大值,则设置成最大值
if (concurrencyLevel > MAX_SEGMENTS)
concurrencyLevel = MAX_SEGMENTS;
// Find power-of-two sizes best matching arguments
int sshift = 0;
int ssize = 1;
// 找到比concurrencyLevel大的最小二次幂
// 采用左移效率更高
// 例如concurrencyLevel = 12, 那么ssize就是16
while (ssize < concurrencyLevel) {
++sshift;
ssize <<= 1;
}
this.segmentShift = 32 - sshift;
// 掩码,未来要用来计算下标的,类似于HashMap中的h & (length - 1)
this.segmentMask = ssize - 1;
// 如果初始容量大于最大值,则设置成最大值
if (initialCapacity > MAXIMUM_CAPACITY)
initialCapacity = MAXIMUM_CAPACITY;
// c用来表示每个Segment下能有多少组链表
// 例如initialCapacity = 32, ssize = 16, 那么c = 2
// 例如initialCapacity = 40, ssize = 16, 那么c = 3
int c = initialCapacity / ssize;
if (c * ssize < initialCapacity)
++c;
// 保证c的大小不小于2
// 即每个segment下至少有两组链表
int cap = MIN_SEGMENT_TABLE_CAPACITY;
// 真正的每个Segment下的链表的组的数量
// 为2的幂次方,例如在上面c = 3, 那么cap就等于4
while (cap < c)
cap <<= 1;
// create segments and segments[0]
// 初始化一个Segment0,
Segment<K,V> s0 =
new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
(HashEntry<K,V>[])new HashEntry[cap]);
// 初始化Segment数组
Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
// native方法,TODO
UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
this.segments = ss;
}
- 和HashMap不同的是,有个参数是并发级别:
该参数决定了segment数组的大小,并不是传入16就是等于16,而是去找到比传入数值大的最小二进制数,比如传入17,segment数组就是32。 ```java int ssize = 1; while (ssize < concurrencyLevel) { ++sshift; ssize <<= 1; }static final int DEFAULT_CONCURRENCY_LEVEL = 16;
…
Segment
HashEntry数组的数量由initialCapacity和concurrencyLevel共同决定。
- 每个Segment下最少有两个HashEntry
```java
// MIN_SEGMENT_TABLE_CAPACITY是2
int cap = MIN_SEGMENT_TABLE_CAPACITY;
每个Segment下的HashEntry数量只能是2的幂次方:
while (cap < c) cap <<= 1;
在上面的过程中,假设我们传入的concurrencyLevel是17,那么我们知道segment数组大小是32。
- 用segment数组大小 每个segment下的HashEntry数量 > initialCapacity, 假设传入的initalCapacity为127, 那么么个segment下的HashEntry数量为4,因为32 4 = 128 > 127,然而 32 * 2 = 64 < 127。(之所以从2跳到4,是因为必须是2的幂次方) ```java // c = 127 / 32 = 3 int c = initialCapacity / ssize;
// 3 32 = 96 < 127 // ++c => c = 4 if (c ssize < initialCapacity) ++c;
// cap = 2 int cap = MIN_SEGMENT_TABLE_CAPACITY;
// cap < 4 => cap <<= 1
// cap = 4
while (cap < c)
cap <<= 1;
// create segments and segments[0]
Segment
<a name="IHWS0"></a>
# Unsafe操作
我们要用Unsafe解决线程不安全的问题。出现线程不安全的原因是,不同的线程会使用自己缓存的值,而不是直接用内存里的值,如下:<br />![image.png](https://cdn.nlark.com/yuque/0/2021/png/12689050/1629026150006-76fc8835-161f-47e5-a59e-a38a6b1a7025.png#clientId=ue7b3e1bd-6943-4&from=paste&height=684&id=ue64d923f&margin=%5Bobject%20Object%5D&name=image.png&originHeight=809&originWidth=1188&originalType=binary&ratio=1&size=54434&status=done&style=none&taskId=ue973d53c-777a-4dd8-9be3-291e66d16d4&width=1004)
Unsafe使用的操作是CAS,即campareAndSwap,就是说,在修改之前,去比对本线程中的值是否与内存中的值相等,只有相等的情况下,才去做操作,该操作为原子操作。
<a name="zoH8L"></a>
## 自己的类无法直接使用Unsafe类
如果我们仿照ConcurrentHashmap那样初始化Unsafe,会报错:
```java
private static final sun.misc.Unsafe UNSAFE;
static {
UNSAFE = sun.misc.Unsafe.getUnsafe();
}
这是因为,我们的类加载器是APP类加载器,而ConcurrentHashmap是Bootstrap类加载器,因此getClassLoader方法才会返回null, 才不会抛出异常:
public static Unsafe getUnsafe() {
Class var0 = Reflection.getCallerClass();
if (var0.getClassLoader() != null) {
throw new SecurityException("Unsafe");
} else {
return theUnsafe;
}
}
解决办法就是通过反射的方式:
Put方法
public V put(K key, V value) {
Segment<K,V> s;
// value不能为Null
if (value == null)
throw new NullPointerException();
// 得到哈希值
int hash = hash(key);
// 取高位,举个例子,这个segmentShift如果等于28,那么就会把高4位移动到低4位
// 和segmentMask相与,得到下标,也就是原高4位的下标
// 之所以取高4位,是因为将来算另一个下标的时候,要用到低4位
int j = (hash >>> segmentShift) & segmentMask;
// 如果该下标位置上的segment还没初始化,就去初始化一个Segment放在这里
if ((s = (Segment<K,V>)UNSAFE.getObject // nonvolatile; recheck
(segments, (j << SSHIFT) + SBASE)) == null) // in ensureSegment
s = ensureSegment(j);
// 将key-value放入
return s.put(key, hash, value, false);
}
- 首先对key取hash;
- 然后取得下标 hash >>> segmentShift) & segmentMask;, 此时取的是高位,例如高四位;
- 看下标的位置是否已存在一个Segment对象,如果不存在就生成一个;
此时就将出现线程安全问题, 例如两个线程同时想要生成Segment对象放在对应的索引位置上。
解决办法:用Unsafe,如果该对象已经存在,直接返回,否则生成一个。
EnsureSegment方法
// 该方法考虑了多线程的问题,多次判断取出来的对象是否为null,
// 因为在这个过程中,随时有可能有其他线程生成segment,使得它不再为null
// 当不为null的时候,返回该对象即可
private Segment<K,V> ensureSegment(int k) {
final Segment<K,V>[] ss = this.segments;
long u = (k << SSHIFT) + SBASE; // raw offset
Segment<K,V> seg;
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
Segment<K,V> proto = ss[0]; // use segment 0 as prototype
int cap = proto.table.length;
float lf = proto.loadFactor;
int threshold = (int)(cap * lf);
HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) { // recheck
Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) {
if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
break;
}
}
}
return seg;
}
segment.put方法
调用Segment的put方法,即把key-value对放到HashEntry上去,此后的线程安全问题通过锁来解决;
- 先获取对应的index;
判断对应的index上是否有值;
- 有值就往后遍历,或者覆盖
- 没值就生成一个结点,插入;
```java
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
// 尝试上锁,tryLock()是ReentrantLock类里的方法
// Segment继承了ReentrantLock,因此是锁Segment自己,因此是分段锁
HashEntry
node = tryLock() ? null : scanAndLockForPut(key, hash, value); V oldValue;
try { // 获得HashEntry数组 HashEntry
[] tab = table; // 计算下标,这里就是取低四位了 // 如果在之前计算Segment下标的时候也取低4位, // 那么就会导致某个Segment下的数据都放在同一个链表下了 int index = (tab.length - 1) & hash;
// 获取链表的第一个结点, 一定要一边回想ConcurrentHashMap的结构一边看代码 HashEntry
first = entryAt(tab, index); // 遍历HashEntry链表 for (HashEntry
e = first;;) { // 如果该处已经存在链表了 if (e != null) { K k; // 如果找到相同的key就覆盖 if ((k = e.key) == key || (e.hash == hash && key.equals(k))) { oldValue = e.value; if (!onlyIfAbsent) { e.value = value; ++modCount; } break; } // 否则接着往下遍历 e = e.next; } // 该处无链表 else { // 此处是应对其他线程加锁成功后,node结点被new出来了,参考scanAndLockForPut方法 if (node != null) node.setNext(first); else node = new HashEntry<K,V>(hash, key, value, first); int c = count + 1; // 如果结点数量过多,对当前Segment下的HashEntry数组进行扩容 // 不明白为什么要叫rehash而不叫resize if (c > threshold && tab.length < MAXIMUM_CAPACITY) rehash(node); else // 否则就插入 setEntryAt(tab, index, node); ++modCount; count = c; oldValue = null; break; }
} } finally { // 释放锁 unlock(); } return oldValue; } ```
Rehash扩容(TODO)
扩容是局部扩容,不会改变Segment数组对象的长度,而是改变单个Segment下的HashEntry数组的长度。因此,Segment数组一旦初始化,其长度不会再变化。
private void rehash(HashEntry<K,V> node) {
/*
* Reclassify nodes in each list to new table. Because we
* are using power-of-two expansion, the elements from
* each bin must either stay at same index, or move with a
* power of two offset. We eliminate unnecessary node
* creation by catching cases where old nodes can be
* reused because their next fields won't change.
* Statistically, at the default threshold, only about
* one-sixth of them need cloning when a table
* doubles. The nodes they replace will be garbage
* collectable as soon as they are no longer referenced by
* any reader thread that may be in the midst of
* concurrently traversing table. Entry accesses use plain
* array indexing because they are followed by volatile
* table write.
*/
// 旧的数组
HashEntry<K,V>[] oldTable = table;
int oldCapacity = oldTable.length;
// 新数组容量为旧数组两倍
// 下面的和HashMap类似
int newCapacity = oldCapacity << 1;
threshold = (int)(newCapacity * loadFactor);
HashEntry<K,V>[] newTable =
(HashEntry<K,V>[]) new HashEntry[newCapacity];
int sizeMask = newCapacity - 1;
// 遍历旧数组
for (int i = 0; i < oldCapacity ; i++) {
HashEntry<K,V> e = oldTable[i];
if (e != null) {
HashEntry<K,V> next = e.next;
// 类似HashMap,要么在原地,要么为原来的index + oldLengh;
int idx = e.hash & sizeMask;
// 下一个结点是空,就表示是单个结点,直接挪过去即可
if (next == null) // Single node on list
newTable[idx] = e;
else { // Reuse consecutive sequence at same slot
// 下面是去找连续的新下标相同的结点,例如1->2->3->4->5
// 假设1,2,3的新下标是相同的,就把1->2->3一次性移过去
HashEntry<K,V> lastRun = e;
int lastIdx = idx;
for (HashEntry<K,V> last = next;
last != null;
last = last.next) {
int k = last.hash & sizeMask;
if (k != lastIdx) {
lastIdx = k;
lastRun = last;
}
}
newTable[lastIdx] = lastRun;
// Clone remaining nodes
// 将剩下的结点移过去
for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
V v = p.value;
int h = p.hash;
int k = h & sizeMask;
HashEntry<K,V> n = newTable[k];
newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
}
}
}
}
// 将要添加的新结点添加进来
int nodeIndex = node.hash & sizeMask; // add the new node
node.setNext(newTable[nodeIndex]);
newTable[nodeIndex] = node;
table = newTable;
}
scanAndLockForPut方法
之前我们在加锁的时候,不一定会成功,就会去调用该方法:
HashEntry<K,V> node = tryLock() ? null :
scanAndLockForPut(key, hash, value);
加锁失败并不会阻塞,而是去做一些准备工作,例如将该HashEntry结点生成出来
private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
// 拿到要放置位置的第一个结点
HashEntry<K,V> first = entryForHash(this, hash);
HashEntry<K,V> e = first;
HashEntry<K,V> node = null;
// 重试次数
int retries = -1; // negative while locating node
// 尝试加锁,如果失败就去循环
while (!tryLock()) {
HashEntry<K,V> f; // to recheck first below
// 第一次必然会进来
if (retries < 0) {
// 如果第一个结点不为空
if (e == null) {
// 如果node为空,就去创建一个node
if (node == null) // speculatively create node
node = new HashEntry<K,V>(hash, key, value, null);
retries = 0;
}
// 如果要第一个结点的key和当前key相同,说明是要覆盖的值,就不需要去生成结点
else if (key.equals(e.key))
retries = 0;
// 否则就遍历下去,直到找到相同的key或者到末尾
else
e = e.next;
}
// 如果达到最大重试次数,加锁
else if (++retries > MAX_SCAN_RETRIES) {
lock();
break;
}
// 每间隔一次循环,判断一次第一个结点是否变化了
// 如果变化了,重新来一遍
else if ((retries & 1) == 0 &&
(f = entryForHash(this, hash)) != first) {
e = first = f; // re-traverse if entry changed
retries = -1;
}
}
return node;
}