Iterator - 迭代器:一个接口,用来遍历容器的所有元素。
Golang 提供多种容器类型像Slice,Map 等,大多数场景下,简单的遍历就足够了,不需要用到迭代器。但是在一些特殊的场景下,迭代器更加便捷。例如:减少拷贝的次数;遍历可以被中断和恢复;Goroutines 协同遍历等。简单的说,就是给遍历加入更多的逻辑。

以下是实现迭代器的三种方式,首先我们定义一些简单的数据结构,Task 是我们处理的对象。

  1. type Task struct {
  2. id int
  3. desc string
  4. }
  5. func makeTasks() []*Task {
  6. var tasks []*Task
  7. for i := 0; i < 5; i++ {
  8. task := &Task{
  9. id: i + 100,
  10. desc: fmt.Sprintf("#%d Task", i+100),
  11. }
  12. tasks = append(tasks, task)
  13. }
  14. return tasks

一,用 Callback 实现迭代器

  1. func iterWithCallback(tasks []*Task, cb func(*Task) error) error {
  2. for _, task := range tasks {
  3. if err := cb(task); err != nil {
  4. return err
  5. }
  6. }
  7. return nil
  8. }
  9. func runIterWithCallback() {
  10. err := iterWithCallback(makeTasks(), func(task *Task) error {
  11. fmt.Println("Task ID: ", task.id)
  12. return nil
  13. })
  14. if err != nil {
  15. fmt.Println(err)
  16. }
  17. }

二,用 Next() 实现迭代器

  1. type Iterator interface {
  2. Next() bool // move to next one
  3. FilterNext() bool // move to next one with filter
  4. Task() *Task // return the current value
  5. Close() // close the iterator
  6. }
  7. func NewIterTasks(tasks []*Task) Iterator {
  8. return &sliceIter{tasks: tasks, index: -1}
  9. }
  10. func NewIterTasksWithCycle(tasks []*Task) Iterator {
  11. return &sliceIter{tasks: tasks, index: -1, cycle: true}
  12. }
  13. func NewIterTasksWithCycleFilter(tasks []*Task, filter func(*Task) bool) Iterator {
  14. return &sliceIter{tasks: tasks, index: -1, cycle: true, filter: filter}
  15. }
  16. type sliceIter struct {
  17. mu sync.Mutex
  18. tasks []*Task
  19. index int
  20. cycle bool
  21. filter func(*Task) bool
  22. }
  23. func (s *sliceIter) Next() bool {
  24. s.mu.Lock()
  25. defer s.mu.Unlock()
  26. l := len(s.tasks)
  27. // if s.tasks is nil, len(s.tasks) is 0
  28. if l == 0 {
  29. return false
  30. }
  31. s.index++
  32. if s.index == l {
  33. if s.cycle {
  34. s.index = 0
  35. } else {
  36. s.tasks = nil
  37. return false
  38. }
  39. }
  40. return true
  41. }
  42. func (s *sliceIter) FilterNext() bool {
  43. for s.Next() {
  44. if s.filter(s.Task()) {
  45. return true
  46. }
  47. }
  48. return false
  49. }
  50. func (s *sliceIter) Task() *Task {
  51. s.mu.Lock()
  52. defer s.mu.Unlock()
  53. if len(s.tasks) == 0 {
  54. return nil
  55. }
  56. return s.tasks[s.index]
  57. }
  58. func (s *sliceIter) Close() {
  59. s.mu.Lock()
  60. defer s.mu.Unlock()
  61. s.tasks = nil
  62. }

使用方式:

  1. // 1.
  2. iter := NewIterTasks(makeTasks())
  3. for iter.Next() {
  4. task := iter.Task()
  5. fmt.Println("Task ID ", task.id)
  6. }
  7. // 2.
  8. iter = NewIterTasksWithCycle(makeTasks())
  9. for iter.Next() {
  10. task := iter.Task()
  11. fmt.Println("Task ID ", task.id)
  12. }
  13. // 3.
  14. iter = NewIterTasksWithCycleFilter(makeTasks(), func(task *Task) bool {
  15. if task.id%2 == 0 {
  16. return true
  17. }
  18. return false
  19. })
  20. for iter.FilterNext() {
  21. task := iter.Task()
  22. fmt.Println("Task ID ", task.id)
  23. }

三,用 Channel 实现迭代器

实际上就是经典的生产者和消费者模型。一个或多个 Goroutines 负责生产数据,另外的一个或多个 Goroutines 负责处理数据。需要注意的是,生产的 Goroutines 退出的时候需要通知处理的 Goroutines,否则会产生 Leak。

以下是一个简单例子一个生产者和一个消费者。首先给 Task 加一个错误项,可以动态的获取错误。

  1. type TaskWithError struct {
  2. *Task
  3. Err error
  4. }
  5. func runIterWithChan() error {
  6. tC := make(chan TaskWithError)
  7. exitC := make(chan struct{})
  8. go func() {
  9. defer close(tC)
  10. for _, task := range makeTasks() {
  11. wrapTask := TaskWithError{task, nil}
  12. // processing...
  13. // perhaps it has error...
  14. // wrapTask.Err = errors.New(xxxx)
  15. select {
  16. case <-exitC:
  17. return
  18. default:
  19. }
  20. select {
  21. case <-exitC:
  22. return
  23. case tC <- wrapTask:
  24. }
  25. }
  26. }()
  27. for task := range tC {
  28. if task.Err != nil {
  29. // 如果考虑退出,应增加Channel 通知Goroutine 退出
  30. // 否则Goroutine Leak
  31. exitC <- struct{}{}
  32. return task.Err
  33. } else {
  34. fmt.Println("Task ID:", task.id)
  35. }
  36. }
  37. return nil
  38. }

参考链接: