「Golang」sync.WaitGroup源码讲解

sync.WaitGroup介绍

当我们在开发过程中,经常需要在开启多个goroutine后,等待全部的goroutine执行完毕后才进行下一步的业务逻辑执行。此时我们可能会采用轮询的方式去定时侦测已经开启的多个goroutine的业务是否执行完毕,但是这样性能很低,并且持续占用cpu时间片很消耗cpu的资源,此时我们就该使用sync.WaitGroup来完成此次操作。举个🌰,下列代码是开了10个goroutine后等待其各睡眠5秒之后进行后续操作的sync.WaitGroup方法实现。

  1. func main() {
  2. // 创建对象
  3. var wait sync.WaitGroup
  4. for i := 0; i < 10; i++ {
  5. // 为需要等待结束的goroutine数量+1
  6. wait.Add(1)
  7. go func() {
  8. time.Sleep(5*time.Second)
  9. // 结束 使得需要等待的数量-1
  10. // 等同于 wait.Add(-1)
  11. wait.Done()
  12. }()
  13. }
  14. // 等待所有执行完毕
  15. wait.Wait()
  16. fmt.Println("wait done")
  17. }

上述代码的睡眠5秒可以替换为任何需要在goroutine中执行的业务逻辑,上述代码中出现了下述几个sync.WaitGroup中提供的方法,提供方法很少很简洁,接下来就开始解析一下sync.WaitGroup的相关信息。

  1. func (wg *WaitGroup) Add(delta int)
  2. func (wg *WaitGroup) Done()
  3. func (wg *WaitGroup) Wait()

sync.WaitGroup源代码解析

1:sync.WaitGroup结构体的解析

  1. type WaitGroup struct {
  2. // 一个防止sync.WaitGroup被复制的标记结构体
  3. noCopy noCopy
  4. // 该数组在32为系统与64位系统中代表的用途不同
  5. // 首先说64位系统:
  6. // state1[0]代表当前sync.WaitGroup 调用Add方法增加了多少的couter
  7. // state1[1]代表调用了Wait方法等待结束的Waiter的数量
  8. // state1[2]代表Waiter的信号量
  9. // 其中 state1[0]与state1[1]作者称为计数标记
  10. // 在32位系统中:
  11. // state1[0]代表Waiter的信号量
  12. // state1[1]代表当前sync.WaitGroup 调用Add方法增加了多少的couter
  13. // state1[2]代表调用了Wait方法等待结束的Waiter的数量
  14. // 其中 state1[1]与state1[2]作者称为计数标记
  15. state1 [3]uint32
  16. }

Add(delta int)方法源代码解析

  1. // state 方法用于根据系统是32位还是64位返回对应的state1字段的对应的计数和信号量的地址
  2. func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
  3. if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
  4. // 64位系统返回state1与state1[2],由于数组是连续内存所以可以通过首地址
  5. // 取出state1[0]与state1[1]的所有二进制位
  6. return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
  7. } else {
  8. // 32位系统返回state1与state1[2],由于数组是连续内存所以可以通过首地址
  9. // 取出state1[0]与state1[1]的所有二进制位
  10. return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
  11. }
  12. }
  13. // Add 方法传入一个增量 可以为赋值则代表 Done
  14. // 例如 调用Done方法就是对Add方法传入了-1
  15. // 即Add(-1)
  16. func (wg *WaitGroup) Add(delta int) {
  17. // 根据系统位数返回计数标记和信号量标记
  18. statep, semap := wg.state()
  19. // 做了race检查和异常的检查
  20. if race.Enabled {
  21. _ = *statep // 如果信号量是个空指针,则报错
  22. if delta < 0 {
  23. race.ReleaseMerge(unsafe.Pointer(wg))
  24. }
  25. race.Disable()
  26. defer race.Enable()
  27. }
  28. // 将计数标记的高32位的值+delta
  29. // 如果是64位系统则代表state1[0]+delta
  30. state := atomic.AddUint64(statep, uint64(delta)<<32)
  31. // 转换成正常的数字
  32. // 例如 以64为系统为例 当原state1[0]为0时
  33. // atomic.AddUint64(statep, uint64(delta)<<32)
  34. // 当delta==1时
  35. // 结果为 2^32-1
  36. // 该操作就是将此数值转变为1
  37. v := int32(state >> 32)
  38. // w代表需要等待的数量即 64位系统中的state1[1]为state
  39. w := uint32(state)
  40. // 做race判断,并判断delta是否为负数 如果是的话并且v与delta相等则做一些race的同步
  41. // fixme 此处解释存疑
  42. if race.Enabled && delta > 0 && v == int32(delta) {
  43. // The first increment must be synchronized with Wait.
  44. // Need to model this as a read, because there can be
  45. // several concurrent wg.counter transitions from 0.
  46. race.Read(unsafe.Pointer(semap))
  47. }
  48. // 如果v<0则代表delta的传入的为负值,并且该负值与原couter相减后小于0
  49. // 说白了一点就说Add传入的负值超出了原有couter的数量
  50. if v < 0 {
  51. panic("sync: negative WaitGroup counter")
  52. }
  53. // 如果等待数量不是0 并且delta>0 且v==delta 则代表出现了
  54. // 同时并发调用了Add和Wait
  55. if w != 0 && delta > 0 && v == int32(delta) {
  56. panic("sync: WaitGroup misuse: Add called concurrently with Wait")
  57. }
  58. // 正常情况 直接返回
  59. if v > 0 || w == 0 {
  60. return
  61. }
  62. // 也是同时调用Add和Wait
  63. if *statep != state {
  64. panic("sync: WaitGroup misuse: Add called concurrently with Wait")
  65. }
  66. // 如果计数值v为0并且waiter的数量w不为0
  67. // 则代表delta传入的值使得couter变为了0,但是还是有waiter在等待的话
  68. // 就把statep即state1[0]与state1[1]设置为0
  69. // 并唤醒所有的正在等待的阻塞goroutine
  70. *statep = 0
  71. for ; w != 0; w-- {
  72. runtime_Semrelease(semap, false, 0)
  73. }
  74. }

Add(delta int)方法的总体执行过程大概如下:

1.根据64位还是32位系统获取对应的标记位2.对计数标记位做delta的Add3.判断Add之后的state是否为一些非法情况,比如v<0等4.如果v和w分别为大于0和等于0则正常返回,代表Add成功5.否则判断当v为0时则代表没有counter了,但是还有waiter那么就把state整体设置为0,随后唤醒调用了 Wait()的阻塞的goroutine。

Done() 方法解析

  1. // Done 没什么可说的 就是调用了一下Add(-1)
  2. func (wg *WaitGroup) Done() {
  3. wg.Add(-1)
  4. }

Wait() 方法解析

  1. func (wg *WaitGroup) Wait() {
  2. // 根据系统位数返回计数标记和信号量标记
  3. statep, semap := wg.state()
  4. // race检测
  5. if race.Enabled {
  6. _ = *statep // trigger nil deref early
  7. race.Disable()
  8. }
  9. // 循环校验是否所有的goroutine都调用了Done
  10. for {
  11. //原子获取值
  12. state := atomic.LoadUint64(statep)
  13. // v 代表 couter数量 ,即高32位
  14. v := int32(state >> 32)
  15. // w 代表waiter数量 ,即低32位
  16. w := uint32(state)
  17. // 如果v==0 代表没有 couter了则不用等待了直接返回
  18. if v == 0 {
  19. // Counter is 0, no need to wait.
  20. if race.Enabled {
  21. race.Enable()
  22. race.Acquire(unsafe.Pointer(wg))
  23. }
  24. return
  25. }
  26. // 如果statep的值与state相等 还有需要等待完成的goroutine 此时则waiter+1
  27. if atomic.CompareAndSwapUint64(statep, state, state+1) {
  28. // race检测
  29. if race.Enabled && w == 0 {
  30. race.Write(unsafe.Pointer(semap))
  31. }
  32. // 阻塞等待直到被Add唤醒
  33. runtime_Semacquire(semap)
  34. // 如果被唤醒了 但是发现地址中的值不是0 代表唤醒错误 panic
  35. if *statep != 0 {
  36. panic("sync: WaitGroup is reused before previous Wait has returned")
  37. }
  38. // race检测
  39. if race.Enabled {
  40. race.Enable()
  41. race.Acquire(unsafe.Pointer(wg))
  42. }
  43. return
  44. }
  45. }
  46. }

Wait()方法的总体执行过程大概如下:

1.获取标记位的地址2.获取值3.判断v是否为0,如果是则无须等待。4.如果v不是0并且statep的值与state相等,则代表还有需要等待完成的goroutine,此时则waiter+1,然后阻塞等待Add方法唤醒

总结

sync.WaitGroup 使用的规范

  1. Add(delta int) 方法可以设置为负值,但是必须要确保这个负值delta加上当前计数器的数量的结果大于0。否则会panic。🌰如下

  1. func main() {
  2. var wait sync.WaitGroup
  3. wait.Add(1) // 没问题 计数器为1
  4. wait.Add(-1)// 没问题 计数器为0
  5. wait.Add(-3) // panic 此时计数器为0-3 出错
  6. }

func main() { var wait sync.WaitGroup wait.Add(1) // 没问题 计数器为1 wait.Add(1)// 没问题 计数器为2

  1. wait.Done() // 没问题 计数器为1
  2. wait.Done()// 没问题 计数器为0
  3. wait.Done()// panic 此时计数器为-1

}

  1. 2.`Add(delta int)`方法必须在 `Wait()` 方法调用之前全部调用完毕,否则会出现panic。举个🌰,本实例的设想是进入goroutine`Add(delta int)`但是`Wait()` 方法调用早于goroutine`Add(delta int)`,所以此时`Wait()`计数器为0,则不等待直接跳过。代码如下:
  2. ```go
  3. func main() {
  4. var wait sync.WaitGroup
  5. go func() {
  6. // 故意sleep 代表执行逻辑
  7. time.Sleep(1*time.Millisecond)
  8. fmt.Println("Add")
  9. wait.Add(1)
  10. wait.Done()
  11. }()
  12. go func() {
  13. // 故意sleep 代表执行逻辑
  14. time.Sleep(1*time.Millisecond)
  15. fmt.Println("Add")
  16. wait.Add(1)
  17. wait.Done()
  18. }()
  19. wait.Wait()
  20. fmt.Println("Done")
  21. }

3.不可以在前一个Wait()还未结束时,复用sync.WaitGroup ,举个例子代码:

  1. func main() {
  2. var wait sync.WaitGroup
  3. wait.Add(1)
  4. go func() {
  5. // 故意sleep 代表执行逻辑
  6. time.Sleep(1*time.Millisecond)
  7. wait.Done() // 正常结束
  8. wait.Add(1) // 此处panic 因为wait还未结束就再次复用
  9. }()
  10. wait.Wait()
  11. fmt.Println("Done")
  12. }

sync.WaitGroup 虽然可以重用,但是是有一个前提的,那就是必须等到上一轮的 Wait() 完成之后,才能重用 sync.WaitGroup 执行下一轮的 Add(delta int)/Wait() ,如果你在 Wait() 还没执行完的时候就调用下一轮 Add 方法,就有可能出现 panic。

至此sync.WaitGroup的源码解析就解析完了,可能有些地方有些理解上的错误,请各位谅解并且帮忙指出修改意见,如果这篇文章能帮到你,这是我的荣幸。