我们可以在Go语言中十分便捷地开启goroutine去并发地执行任务,但是如何有效的处理并发过程中的错误则是一个很棘手的问题,本文介绍了一些处理并发错误的方法。

recover goroutine中的panic

我们知道可以在代码中使用 recover 来会恢复程序中意想不到的 panic,而 panic 只会触发当前 goroutine 中的 defer 操作。

例如在下面的示例代码中,无法在 main 函数中 recover 另一个goroutine中引发的 panic。

  1. func f1() {
  2. defer func() {
  3. if e := recover(); e != nil {
  4. fmt.Printf("recover panic:%v\n", e)
  5. }
  6. }()
  7. // 开启一个goroutine执行任务
  8. go func() {
  9. fmt.Println("in goroutine....")
  10. // 只能触发当前goroutine中的defer
  11. panic("panic in goroutine")
  12. }()
  13. time.Sleep(time.Second)
  14. fmt.Println("exit")
  15. }

执行上面的 f1 函数会得到如下结果:

  1. in goroutine....
  2. panic: panic in goroutine
  3. goroutine 6 [running]:
  4. main.f1.func2()
  5. /Users/liwenzhou/workspace/github/the-road-to-learn-golang/ch12/goroutine_recover.go:20 +0x65
  6. created by main.f1
  7. /Users/liwenzhou/workspace/github/the-road-to-learn-golang/ch12/goroutine_recover.go:17 +0x48
  8. Process finished with exit code 2

从输出结果可以看到程序并没有正常退出,而是由于 panic 异常退出了(exit code 2)。

正如上面示例演示的那样,在启用 goroutine 去执行任务的场景下,如果想要 recover goroutine中可能出现的 panic 就需要在 goroutine 中使用 recover。就像下面的 f2 函数那样。

  1. func f2() {
  2. defer func() {
  3. if r := recover(); r != nil {
  4. fmt.Printf("recover outer panic:%v\n", r)
  5. }
  6. }()
  7. // 开启一个goroutine执行任务
  8. go func() {
  9. defer func() {
  10. if r := recover(); r != nil {
  11. fmt.Printf("recover inner panic:%v\n", r)
  12. }
  13. }()
  14. fmt.Println("in goroutine....")
  15. // 只能触发当前goroutine中的defer
  16. panic("panic in goroutine")
  17. }()
  18. time.Sleep(time.Second)
  19. fmt.Println("exit")
  20. }

执行 f2 函数会得到如下输出结果。

  1. in goroutine....
  2. recover inner panic:panic in goroutine
  3. exit

程序中的 panic 被 recover 成功捕获,程序最终正常退出。

errgroup

在以往演示的并发示例中,我们通常像下面的示例代码那样在 go 关键字后,调用一个函数或匿名函数。

  1. go func(){
  2. // ...
  3. }
  4. go foo()

在之前讲解并发的代码示例中我们默认被并发的那些函数都不会返回错误,但真实的情况往往是事与愿违。

当我们想要将一个任务拆分成多个子任务交给多个 goroutine 去运行,这时我们该如何获取到子任务可能返回的错误呢?

假设我们有多个网址需要并发去获取它们的内容,这时候我们会写出类似下面的代码。

  1. // fetchUrlDemo 并发获取url内容
  2. func fetchUrlDemo() {
  3. wg := sync.WaitGroup{}
  4. var urls = []string{
  5. "http://pkg.go.dev",
  6. "http://www.liwenzhou.com",
  7. "http://www.yixieqitawangzhi.com",
  8. }
  9. for _, url := range urls {
  10. wg.Add(1)
  11. go func(url string) {
  12. defer wg.Done()
  13. resp, err := http.Get(url)
  14. if err == nil {
  15. fmt.Printf("获取%s成功\n", url)
  16. resp.Body.Close()
  17. }
  18. return // 如何将错误返回呢?
  19. }(url)
  20. }
  21. wg.Wait()
  22. // 如何获取goroutine中可能出现的错误呢?
  23. }

执行上述fetchUrlDemo函数得到如下输出结果,由于 http://www.yixieqitawangzhi.com 是我随意编造的一个并不真实存在的 url,所以对它的 HTTP 请求会返回错误。

  1. 获取http://pkg.go.dev成功
  2. 获取http://www.liwenzhou.com成功

在上面的示例代码中,我们开启了 3 个 goroutine 分别去获取3个 url 的内容。类似这种将任务分为若干个子任务的场景会有很多,那么我们如何获取子任务中可能出现的错误呢?

errgroup 包就是为了解决这类问题而开发的,它能为处理公共任务的子任务而开启的一组 goroutine 提供同步、error 传播和基于context 的取消功能。

errgroup 包中定义了一个 Group 类型,它包含了若干个不可导出的字段。

  1. type Group struct {
  2. cancel func()
  3. wg sync.WaitGroup
  4. errOnce sync.Once
  5. err error
  6. }

errgroup.Group 提供了GoWait两个方法。

  1. func (g *Group) Go(f func() error)
  • Go 函数会在新的 goroutine 中调用传入的函数f。
  • 第一个返回非零错误的调用将取消该Group;下面的Wait方法会返回该错误
  1. func (g *Group) Wait() error
  • Wait 会阻塞直至由上述 Go 方法调用的所有函数都返回,然后从它们返回第一个非nil的错误(如果有)。

下面的示例代码演示了如何使用 errgroup 包来处理多个子任务 goroutine 中可能返回的 error。

  1. // fetchUrlDemo2 使用errgroup并发获取url内容
  2. func fetchUrlDemo2() error {
  3. g := new(errgroup.Group) // 创建等待组(类似sync.WaitGroup)
  4. var urls = []string{
  5. "http://pkg.go.dev",
  6. "http://www.liwenzhou.com",
  7. "http://www.yixieqitawangzhi.com",
  8. }
  9. for _, url := range urls {
  10. url := url // 注意此处声明新的变量
  11. // 启动一个goroutine去获取url内容
  12. g.Go(func() error {
  13. resp, err := http.Get(url)
  14. if err == nil {
  15. fmt.Printf("获取%s成功\n", url)
  16. resp.Body.Close()
  17. }
  18. return err // 返回错误
  19. })
  20. }
  21. if err := g.Wait(); err != nil {
  22. // 处理可能出现的错误
  23. fmt.Println(err)
  24. return err
  25. }
  26. fmt.Println("所有goroutine均成功")
  27. return nil
  28. }

执行上面的fetchUrlDemo2函数会得到如下输出结果。

  1. 获取http://pkg.go.dev成功
  2. 获取http://www.liwenzhou.com成功
  3. Get "http://www.yixieqitawangzhi.com": dial tcp: lookup www.yixieqitawangzhi.com: no such host

当子任务的 goroutine 中对http://www.yixieqitawangzhi.com 发起 HTTP 请求时会返回一个错误,这个错误会由 errgroup.Group 的 Wait 方法返回。

通过阅读下方 errgroup.Group 的 Go 方法源码,我们可以看到当任意一个函数 f 返回错误时,会通过g.errOnce.Do只将第一个返回的错误记录,并且如果存在 cancel 方法则会调用cancel。

  1. func (g *Group) Go(f func() error) {
  2. g.wg.Add(1)
  3. go func() {
  4. defer g.wg.Done()
  5. if err := f(); err != nil {
  6. g.errOnce.Do(func() {
  7. g.err = err
  8. if g.cancel != nil {
  9. g.cancel()
  10. }
  11. })
  12. }
  13. }()
  14. }

那么如何创建带有 cancel 方法的 errgroup.Group 呢?

答案是通过 errorgroup 包提供的 WithContext 函数。

  1. func WithContext(ctx context.Context) (*Group, context.Context)

WithContext 函数接收一个父 context,返回一个新的 Group 对象和一个关联的子 context 对象。下面的代码片段是一个官方文档给出的示例。

  1. package main
  2. import (
  3. "context"
  4. "crypto/md5"
  5. "fmt"
  6. "io/ioutil"
  7. "log"
  8. "os"
  9. "path/filepath"
  10. "golang.org/x/sync/errgroup"
  11. )
  12. // Pipeline demonstrates the use of a Group to implement a multi-stage
  13. // pipeline: a version of the MD5All function with bounded parallelism from
  14. // https://blog.golang.org/pipelines.
  15. func main() {
  16. m, err := MD5All(context.Background(), ".")
  17. if err != nil {
  18. log.Fatal(err)
  19. }
  20. for k, sum := range m {
  21. fmt.Printf("%s:\t%x\n", k, sum)
  22. }
  23. }
  24. type result struct {
  25. path string
  26. sum [md5.Size]byte
  27. }
  28. // MD5All reads all the files in the file tree rooted at root and returns a map
  29. // from file path to the MD5 sum of the file's contents. If the directory walk
  30. // fails or any read operation fails, MD5All returns an error.
  31. func MD5All(ctx context.Context, root string) (map[string][md5.Size]byte, error) {
  32. // ctx is canceled when g.Wait() returns. When this version of MD5All returns
  33. // - even in case of error! - we know that all of the goroutines have finished
  34. // and the memory they were using can be garbage-collected.
  35. g, ctx := errgroup.WithContext(ctx)
  36. paths := make(chan string)
  37. g.Go(func() error {
  38. return filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
  39. if err != nil {
  40. return err
  41. }
  42. if !info.Mode().IsRegular() {
  43. return nil
  44. }
  45. select {
  46. case paths <- path:
  47. case <-ctx.Done():
  48. return ctx.Err()
  49. }
  50. return nil
  51. })
  52. })
  53. // Start a fixed number of goroutines to read and digest files.
  54. c := make(chan result)
  55. const numDigesters = 20
  56. for i := 0; i < numDigesters; i++ {
  57. g.Go(func() error {
  58. for path := range paths {
  59. data, err := ioutil.ReadFile(path)
  60. if err != nil {
  61. return err
  62. }
  63. select {
  64. case c <- result{path, md5.Sum(data)}:
  65. case <-ctx.Done():
  66. return ctx.Err()
  67. }
  68. }
  69. return nil
  70. })
  71. }
  72. go func() {
  73. g.Wait()
  74. close(c)
  75. }()
  76. m := make(map[string][md5.Size]byte)
  77. for r := range c {
  78. m[r.path] = r.sum
  79. }
  80. // Check whether any of the goroutines failed. Since g is accumulating the
  81. // errors, we don't need to send them (or check for them) in the individual
  82. // results sent on the channel.
  83. if err := g.Wait(); err != nil {
  84. return nil, err
  85. }
  86. return m, nil
  87. }

或者这里有另外一个示例。

  1. func GetFriends(ctx context.Context, user int64) (map[string]*User, error) {
  2. g, ctx := errgroup.WithContext(ctx)
  3. friendIds := make(chan int64)
  4. // Produce
  5. g.Go(func() error {
  6. defer close(friendIds)
  7. for it := GetFriendIds(user); ; {
  8. if id, err := it.Next(ctx); err != nil {
  9. if err == io.EOF {
  10. return nil
  11. }
  12. return fmt.Errorf("GetFriendIds %d: %s", user, err)
  13. } else {
  14. select {
  15. case <-ctx.Done():
  16. return ctx.Err()
  17. case friendIds <- id:
  18. }
  19. }
  20. }
  21. })
  22. friends := make(chan *User)
  23. // Map
  24. workers := int32(nWorkers)
  25. for i := 0; i < nWorkers; i++ {
  26. g.Go(func() error {
  27. defer func() {
  28. // Last one out closes shop
  29. if atomic.AddInt32(&workers, -1) == 0 {
  30. close(friends)
  31. }
  32. }()
  33. for id := range friendIds {
  34. if friend, err := GetUserProfile(ctx, id); err != nil {
  35. return fmt.Errorf("GetUserProfile %d: %s", user, err)
  36. } else {
  37. select {
  38. case <-ctx.Done():
  39. return ctx.Err()
  40. case friends <- friend:
  41. }
  42. }
  43. }
  44. return nil
  45. })
  46. }
  47. // Reduce
  48. ret := map[string]*User{}
  49. g.Go(func() error {
  50. for friend := range friends {
  51. ret[friend.Name] = friend
  52. }
  53. return nil
  54. })
  55. return ret, g.Wait()
  56. }

可惜这两个示例不太好理解。。。