代码调用过程

首先使用循环语句,从Dataloader获取data:

  1. for data, label in train_loader:
  2. ......

for循环会调用dataloader的__iter__(self)方法,以此获得迭代器来遍历dataset

  1. class DataLoader(Generic[T_co]):
  2. ...
  3. def __iter__(self) -> '_BaseDataLoaderIter':
  4. if self.persistent_workers and self.num_workers > 0:
  5. if self._iterator is None:
  6. self._iterator = self._get_iterator()
  7. else:
  8. self._iterator._reset(self)
  9. return self._iterator
  10. else:
  11. return self._get_iterator()

__iter__(self)方法中,dataloader 调用了 self._get_iterator()方法,根据 num_workers 获得迭代器,并指示是进行单进程还是多进程处理。

  1. class DataLoader(Generic[T_co]):
  2. ...
  3. def _get_iterator(self) -> '_BaseDataLoaderIter':
  4. if self.num_workers == 0:
  5. return _SingleProcessDataLoaderIter(self)
  6. else:
  7. self.check_worker_number_rationality()
  8. return _MultiProcessingDataLoaderIter(self)

单进程Dataloader

为了描述更加清晰,我们先考虑单进程的代码,也就是_SingleProcessDataLoaderIter(_BaseDataLoaderIter) 类,它的父类是class _BaseDataLoaderIter(object)

  1. class _BaseDataLoaderIter(object):
  2. def __init__(self, loader: DataLoader) -> None:
  3. # 初始化赋值一些 DataLoader 参数,
  4. # 以及用户输入合法性进行校验
  5. self._dataset = loader.dataset
  6. self._dataset_kind = loader._dataset_kind
  7. self._index_sampler = loader._index_sampler
  8. ...
  9. def __iter__(self) -> '_BaseDataLoaderIter':
  10. return self
  11. def _reset(self, loader, first_iter=False):
  12. self._sampler_iter = iter(self._index_sampler) # 获得sampler迭代器
  13. self._num_yielded = 0
  14. self._IterableDataset_len_called = loader._IterableDataset_len_called
  15. def _next_index(self):
  16. return next(self._sampler_iter) # may raise StopIteration
  17. def _next_data(self):
  18. raise NotImplementedError
  19. def __next__(self) -> Any:
  20. with torch.autograd.profiler.record_function(self._profile_name):
  21. if self._sampler_iter is None:
  22. self._reset()
  23. data = self._next_data() # 重点代码行,通过此获取数据
  24. self._num_yielded += 1
  25. ...
  26. return data
  27. next = __next__ # Python 2 compatibility
  28. def __len__(self) -> int:
  29. return len(self._index_sampler) # len(_BaseDataLoaderIter) == len(self._index_sampler)
  30. def __getstate__(self):
  31. raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)

_BaseDataLoaderIter是所有DataLoaderIter的父类。
dataloader获得了迭代器之后,for 循环需要调用 next() 来获得下一个对象,从而实现遍历。__next__()方法中则调用_next_data()获取数据,这里因为子类_SingleProcessDataLoaderIter重写了该方法,所以会调用子类的方法

  1. class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
  2. def __init__(self, loader):
  3. super(_SingleProcessDataLoaderIter, self).__init__(loader)
  4. assert self._timeout == 0
  5. assert self._num_workers == 0
  6. self._dataset_fetcher = _DatasetKind.create_fetcher(
  7. self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
  8. def _next_data(self):
  9. index = self._next_index() # may raise StopIteration
  10. data = self._dataset_fetcher.fetch(index) # may raise StopIteration
  11. if self._pin_memory:
  12. data = _utils.pin_memory.pin_memory(data)
  13. return data

_SingleProcessDataLoaderIter的初始化参数可以看到,其在父类_BaseDataLoaderIter的基础上定义了_dataset_fetcher,并向其传入 _dataset_auto_collation_collate_fn等参数,用于定义获取数据的方式。其具体实现会在稍后解释。
_next_data()被调用后,其会调用_next_index()获取 index,并通过获得的 index 传入 _dataset_fetcher 中获取对应样本。

  1. class DataLoader(Generic[T_co]):
  2. ...
  3. @property
  4. def _auto_collation(self):
  5. return self.batch_sampler is not None
  6. @property
  7. def _index_sampler(self):
  8. if self._auto_collation:
  9. return self.batch_sampler
  10. else:
  11. return self.sampler
  12. class _BaseDataLoaderIter(object):
  13. ...
  14. def _reset(self, loader, first_iter=False):
  15. self._sampler_iter = iter(self._index_sampler) # 获得_index_sampler迭代器
  16. ...
  17. def _next_index(self):
  18. # sampler_iter 来自于 index_sampler
  19. return next(self._sampler_iter) # may raise StopIteration

从上面代码看出,dataloader 提供了 sampler(可以是batch_sampler或者是其他sampler子类),然后 _SingleProcessDataLoaderIter_next_data()方法中调用_next_index()方法迭代 sampler 获得索引。

DatasetFetcher

下面我们来看看 fetcher,fetcher 需要 index 来获取元素,并同时支持 Map-style dataset(对应 _MapDatasetFetcher)和 Iterable-style dataset(对应 _IterableDatasetFetcher),使其在 Dataloader 内能使用相同的接口 fetch,代码更加简洁。

  • Map-Style:直接输入索引index,作为map的key,获得对应的样本(即value)

    1. class _MapDatasetFetcher(_BaseDatasetFetcher):
    2. def __init__(self, dataset, auto_collation, collate_fn, drop_last):
    3. super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
    4. def fetch(self, possibly_batched_index):
    5. if self.auto_collation:
    6. # 有batch_sampler,_auto_collation就为True,
    7. # 就优先使用batch_sampler,对应在fetcher中传入的就是一个batch的索引
    8. data = [self.dataset[idx] for idx in possibly_batched_index]
    9. else:
    10. data = self.dataset[possibly_batched_index]
    11. return self.collate_fn(data)
  • Iterable-style__init__方法内设置了dataset初始的迭代器,fetch方法内获取元素,此时index其实已经没有多大作用了

    1. class _IterableDatasetFetcher(_BaseDatasetFetcher):
    2. def __init__(self, dataset, auto_collation, collate_fn, drop_last):
    3. super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
    4. self.dataset_iter = iter(dataset)
    5. def fetch(self, possibly_batched_index):
    6. if self.auto_collation:
    7. # 对于batch_sampler(即auto_collation==True)
    8. # 直接使用往后遍历并提取len(possibly_batched_index)个样本(即1个batch的样本)
    9. data = []
    10. for _ in possibly_batched_index:
    11. try:
    12. data.append(next(self.dataset_iter))
    13. except StopIteration:
    14. break
    15. if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
    16. raise StopIteration
    17. else:
    18. # 对于sampler,直接往后遍历并提取1个样本
    19. data = next(self.dataset_iter)
    20. return self.collate_fn(data)

    最后,将索引index传入fetcher,fetch到想要的样本。

因此整个过程调用关系如下:
loader.iter —> self.getiterator() —> class _SingleProcessDataLoaderIter —> class _BaseDataLoaderIter —> **__next()** —> self._next_data() —> self._next_index() —>next(self._sampler_iter) 即 next(iter(self._index_sampler)) —> 获得 index —> self._dataset_fetcher.fetch(index) —> 获得 data

多进程Dataloader

而对于多进程而言,借用 PyTorch 内源码的注释,其运行流程解释如下:

  1. # Our data model looks like this (queues are indicated with curly brackets):
  2. #
  3. # main process ||
  4. # | ||
  5. # {index_queue} ||
  6. # | ||
  7. # worker processes || DATA
  8. # | ||
  9. # {worker_result_queue} || FLOW
  10. # | ||
  11. # pin_memory_thread of main process || DIRECTION
  12. # | ||
  13. # {data_queue} ||
  14. # | ||
  15. # data output \/
  16. #
  17. # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
  18. # `pin_memory=False`.

首先dataloader 基于 multiprocessing 产生多进程,每个子进程的输入输出通过两个主要的队列
multiprocessing.Queue() 类)产生,分别为:

  • index_queue:每个子进程的队列中需要处理的任务的下标
  • _worker_result_queue:返回时处理完任务的下标
  • data_queue:表明经过 pin_memory 处理后的数据队列

并且有以下这些比较重要的 flag 参数来协调各个 worker 之间的工作:

  • _send_idx: 发送索引,用来记录这次要放 index_queue 中 batch 的 idx
  • _rcvd_idx: 接受索引,记录要从 data_queue 中取出的 batch 的 idx
  • _task_info: 存储将要产生的 data 信息的 dict,key为 task idx(由 0 开始的整形索引),value 为 (worker_id,) 或 (worker_id, data),分别对应数据未取和已取的情况
  • _tasks_outstanding: 整形,代表已经准备好的 task/batch 的数量(可能有些正在准备中)

每个 worker 一次产生一个 batch 的数据,返回 batch 数据前放入下一个批次要处理的数据下标,对应构造函数子进程初始化如下:

  1. class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
  2. def __init__(self, loader):
  3. super(_MultiProcessingDataLoaderIter, self).__init__(loader)
  4. ...
  5. self._worker_result_queue = multiprocessing_context.Queue() # 把该worker取出的数放入该队列,用于进程间通信
  6. ...
  7. self._workers_done_event = multiprocessing_context.Event()
  8. self._index_queues = []
  9. self._workers = []
  10. for i in range(self._num_workers):
  11. index_queue = multiprocessing_context.Queue() # 索引队列,每个子进程一个队列放要处理的下标
  12. index_queue.cancel_join_thread()
  13. # _worker_loop 的作用是:从index_queue中取索引,然后通过collate_fn处理数据,
  14. # 然后再将处理好的 batch 数据放到 data_queue 中。(发送到队列中的idx是self.send_idx)
  15. w = multiprocessing_context.Process(
  16. target=_utils.worker._worker_loop, # 每个worker子进程循环执行的函数,主要将数据以(idx, data)的方式传入_worker_result_queue中
  17. args=(self._dataset_kind, self._dataset, index_queue,
  18. self._worker_result_queue, self._workers_done_event,
  19. self._auto_collation, self._collate_fn, self._drop_last,
  20. self._base_seed + i, self._worker_init_fn, i, self._num_workers,
  21. self._persistent_workers))
  22. w.daemon = True
  23. w.start()
  24. self._index_queues.append(index_queue)
  25. self._workers.append(w)
  26. if self._pin_memory:
  27. self._pin_memory_thread_done_event = threading.Event()
  28. self._data_queue = queue.Queue() # 用于存取出的数据进行 pin_memory 操作后的结果
  29. pin_memory_thread = threading.Thread(
  30. target=_utils.pin_memory._pin_memory_loop,
  31. args=(self._worker_result_queue, self._data_queue,
  32. torch.cuda.current_device(),
  33. self._pin_memory_thread_done_event))
  34. pin_memory_thread.daemon = True
  35. pin_memory_thread.start()
  36. # Similar to workers (see comment above), we only register
  37. # pin_memory_thread once it is started.
  38. self._pin_memory_thread = pin_memory_thread
  39. else:
  40. self._data_queue = self._worker_result_queue
  41. ...
  42. self._reset(loader, first_iter=True)
  43. def _reset(self, loader, first_iter=False):
  44. super()._reset(loader, first_iter)
  45. self._send_idx = 0 # idx of the next task to be sent to workers,发送索引,用来记录这次要放 index_queue 中 batch 的 idx
  46. self._rcvd_idx = 0 # idx of the next task to be returned in __next__,接受索引,记录要从 data_queue 中取出的 batch 的 idx
  47. # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
  48. # map: task idx => - (worker_id,) if data isn't fetched (outstanding)
  49. # \ (worker_id, data) if data is already fetched (out-of-order)
  50. self._task_info = {}
  51. # _tasks_outstanding 指示当前已经准备好的 task/batch 的数量(可能有些正在准备中)
  52. # 初始值为 0, 在 self._try_put_index() 中 +1,在 self._next_data 中-1
  53. self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
  54. # this indicates status that a worker still has work to do *for this epoch*.
  55. self._workers_status = [True for i in range(self._num_workers)]
  56. # We resume the prefetching in case it was enabled
  57. if not first_iter:
  58. for idx in range(self._num_workers):
  59. self._index_queues[idx].put(_utils.worker._ResumeIteration())
  60. resume_iteration_cnt = self._num_workers
  61. while resume_iteration_cnt > 0:
  62. data = self._get_data()
  63. if isinstance(data, _utils.worker._ResumeIteration):
  64. resume_iteration_cnt -= 1
  65. ...
  66. # 初始化的时候,就将 2*num_workers 个 (batch_idx, sampler_indices) 放到 index_queue 中
  67. for _ in range(self._prefetch_factor * self._num_workers):
  68. self._try_put_index() # 进行预取

dataloader初始化的时候,每个worker的index_queue默认会放入两个batch的index,从index_queue中取出要处理的下标

  1. def _try_put_index(self):
  2. # self._prefetch_factor 默认为 2
  3. assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
  4. try:
  5. index = self._next_index()
  6. except StopIteration:
  7. return
  8. for _ in range(self._num_workers): # find the next active worker, if any
  9. worker_queue_idx = next(self._worker_queue_idx_cycle)
  10. if self._workers_status[worker_queue_idx]:
  11. break
  12. else:
  13. # not found (i.e., didn't break)
  14. return
  15. self._index_queues[worker_queue_idx].put((self._send_idx, index)) # 放入 任务下标 和 数据下标
  16. self._task_info[self._send_idx] = (worker_queue_idx,)
  17. # _tasks_outstanding + 1,表明预备好的batch个数+1
  18. self._tasks_outstanding += 1
  19. # send_idx 发送索引, 记录从sample_iter中发送索引到index_queue的次数
  20. self._send_idx += 1

调用_next_data(self)方法进行数据读取,其中_process_data(self, data)用于返回数据。

  1. def _next_data(self):
  2. while True:
  3. while self._rcvd_idx < self._send_idx: # 确保待处理的任务(待取的batch)下标 > 处理完毕要返回的任务(已经取完的batch)下标
  4. info = self._task_info[self._rcvd_idx]
  5. worker_id = info[0]
  6. if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
  7. break
  8. del self._task_info[self._rcvd_idx]
  9. self._rcvd_idx += 1
  10. else:
  11. # no valid `self._rcvd_idx` is found (i.e., didn't break)
  12. if not self._persistent_workers:
  13. self._shutdown_workers()
  14. raise StopIteration
  15. # Now `self._rcvd_idx` is the batch index we want to fetch
  16. # Check if the next sample has already been generated
  17. if len(self._task_info[self._rcvd_idx]) == 2:
  18. data = self._task_info.pop(self._rcvd_idx)[1]
  19. return self._process_data(data)
  20. assert not self._shutdown and self._tasks_outstanding > 0
  21. idx, data = self._get_data() # 调用 self._try_get_data() 从 self._data_queue 中取数
  22. self._tasks_outstanding -= 1 # 表明预备好的batch个数需要减1
  23. if self._dataset_kind == _DatasetKind.Iterable:
  24. # Check for _IterableDatasetStopIteration
  25. if isinstance(data, _utils.worker._IterableDatasetStopIteration):
  26. if self._persistent_workers:
  27. self._workers_status[data.worker_id] = False
  28. else:
  29. self._mark_worker_as_unavailable(data.worker_id)
  30. self._try_put_index()
  31. continue
  32. if idx != self._rcvd_idx:
  33. # store out-of-order samples
  34. self._task_info[idx] += (data,)
  35. else:
  36. del self._task_info[idx]
  37. return self._process_data(data) # 返回数据
  38. def _process_data(self, data):
  39. self._rcvd_idx += 1
  40. self._try_put_index() # 同上,主要放入队列索引 以及 更新flag
  41. if isinstance(data, ExceptionWrapper):
  42. data.reraise()
  43. return data

这样,多进程模式的 dataloader 就能通过多个 worker 的协作来共同完成数据的加载。