代码调用过程
首先使用循环语句,从Dataloader获取data:
for data, label in train_loader:......
for循环会调用dataloader的__iter__(self)方法,以此获得迭代器来遍历dataset
class DataLoader(Generic[T_co]):...def __iter__(self) -> '_BaseDataLoaderIter':if self.persistent_workers and self.num_workers > 0:if self._iterator is None:self._iterator = self._get_iterator()else:self._iterator._reset(self)return self._iteratorelse:return self._get_iterator()
在 __iter__(self)方法中,dataloader 调用了 self._get_iterator()方法,根据 num_workers 获得迭代器,并指示是进行单进程还是多进程处理。
class DataLoader(Generic[T_co]):...def _get_iterator(self) -> '_BaseDataLoaderIter':if self.num_workers == 0:return _SingleProcessDataLoaderIter(self)else:self.check_worker_number_rationality()return _MultiProcessingDataLoaderIter(self)
单进程Dataloader
为了描述更加清晰,我们先考虑单进程的代码,也就是_SingleProcessDataLoaderIter(_BaseDataLoaderIter) 类,它的父类是class _BaseDataLoaderIter(object)
class _BaseDataLoaderIter(object):def __init__(self, loader: DataLoader) -> None:# 初始化赋值一些 DataLoader 参数,# 以及用户输入合法性进行校验self._dataset = loader.datasetself._dataset_kind = loader._dataset_kindself._index_sampler = loader._index_sampler...def __iter__(self) -> '_BaseDataLoaderIter':return selfdef _reset(self, loader, first_iter=False):self._sampler_iter = iter(self._index_sampler) # 获得sampler迭代器self._num_yielded = 0self._IterableDataset_len_called = loader._IterableDataset_len_calleddef _next_index(self):return next(self._sampler_iter) # may raise StopIterationdef _next_data(self):raise NotImplementedErrordef __next__(self) -> Any:with torch.autograd.profiler.record_function(self._profile_name):if self._sampler_iter is None:self._reset()data = self._next_data() # 重点代码行,通过此获取数据self._num_yielded += 1...return datanext = __next__ # Python 2 compatibilitydef __len__(self) -> int:return len(self._index_sampler) # len(_BaseDataLoaderIter) == len(self._index_sampler)def __getstate__(self):raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
_BaseDataLoaderIter是所有DataLoaderIter的父类。
dataloader获得了迭代器之后,for 循环需要调用 next() 来获得下一个对象,从而实现遍历。__next__()方法中则调用_next_data()获取数据,这里因为子类_SingleProcessDataLoaderIter重写了该方法,所以会调用子类的方法
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):def __init__(self, loader):super(_SingleProcessDataLoaderIter, self).__init__(loader)assert self._timeout == 0assert self._num_workers == 0self._dataset_fetcher = _DatasetKind.create_fetcher(self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)def _next_data(self):index = self._next_index() # may raise StopIterationdata = self._dataset_fetcher.fetch(index) # may raise StopIterationif self._pin_memory:data = _utils.pin_memory.pin_memory(data)return data
从_SingleProcessDataLoaderIter的初始化参数可以看到,其在父类_BaseDataLoaderIter的基础上定义了_dataset_fetcher,并向其传入 _dataset,_auto_collation,_collate_fn等参数,用于定义获取数据的方式。其具体实现会在稍后解释。
在_next_data()被调用后,其会调用_next_index()获取 index,并通过获得的 index 传入 _dataset_fetcher 中获取对应样本。
class DataLoader(Generic[T_co]):...@propertydef _auto_collation(self):return self.batch_sampler is not None@propertydef _index_sampler(self):if self._auto_collation:return self.batch_samplerelse:return self.samplerclass _BaseDataLoaderIter(object):...def _reset(self, loader, first_iter=False):self._sampler_iter = iter(self._index_sampler) # 获得_index_sampler迭代器...def _next_index(self):# sampler_iter 来自于 index_samplerreturn next(self._sampler_iter) # may raise StopIteration
从上面代码看出,dataloader 提供了 sampler(可以是batch_sampler或者是其他sampler子类),然后 _SingleProcessDataLoaderIter在_next_data()方法中调用_next_index()方法迭代 sampler 获得索引。
DatasetFetcher
下面我们来看看 fetcher,fetcher 需要 index 来获取元素,并同时支持 Map-style dataset(对应 _MapDatasetFetcher)和 Iterable-style dataset(对应 _IterableDatasetFetcher),使其在 Dataloader 内能使用相同的接口 fetch,代码更加简洁。
对Map-Style:直接输入索引index,作为map的key,获得对应的样本(即value)
class _MapDatasetFetcher(_BaseDatasetFetcher):def __init__(self, dataset, auto_collation, collate_fn, drop_last):super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)def fetch(self, possibly_batched_index):if self.auto_collation:# 有batch_sampler,_auto_collation就为True,# 就优先使用batch_sampler,对应在fetcher中传入的就是一个batch的索引data = [self.dataset[idx] for idx in possibly_batched_index]else:data = self.dataset[possibly_batched_index]return self.collate_fn(data)
对Iterable-style:
__init__方法内设置了dataset初始的迭代器,fetch方法内获取元素,此时index其实已经没有多大作用了class _IterableDatasetFetcher(_BaseDatasetFetcher):def __init__(self, dataset, auto_collation, collate_fn, drop_last):super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)self.dataset_iter = iter(dataset)def fetch(self, possibly_batched_index):if self.auto_collation:# 对于batch_sampler(即auto_collation==True)# 直接使用往后遍历并提取len(possibly_batched_index)个样本(即1个batch的样本)data = []for _ in possibly_batched_index:try:data.append(next(self.dataset_iter))except StopIteration:breakif len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):raise StopIterationelse:# 对于sampler,直接往后遍历并提取1个样本data = next(self.dataset_iter)return self.collate_fn(data)
最后,将索引index传入fetcher,fetch到想要的样本。
因此整个过程调用关系如下:
loader.iter —> self.getiterator() —> class _SingleProcessDataLoaderIter —> class _BaseDataLoaderIter —> **__next()** —> self._next_data() —> self._next_index() —>next(self._sampler_iter) 即 next(iter(self._index_sampler)) —> 获得 index —> self._dataset_fetcher.fetch(index) —> 获得 data
多进程Dataloader
而对于多进程而言,借用 PyTorch 内源码的注释,其运行流程解释如下:
# Our data model looks like this (queues are indicated with curly brackets):## main process ||# | ||# {index_queue} ||# | ||# worker processes || DATA# | ||# {worker_result_queue} || FLOW# | ||# pin_memory_thread of main process || DIRECTION# | ||# {data_queue} ||# | ||# data output \/## P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if# `pin_memory=False`.
首先dataloader 基于 multiprocessing 产生多进程,每个子进程的输入输出通过两个主要的队列
(multiprocessing.Queue() 类)产生,分别为:
- index_queue:每个子进程的队列中需要处理的任务的下标
- _worker_result_queue:返回时处理完任务的下标
- data_queue:表明经过 pin_memory 处理后的数据队列
并且有以下这些比较重要的 flag 参数来协调各个 worker 之间的工作:
- _send_idx: 发送索引,用来记录这次要放 index_queue 中 batch 的 idx
- _rcvd_idx: 接受索引,记录要从 data_queue 中取出的 batch 的 idx
- _task_info: 存储将要产生的 data 信息的 dict,key为 task idx(由 0 开始的整形索引),value 为 (worker_id,) 或 (worker_id, data),分别对应数据未取和已取的情况
- _tasks_outstanding: 整形,代表已经准备好的 task/batch 的数量(可能有些正在准备中)
每个 worker 一次产生一个 batch 的数据,返回 batch 数据前放入下一个批次要处理的数据下标,对应构造函数子进程初始化如下:
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):def __init__(self, loader):super(_MultiProcessingDataLoaderIter, self).__init__(loader)...self._worker_result_queue = multiprocessing_context.Queue() # 把该worker取出的数放入该队列,用于进程间通信...self._workers_done_event = multiprocessing_context.Event()self._index_queues = []self._workers = []for i in range(self._num_workers):index_queue = multiprocessing_context.Queue() # 索引队列,每个子进程一个队列放要处理的下标index_queue.cancel_join_thread()# _worker_loop 的作用是:从index_queue中取索引,然后通过collate_fn处理数据,# 然后再将处理好的 batch 数据放到 data_queue 中。(发送到队列中的idx是self.send_idx)w = multiprocessing_context.Process(target=_utils.worker._worker_loop, # 每个worker子进程循环执行的函数,主要将数据以(idx, data)的方式传入_worker_result_queue中args=(self._dataset_kind, self._dataset, index_queue,self._worker_result_queue, self._workers_done_event,self._auto_collation, self._collate_fn, self._drop_last,self._base_seed + i, self._worker_init_fn, i, self._num_workers,self._persistent_workers))w.daemon = Truew.start()self._index_queues.append(index_queue)self._workers.append(w)if self._pin_memory:self._pin_memory_thread_done_event = threading.Event()self._data_queue = queue.Queue() # 用于存取出的数据进行 pin_memory 操作后的结果pin_memory_thread = threading.Thread(target=_utils.pin_memory._pin_memory_loop,args=(self._worker_result_queue, self._data_queue,torch.cuda.current_device(),self._pin_memory_thread_done_event))pin_memory_thread.daemon = Truepin_memory_thread.start()# Similar to workers (see comment above), we only register# pin_memory_thread once it is started.self._pin_memory_thread = pin_memory_threadelse:self._data_queue = self._worker_result_queue...self._reset(loader, first_iter=True)def _reset(self, loader, first_iter=False):super()._reset(loader, first_iter)self._send_idx = 0 # idx of the next task to be sent to workers,发送索引,用来记录这次要放 index_queue 中 batch 的 idxself._rcvd_idx = 0 # idx of the next task to be returned in __next__,接受索引,记录要从 data_queue 中取出的 batch 的 idx# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).# map: task idx => - (worker_id,) if data isn't fetched (outstanding)# \ (worker_id, data) if data is already fetched (out-of-order)self._task_info = {}# _tasks_outstanding 指示当前已经准备好的 task/batch 的数量(可能有些正在准备中)# 初始值为 0, 在 self._try_put_index() 中 +1,在 self._next_data 中-1self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)# this indicates status that a worker still has work to do *for this epoch*.self._workers_status = [True for i in range(self._num_workers)]# We resume the prefetching in case it was enabledif not first_iter:for idx in range(self._num_workers):self._index_queues[idx].put(_utils.worker._ResumeIteration())resume_iteration_cnt = self._num_workerswhile resume_iteration_cnt > 0:data = self._get_data()if isinstance(data, _utils.worker._ResumeIteration):resume_iteration_cnt -= 1...# 初始化的时候,就将 2*num_workers 个 (batch_idx, sampler_indices) 放到 index_queue 中for _ in range(self._prefetch_factor * self._num_workers):self._try_put_index() # 进行预取
dataloader初始化的时候,每个worker的index_queue默认会放入两个batch的index,从index_queue中取出要处理的下标
def _try_put_index(self):# self._prefetch_factor 默认为 2assert self._tasks_outstanding < self._prefetch_factor * self._num_workerstry:index = self._next_index()except StopIteration:returnfor _ in range(self._num_workers): # find the next active worker, if anyworker_queue_idx = next(self._worker_queue_idx_cycle)if self._workers_status[worker_queue_idx]:breakelse:# not found (i.e., didn't break)returnself._index_queues[worker_queue_idx].put((self._send_idx, index)) # 放入 任务下标 和 数据下标self._task_info[self._send_idx] = (worker_queue_idx,)# _tasks_outstanding + 1,表明预备好的batch个数+1self._tasks_outstanding += 1# send_idx 发送索引, 记录从sample_iter中发送索引到index_queue的次数self._send_idx += 1
调用_next_data(self)方法进行数据读取,其中_process_data(self, data)用于返回数据。
def _next_data(self):while True:while self._rcvd_idx < self._send_idx: # 确保待处理的任务(待取的batch)下标 > 处理完毕要返回的任务(已经取完的batch)下标info = self._task_info[self._rcvd_idx]worker_id = info[0]if len(info) == 2 or self._workers_status[worker_id]: # has data or is still activebreakdel self._task_info[self._rcvd_idx]self._rcvd_idx += 1else:# no valid `self._rcvd_idx` is found (i.e., didn't break)if not self._persistent_workers:self._shutdown_workers()raise StopIteration# Now `self._rcvd_idx` is the batch index we want to fetch# Check if the next sample has already been generatedif len(self._task_info[self._rcvd_idx]) == 2:data = self._task_info.pop(self._rcvd_idx)[1]return self._process_data(data)assert not self._shutdown and self._tasks_outstanding > 0idx, data = self._get_data() # 调用 self._try_get_data() 从 self._data_queue 中取数self._tasks_outstanding -= 1 # 表明预备好的batch个数需要减1if self._dataset_kind == _DatasetKind.Iterable:# Check for _IterableDatasetStopIterationif isinstance(data, _utils.worker._IterableDatasetStopIteration):if self._persistent_workers:self._workers_status[data.worker_id] = Falseelse:self._mark_worker_as_unavailable(data.worker_id)self._try_put_index()continueif idx != self._rcvd_idx:# store out-of-order samplesself._task_info[idx] += (data,)else:del self._task_info[idx]return self._process_data(data) # 返回数据def _process_data(self, data):self._rcvd_idx += 1self._try_put_index() # 同上,主要放入队列索引 以及 更新flagif isinstance(data, ExceptionWrapper):data.reraise()return data
这样,多进程模式的 dataloader 就能通过多个 worker 的协作来共同完成数据的加载。
