代码调用过程
首先使用循环语句,从Dataloader获取data:
for data, label in train_loader:
......
for循环会调用dataloader的__iter__(self)
方法,以此获得迭代器来遍历dataset
class DataLoader(Generic[T_co]):
...
def __iter__(self) -> '_BaseDataLoaderIter':
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()
在 __iter__(self)
方法中,dataloader 调用了 self._get_iterator()
方法,根据 num_workers 获得迭代器,并指示是进行单进程还是多进程处理。
class DataLoader(Generic[T_co]):
...
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
单进程Dataloader
为了描述更加清晰,我们先考虑单进程的代码,也就是_SingleProcessDataLoaderIter(_BaseDataLoaderIter)
类,它的父类是class _BaseDataLoaderIter(object)
class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
# 初始化赋值一些 DataLoader 参数,
# 以及用户输入合法性进行校验
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._index_sampler = loader._index_sampler
...
def __iter__(self) -> '_BaseDataLoaderIter':
return self
def _reset(self, loader, first_iter=False):
self._sampler_iter = iter(self._index_sampler) # 获得sampler迭代器
self._num_yielded = 0
self._IterableDataset_len_called = loader._IterableDataset_len_called
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
def _next_data(self):
raise NotImplementedError
def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
self._reset()
data = self._next_data() # 重点代码行,通过此获取数据
self._num_yielded += 1
...
return data
next = __next__ # Python 2 compatibility
def __len__(self) -> int:
return len(self._index_sampler) # len(_BaseDataLoaderIter) == len(self._index_sampler)
def __getstate__(self):
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
_BaseDataLoaderIter
是所有DataLoaderIter的父类。
dataloader获得了迭代器之后,for 循环需要调用 next() 来获得下一个对象,从而实现遍历。__next__()
方法中则调用_next_data()
获取数据,这里因为子类_SingleProcessDataLoaderIter
重写了该方法,所以会调用子类的方法
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
从_SingleProcessDataLoaderIter
的初始化参数可以看到,其在父类_BaseDataLoaderIter
的基础上定义了_dataset_fetcher
,并向其传入 _dataset
,_auto_collation
,_collate_fn
等参数,用于定义获取数据的方式。其具体实现会在稍后解释。
在_next_data()
被调用后,其会调用_next_index()
获取 index,并通过获得的 index 传入 _dataset_fetcher 中获取对应样本。
class DataLoader(Generic[T_co]):
...
@property
def _auto_collation(self):
return self.batch_sampler is not None
@property
def _index_sampler(self):
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
class _BaseDataLoaderIter(object):
...
def _reset(self, loader, first_iter=False):
self._sampler_iter = iter(self._index_sampler) # 获得_index_sampler迭代器
...
def _next_index(self):
# sampler_iter 来自于 index_sampler
return next(self._sampler_iter) # may raise StopIteration
从上面代码看出,dataloader 提供了 sampler(可以是batch_sampler
或者是其他sampler
子类),然后 _SingleProcessDataLoaderIter
在_next_data()
方法中调用_next_index()
方法迭代 sampler 获得索引。
DatasetFetcher
下面我们来看看 fetcher,fetcher 需要 index 来获取元素,并同时支持 Map-style dataset(对应 _MapDatasetFetcher)和 Iterable-style dataset(对应 _IterableDatasetFetcher),使其在 Dataloader 内能使用相同的接口 fetch,代码更加简洁。
对Map-Style:直接输入索引index,作为map的key,获得对应的样本(即value)
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
# 有batch_sampler,_auto_collation就为True,
# 就优先使用batch_sampler,对应在fetcher中传入的就是一个batch的索引
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
对Iterable-style:
__init__
方法内设置了dataset初始的迭代器,fetch方法内获取元素,此时index其实已经没有多大作用了class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
self.dataset_iter = iter(dataset)
def fetch(self, possibly_batched_index):
if self.auto_collation:
# 对于batch_sampler(即auto_collation==True)
# 直接使用往后遍历并提取len(possibly_batched_index)个样本(即1个batch的样本)
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
break
if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
raise StopIteration
else:
# 对于sampler,直接往后遍历并提取1个样本
data = next(self.dataset_iter)
return self.collate_fn(data)
最后,将索引index传入fetcher,fetch到想要的样本。
因此整个过程调用关系如下:
loader.iter —> self.getiterator() —> class _SingleProcessDataLoaderIter —> class _BaseDataLoaderIter —> **__next()** —> self._next_data() —> self._next_index() —>next(self._sampler_iter) 即 next(iter(self._index_sampler)) —> 获得 index —> self._dataset_fetcher.fetch(index) —> 获得 data
多进程Dataloader
而对于多进程而言,借用 PyTorch 内源码的注释,其运行流程解释如下:
# Our data model looks like this (queues are indicated with curly brackets):
#
# main process ||
# | ||
# {index_queue} ||
# | ||
# worker processes || DATA
# | ||
# {worker_result_queue} || FLOW
# | ||
# pin_memory_thread of main process || DIRECTION
# | ||
# {data_queue} ||
# | ||
# data output \/
#
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
# `pin_memory=False`.
首先dataloader 基于 multiprocessing 产生多进程,每个子进程的输入输出通过两个主要的队列
(multiprocessing.Queue()
类)产生,分别为:
- index_queue:每个子进程的队列中需要处理的任务的下标
- _worker_result_queue:返回时处理完任务的下标
- data_queue:表明经过 pin_memory 处理后的数据队列
并且有以下这些比较重要的 flag 参数来协调各个 worker 之间的工作:
- _send_idx: 发送索引,用来记录这次要放 index_queue 中 batch 的 idx
- _rcvd_idx: 接受索引,记录要从 data_queue 中取出的 batch 的 idx
- _task_info: 存储将要产生的 data 信息的 dict,key为 task idx(由 0 开始的整形索引),value 为 (worker_id,) 或 (worker_id, data),分别对应数据未取和已取的情况
- _tasks_outstanding: 整形,代表已经准备好的 task/batch 的数量(可能有些正在准备中)
每个 worker 一次产生一个 batch 的数据,返回 batch 数据前放入下一个批次要处理的数据下标,对应构造函数子进程初始化如下:
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_MultiProcessingDataLoaderIter, self).__init__(loader)
...
self._worker_result_queue = multiprocessing_context.Queue() # 把该worker取出的数放入该队列,用于进程间通信
...
self._workers_done_event = multiprocessing_context.Event()
self._index_queues = []
self._workers = []
for i in range(self._num_workers):
index_queue = multiprocessing_context.Queue() # 索引队列,每个子进程一个队列放要处理的下标
index_queue.cancel_join_thread()
# _worker_loop 的作用是:从index_queue中取索引,然后通过collate_fn处理数据,
# 然后再将处理好的 batch 数据放到 data_queue 中。(发送到队列中的idx是self.send_idx)
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop, # 每个worker子进程循环执行的函数,主要将数据以(idx, data)的方式传入_worker_result_queue中
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed + i, self._worker_init_fn, i, self._num_workers,
self._persistent_workers))
w.daemon = True
w.start()
self._index_queues.append(index_queue)
self._workers.append(w)
if self._pin_memory:
self._pin_memory_thread_done_event = threading.Event()
self._data_queue = queue.Queue() # 用于存取出的数据进行 pin_memory 操作后的结果
pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
args=(self._worker_result_queue, self._data_queue,
torch.cuda.current_device(),
self._pin_memory_thread_done_event))
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
# pin_memory_thread once it is started.
self._pin_memory_thread = pin_memory_thread
else:
self._data_queue = self._worker_result_queue
...
self._reset(loader, first_iter=True)
def _reset(self, loader, first_iter=False):
super()._reset(loader, first_iter)
self._send_idx = 0 # idx of the next task to be sent to workers,发送索引,用来记录这次要放 index_queue 中 batch 的 idx
self._rcvd_idx = 0 # idx of the next task to be returned in __next__,接受索引,记录要从 data_queue 中取出的 batch 的 idx
# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
# map: task idx => - (worker_id,) if data isn't fetched (outstanding)
# \ (worker_id, data) if data is already fetched (out-of-order)
self._task_info = {}
# _tasks_outstanding 指示当前已经准备好的 task/batch 的数量(可能有些正在准备中)
# 初始值为 0, 在 self._try_put_index() 中 +1,在 self._next_data 中-1
self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
# this indicates status that a worker still has work to do *for this epoch*.
self._workers_status = [True for i in range(self._num_workers)]
# We resume the prefetching in case it was enabled
if not first_iter:
for idx in range(self._num_workers):
self._index_queues[idx].put(_utils.worker._ResumeIteration())
resume_iteration_cnt = self._num_workers
while resume_iteration_cnt > 0:
data = self._get_data()
if isinstance(data, _utils.worker._ResumeIteration):
resume_iteration_cnt -= 1
...
# 初始化的时候,就将 2*num_workers 个 (batch_idx, sampler_indices) 放到 index_queue 中
for _ in range(self._prefetch_factor * self._num_workers):
self._try_put_index() # 进行预取
dataloader初始化的时候,每个worker的index_queue默认会放入两个batch的index,从index_queue中取出要处理的下标
def _try_put_index(self):
# self._prefetch_factor 默认为 2
assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
try:
index = self._next_index()
except StopIteration:
return
for _ in range(self._num_workers): # find the next active worker, if any
worker_queue_idx = next(self._worker_queue_idx_cycle)
if self._workers_status[worker_queue_idx]:
break
else:
# not found (i.e., didn't break)
return
self._index_queues[worker_queue_idx].put((self._send_idx, index)) # 放入 任务下标 和 数据下标
self._task_info[self._send_idx] = (worker_queue_idx,)
# _tasks_outstanding + 1,表明预备好的batch个数+1
self._tasks_outstanding += 1
# send_idx 发送索引, 记录从sample_iter中发送索引到index_queue的次数
self._send_idx += 1
调用_next_data(self)
方法进行数据读取,其中_process_data(self, data)
用于返回数据。
def _next_data(self):
while True:
while self._rcvd_idx < self._send_idx: # 确保待处理的任务(待取的batch)下标 > 处理完毕要返回的任务(已经取完的batch)下标
info = self._task_info[self._rcvd_idx]
worker_id = info[0]
if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
break
del self._task_info[self._rcvd_idx]
self._rcvd_idx += 1
else:
# no valid `self._rcvd_idx` is found (i.e., didn't break)
if not self._persistent_workers:
self._shutdown_workers()
raise StopIteration
# Now `self._rcvd_idx` is the batch index we want to fetch
# Check if the next sample has already been generated
if len(self._task_info[self._rcvd_idx]) == 2:
data = self._task_info.pop(self._rcvd_idx)[1]
return self._process_data(data)
assert not self._shutdown and self._tasks_outstanding > 0
idx, data = self._get_data() # 调用 self._try_get_data() 从 self._data_queue 中取数
self._tasks_outstanding -= 1 # 表明预备好的batch个数需要减1
if self._dataset_kind == _DatasetKind.Iterable:
# Check for _IterableDatasetStopIteration
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
if self._persistent_workers:
self._workers_status[data.worker_id] = False
else:
self._mark_worker_as_unavailable(data.worker_id)
self._try_put_index()
continue
if idx != self._rcvd_idx:
# store out-of-order samples
self._task_info[idx] += (data,)
else:
del self._task_info[idx]
return self._process_data(data) # 返回数据
def _process_data(self, data):
self._rcvd_idx += 1
self._try_put_index() # 同上,主要放入队列索引 以及 更新flag
if isinstance(data, ExceptionWrapper):
data.reraise()
return data
这样,多进程模式的 dataloader 就能通过多个 worker 的协作来共同完成数据的加载。