Iterator - 迭代器:一个接口,用来遍历容器的所有元素。
Golang 提供多种容器类型像Slice,Map 等,大多数场景下,简单的遍历就足够了,不需要用到迭代器。但是在一些特殊的场景下,迭代器更加便捷。例如:减少拷贝的次数;遍历可以被中断和恢复;Goroutines 协同遍历等。简单的说,就是给遍历加入更多的逻辑。
以下是实现迭代器的三种方式,首先我们定义一些简单的数据结构,Task 是我们处理的对象。
type Task struct {
id int
desc string
}
func makeTasks() []*Task {
var tasks []*Task
for i := 0; i < 5; i++ {
task := &Task{
id: i + 100,
desc: fmt.Sprintf("#%d Task", i+100),
}
tasks = append(tasks, task)
}
return tasks
一,用 Callback 实现迭代器
func iterWithCallback(tasks []*Task, cb func(*Task) error) error {
for _, task := range tasks {
if err := cb(task); err != nil {
return err
}
}
return nil
}
func runIterWithCallback() {
err := iterWithCallback(makeTasks(), func(task *Task) error {
fmt.Println("Task ID: ", task.id)
return nil
})
if err != nil {
fmt.Println(err)
}
}
二,用 Next() 实现迭代器
type Iterator interface {
Next() bool // move to next one
FilterNext() bool // move to next one with filter
Task() *Task // return the current value
Close() // close the iterator
}
func NewIterTasks(tasks []*Task) Iterator {
return &sliceIter{tasks: tasks, index: -1}
}
func NewIterTasksWithCycle(tasks []*Task) Iterator {
return &sliceIter{tasks: tasks, index: -1, cycle: true}
}
func NewIterTasksWithCycleFilter(tasks []*Task, filter func(*Task) bool) Iterator {
return &sliceIter{tasks: tasks, index: -1, cycle: true, filter: filter}
}
type sliceIter struct {
mu sync.Mutex
tasks []*Task
index int
cycle bool
filter func(*Task) bool
}
func (s *sliceIter) Next() bool {
s.mu.Lock()
defer s.mu.Unlock()
l := len(s.tasks)
// if s.tasks is nil, len(s.tasks) is 0
if l == 0 {
return false
}
s.index++
if s.index == l {
if s.cycle {
s.index = 0
} else {
s.tasks = nil
return false
}
}
return true
}
func (s *sliceIter) FilterNext() bool {
for s.Next() {
if s.filter(s.Task()) {
return true
}
}
return false
}
func (s *sliceIter) Task() *Task {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.tasks) == 0 {
return nil
}
return s.tasks[s.index]
}
func (s *sliceIter) Close() {
s.mu.Lock()
defer s.mu.Unlock()
s.tasks = nil
}
使用方式:
// 1.
iter := NewIterTasks(makeTasks())
for iter.Next() {
task := iter.Task()
fmt.Println("Task ID ", task.id)
}
// 2.
iter = NewIterTasksWithCycle(makeTasks())
for iter.Next() {
task := iter.Task()
fmt.Println("Task ID ", task.id)
}
// 3.
iter = NewIterTasksWithCycleFilter(makeTasks(), func(task *Task) bool {
if task.id%2 == 0 {
return true
}
return false
})
for iter.FilterNext() {
task := iter.Task()
fmt.Println("Task ID ", task.id)
}
三,用 Channel 实现迭代器
实际上就是经典的生产者和消费者模型。一个或多个 Goroutines 负责生产数据,另外的一个或多个 Goroutines 负责处理数据。需要注意的是,生产的 Goroutines 退出的时候需要通知处理的 Goroutines,否则会产生 Leak。
以下是一个简单例子一个生产者和一个消费者。首先给 Task 加一个错误项,可以动态的获取错误。
type TaskWithError struct {
*Task
Err error
}
func runIterWithChan() error {
tC := make(chan TaskWithError)
exitC := make(chan struct{})
go func() {
defer close(tC)
for _, task := range makeTasks() {
wrapTask := TaskWithError{task, nil}
// processing...
// perhaps it has error...
// wrapTask.Err = errors.New(xxxx)
select {
case <-exitC:
return
default:
}
select {
case <-exitC:
return
case tC <- wrapTask:
}
}
}()
for task := range tC {
if task.Err != nil {
// 如果考虑退出,应增加Channel 通知Goroutine 退出
// 否则Goroutine Leak
exitC <- struct{}{}
return task.Err
} else {
fmt.Println("Task ID:", task.id)
}
}
return nil
}
参考链接: