「Golang」sync.WaitGroup源码讲解
sync.WaitGroup介绍
当我们在开发过程中,经常需要在开启多个goroutine后,等待全部的goroutine执行完毕后才进行下一步的业务逻辑执行。此时我们可能会采用轮询的方式去定时侦测已经开启的多个goroutine的业务是否执行完毕,但是这样性能很低,并且持续占用cpu时间片很消耗cpu的资源,此时我们就该使用sync.WaitGroup来完成此次操作。举个🌰,下列代码是开了10个goroutine后等待其各睡眠5秒之后进行后续操作的sync.WaitGroup方法实现。
func main() {// 创建对象var wait sync.WaitGroupfor i := 0; i < 10; i++ {// 为需要等待结束的goroutine数量+1wait.Add(1)go func() {time.Sleep(5*time.Second)// 结束 使得需要等待的数量-1// 等同于 wait.Add(-1)wait.Done()}()}// 等待所有执行完毕wait.Wait()fmt.Println("wait done")}
上述代码的睡眠5秒可以替换为任何需要在goroutine中执行的业务逻辑,上述代码中出现了下述几个sync.WaitGroup中提供的方法,提供方法很少很简洁,接下来就开始解析一下sync.WaitGroup的相关信息。
func (wg *WaitGroup) Add(delta int)func (wg *WaitGroup) Done()func (wg *WaitGroup) Wait()
sync.WaitGroup源代码解析
1:sync.WaitGroup结构体的解析
type WaitGroup struct {// 一个防止sync.WaitGroup被复制的标记结构体noCopy noCopy// 该数组在32为系统与64位系统中代表的用途不同// 首先说64位系统:// state1[0]代表当前sync.WaitGroup 调用Add方法增加了多少的couter// state1[1]代表调用了Wait方法等待结束的Waiter的数量// state1[2]代表Waiter的信号量// 其中 state1[0]与state1[1]作者称为计数标记// 在32位系统中:// state1[0]代表Waiter的信号量// state1[1]代表当前sync.WaitGroup 调用Add方法增加了多少的couter// state1[2]代表调用了Wait方法等待结束的Waiter的数量// 其中 state1[1]与state1[2]作者称为计数标记state1 [3]uint32}
Add(delta int)方法源代码解析
// state 方法用于根据系统是32位还是64位返回对应的state1字段的对应的计数和信号量的地址func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {// 64位系统返回state1与state1[2],由于数组是连续内存所以可以通过首地址// 取出state1[0]与state1[1]的所有二进制位return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]} else {// 32位系统返回state1与state1[2],由于数组是连续内存所以可以通过首地址// 取出state1[0]与state1[1]的所有二进制位return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]}}// Add 方法传入一个增量 可以为赋值则代表 Done// 例如 调用Done方法就是对Add方法传入了-1// 即Add(-1)func (wg *WaitGroup) Add(delta int) {// 根据系统位数返回计数标记和信号量标记statep, semap := wg.state()// 做了race检查和异常的检查if race.Enabled {_ = *statep // 如果信号量是个空指针,则报错if delta < 0 {race.ReleaseMerge(unsafe.Pointer(wg))}race.Disable()defer race.Enable()}// 将计数标记的高32位的值+delta// 如果是64位系统则代表state1[0]+deltastate := atomic.AddUint64(statep, uint64(delta)<<32)// 转换成正常的数字// 例如 以64为系统为例 当原state1[0]为0时// atomic.AddUint64(statep, uint64(delta)<<32)// 当delta==1时// 结果为 2^32-1// 该操作就是将此数值转变为1v := int32(state >> 32)// w代表需要等待的数量即 64位系统中的state1[1]为statew := uint32(state)// 做race判断,并判断delta是否为负数 如果是的话并且v与delta相等则做一些race的同步// fixme 此处解释存疑if race.Enabled && delta > 0 && v == int32(delta) {// The first increment must be synchronized with Wait.// Need to model this as a read, because there can be// several concurrent wg.counter transitions from 0.race.Read(unsafe.Pointer(semap))}// 如果v<0则代表delta的传入的为负值,并且该负值与原couter相减后小于0// 说白了一点就说Add传入的负值超出了原有couter的数量if v < 0 {panic("sync: negative WaitGroup counter")}// 如果等待数量不是0 并且delta>0 且v==delta 则代表出现了// 同时并发调用了Add和Waitif w != 0 && delta > 0 && v == int32(delta) {panic("sync: WaitGroup misuse: Add called concurrently with Wait")}// 正常情况 直接返回if v > 0 || w == 0 {return}// 也是同时调用Add和Waitif *statep != state {panic("sync: WaitGroup misuse: Add called concurrently with Wait")}// 如果计数值v为0并且waiter的数量w不为0// 则代表delta传入的值使得couter变为了0,但是还是有waiter在等待的话// 就把statep即state1[0]与state1[1]设置为0// 并唤醒所有的正在等待的阻塞goroutine*statep = 0for ; w != 0; w-- {runtime_Semrelease(semap, false, 0)}}
Add(delta int)方法的总体执行过程大概如下:
1.根据64位还是32位系统获取对应的标记位2.对计数标记位做delta的Add3.判断Add之后的state是否为一些非法情况,比如v<0等4.如果v和w分别为大于0和等于0则正常返回,代表Add成功5.否则判断当v为0时则代表没有counter了,但是还有waiter那么就把state整体设置为0,随后唤醒调用了 Wait()的阻塞的goroutine。
Done() 方法解析
// Done 没什么可说的 就是调用了一下Add(-1)func (wg *WaitGroup) Done() {wg.Add(-1)}
Wait() 方法解析
func (wg *WaitGroup) Wait() {// 根据系统位数返回计数标记和信号量标记statep, semap := wg.state()// race检测if race.Enabled {_ = *statep // trigger nil deref earlyrace.Disable()}// 循环校验是否所有的goroutine都调用了Donefor {//原子获取值state := atomic.LoadUint64(statep)// v 代表 couter数量 ,即高32位v := int32(state >> 32)// w 代表waiter数量 ,即低32位w := uint32(state)// 如果v==0 代表没有 couter了则不用等待了直接返回if v == 0 {// Counter is 0, no need to wait.if race.Enabled {race.Enable()race.Acquire(unsafe.Pointer(wg))}return}// 如果statep的值与state相等 还有需要等待完成的goroutine 此时则waiter+1if atomic.CompareAndSwapUint64(statep, state, state+1) {// race检测if race.Enabled && w == 0 {race.Write(unsafe.Pointer(semap))}// 阻塞等待直到被Add唤醒runtime_Semacquire(semap)// 如果被唤醒了 但是发现地址中的值不是0 代表唤醒错误 panicif *statep != 0 {panic("sync: WaitGroup is reused before previous Wait has returned")}// race检测if race.Enabled {race.Enable()race.Acquire(unsafe.Pointer(wg))}return}}}
Wait()方法的总体执行过程大概如下:
1.获取标记位的地址2.获取值3.判断v是否为0,如果是则无须等待。4.如果v不是0并且statep的值与state相等,则代表还有需要等待完成的goroutine,此时则waiter+1,然后阻塞等待Add方法唤醒
总结
sync.WaitGroup 使用的规范
Add(delta int)方法可以设置为负值,但是必须要确保这个负值delta加上当前计数器的数量的结果大于0。否则会panic。🌰如下
func main() {var wait sync.WaitGroupwait.Add(1) // 没问题 计数器为1wait.Add(-1)// 没问题 计数器为0wait.Add(-3) // panic 此时计数器为0-3 出错}
func main() { var wait sync.WaitGroup wait.Add(1) // 没问题 计数器为1 wait.Add(1)// 没问题 计数器为2
wait.Done() // 没问题 计数器为1wait.Done()// 没问题 计数器为0wait.Done()// panic 此时计数器为-1
}
2.`Add(delta int)`方法必须在 `Wait()` 方法调用之前全部调用完毕,否则会出现panic。举个🌰,本实例的设想是进入goroutine后`Add(delta int)`但是`Wait()` 方法调用早于goroutine中`Add(delta int)`,所以此时`Wait()`计数器为0,则不等待直接跳过。代码如下:```gofunc main() {var wait sync.WaitGroupgo func() {// 故意sleep 代表执行逻辑time.Sleep(1*time.Millisecond)fmt.Println("Add")wait.Add(1)wait.Done()}()go func() {// 故意sleep 代表执行逻辑time.Sleep(1*time.Millisecond)fmt.Println("Add")wait.Add(1)wait.Done()}()wait.Wait()fmt.Println("Done")}
3.不可以在前一个Wait()还未结束时,复用sync.WaitGroup ,举个例子代码:
func main() {var wait sync.WaitGroupwait.Add(1)go func() {// 故意sleep 代表执行逻辑time.Sleep(1*time.Millisecond)wait.Done() // 正常结束wait.Add(1) // 此处panic 因为wait还未结束就再次复用}()wait.Wait()fmt.Println("Done")}
sync.WaitGroup虽然可以重用,但是是有一个前提的,那就是必须等到上一轮的Wait()完成之后,才能重用sync.WaitGroup执行下一轮的Add(delta int)/Wait(),如果你在Wait()还没执行完的时候就调用下一轮 Add 方法,就有可能出现 panic。
至此sync.WaitGroup的源码解析就解析完了,可能有些地方有些理解上的错误,请各位谅解并且帮忙指出修改意见,如果这篇文章能帮到你,这是我的荣幸。
