本文基于PyTorch1.7.0,https://github.com/pytorch/pytorch/tree/v1.7.0 参考https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

DataLoader介绍

DataLoader支持

  • map-styleiterable-style类型的数据集
  • 支持单线程和多线程
  • 可自定义数据加载的顺序、自动batchingmemory pinning

    DataLoader属性

    | 属性 | 类型(optional表示可选) | 定义 | | —- | —- | —- | | dataset | Dataset | 加载的数据集 | | batch | int, optional | 每个batch的样本数 | | shuffle | bool, optional | 设置为True会使得每个epoch都会重新打乱顺序,即每个epoch加载数据的顺序都不相同,默认值是False | | sampler | Sampler or Iterable, optional | 定义从dataset读取数据的策略。如果设置了sampler,则不能设置shuffle属性 | | batch_sampler | Sampler or Iterable, optional | 和sampler类似,不同的是batch_sampler定义一个bacth的读取策略。设置了batch_sampler就不能再设置batch_size、shuffle、sampler、drop_last | | num_workers | int, optional | 加载数据的子进程数,默认值为0,表示在主进程中加载数据。在主进程中加载数据是指先加载数据,然后训练模型,则会使得加载数据和训练模型不能并行 | | collate_fn | callable, optional | 把样本列表合并成mini-batch Tensors | | pin_memory | bool, optional | 锁页内存,设置为True可以提高数据从内存转移到显存的速度,默认是False
    - When to set pin_memory to true?
    - https://developer.nvidia.com/blog/how-optimize-data-transfers-cuda-cc/
    | | drop_last | bool, optional | 当设置为True时,如果数据集的样本数不能整除batch_size,则会丢失剩余的部分 | | timeout | numeric, optional | 从子进程收集数据的有效时间 | | worker_init_fn | callable, optional | 如果设置了worker_init_fn,那么会在设置seeding后以及加载数据前,每个子进程都会调用worker_init_fn,worker_init_fn的参数是worker_id | | prefetch_factor | int, optional | 每个子进程预加载的样本数 | | persistent_workers | bool, optional | 在数据加载完后是否保存子进程 |

DataLoader方法

init

init()参数有效性检查以及属性初始化

  1. [https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/dataloader.py#L154]
  2. class DataLoader(Generic[T_co]):
  3. ...
  4. def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
  5. shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
  6. batch_sampler: Optional[Sampler[Sequence[int]]] = None,
  7. num_workers: int = 0, collate_fn: _collate_fn_t = None,
  8. pin_memory: bool = False, drop_last: bool = False,
  9. timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,
  10. multiprocessing_context=None, generator=None,
  11. *, prefetch_factor: int = 2,
  12. persistent_workers: bool = False):
  13. torch._C._log_api_usage_once("python.data_loader") # type: ignore
  14. if num_workers < 0:
  15. raise ValueError('num_workers option should be non-negative; '
  16. 'use num_workers=0 to disable multiprocessing.')
  17. if timeout < 0:
  18. raise ValueError('timeout option should be non-negative')
  19. if num_workers == 0 and prefetch_factor != 2:
  20. raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
  21. 'let num_workers > 0 to enable multiprocessing.')
  22. assert prefetch_factor > 0
  23. if persistent_workers and num_workers == 0:
  24. raise ValueError('persistent_workers option needs num_workers > 0')
  25. self.dataset = dataset
  26. self.num_workers = num_workers
  27. self.prefetch_factor = prefetch_factor
  28. self.pin_memory = pin_memory
  29. self.timeout = timeout
  30. self.worker_init_fn = worker_init_fn
  31. self.multiprocessing_context = multiprocessing_context
  32. if isinstance(dataset, IterableDataset):
  33. self._dataset_kind = _DatasetKind.Iterable
  34. if shuffle is not False:
  35. raise ValueError(
  36. "DataLoader with IterableDataset: expected unspecified "
  37. "shuffle option, but got shuffle={}".format(shuffle))
  38. elif sampler is not None:
  39. # See NOTE [ Custom Samplers and IterableDataset ]
  40. raise ValueError(
  41. "DataLoader with IterableDataset: expected unspecified "
  42. "sampler option, but got sampler={}".format(sampler))
  43. elif batch_sampler is not None:
  44. # See NOTE [ Custom Samplers and IterableDataset ]
  45. raise ValueError(
  46. "DataLoader with IterableDataset: expected unspecified "
  47. "batch_sampler option, but got batch_sampler={}".format(batch_sampler))
  48. else:
  49. self._dataset_kind = _DatasetKind.Map
  50. if sampler is not None and shuffle:
  51. raise ValueError('sampler option is mutually exclusive with '
  52. 'shuffle')
  53. if batch_sampler is not None:
  54. # auto_collation with custom batch_sampler
  55. if batch_size != 1 or shuffle or sampler is not None or drop_last:
  56. raise ValueError('batch_sampler option is mutually exclusive '
  57. 'with batch_size, shuffle, sampler, and '
  58. 'drop_last')
  59. batch_size = None
  60. drop_last = False
  61. elif batch_size is None:
  62. # no auto_collation
  63. if drop_last:
  64. raise ValueError('batch_size=None option disables auto-batching '
  65. 'and is mutually exclusive with drop_last')
  66. if sampler is None: # give default samplers
  67. if self._dataset_kind == _DatasetKind.Iterable:
  68. # See NOTE [ Custom Samplers and IterableDataset ]
  69. sampler = _InfiniteConstantSampler()
  70. else: # map-style
  71. if shuffle:
  72. # Cannot statically verify that dataset is Sized
  73. # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  74. sampler = RandomSampler(dataset, generator=generator) # type: ignore
  75. else:
  76. sampler = SequentialSampler(dataset)
  77. if batch_size is not None and batch_sampler is None:
  78. # auto_collation without custom batch_sampler
  79. batch_sampler = BatchSampler(sampler, batch_size, drop_last)
  80. self.batch_size = batch_size
  81. self.drop_last = drop_last
  82. self.sampler = sampler
  83. self.batch_sampler = batch_sampler
  84. self.generator = generator
  85. if collate_fn is None:
  86. if self._auto_collation:
  87. collate_fn = _utils.collate.default_collate
  88. else:
  89. collate_fn = _utils.collate.default_convert
  90. self.collate_fn = collate_fn
  91. self.persistent_workers = persistent_workers
  92. self.__initialized = True
  93. self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ]
  94. self._iterator = None

iter

iter如何被调用

dataloader = DataLoader(...)
dataiter = iter(dataloader)  # 相当于调用__iter__
for data in dataloader:  # 这种方式也会调用__iter__
    do_something(data)
"""
for data in dataloader:
    do_something(data)
可以改写为
dataiter = iter(dataloader)
while True:
    try:
        data = next(dataiter)  # next(dataiter)其实是dataiter.__next__()
    except StopIteration:
        break
    do_something(data)
因此,迭代的时候本质是调用__next__
"""

iter分析

[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/dataloader.py#L339]
class DataLoader(Generic[T_co]):
    ...
    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            return _MultiProcessingDataLoaderIter(self)
    ...
    # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
    # since '_BaseDataLoaderIter' references 'DataLoader'.
    def __iter__(self) -> '_BaseDataLoaderIter':
        # When using a single worker the returned iterator should be
        # created everytime to avoid reseting its state
        # However, in the case of a multiple workers iterator
        # the iterator is only created once in the lifetime of the
        # DataLoader object so that workers can be reused
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()
  1. self.num_workers == 0

    每次调用___iter__都会返回新的_SingleProcessDataLoaderIter对象

  2. self.persistent_workers为False并且self.num_workers > 0

    每次调用___iter__都会返回新的_MultiProcessingDataLoaderIter对象

  3. self.persistent_workers为True并且self.num_workers > 0

    使用子进程加载数据并且保留子进程,第一次调用__iter__会返回_MultiProcessingDataLoaderIter对象,后面则只是重置_MultiProcessingDataLoaderIter对象,这样子进程就可以重用

    _BaseDataLoaderIter

    _BaseDataLoaderIter是基类,_SingleProcessDataLoaderIter_MultiProcessingDataLoaderIter
    都继承自_BaseDataLoaderIter

    [https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/dataloader.py#L400]
    class _BaseDataLoaderIter(object):
     def __init__(self, loader: DataLoader) -> None:
         self._dataset = loader.dataset
         self._dataset_kind = loader._dataset_kind
         self._IterableDataset_len_called = loader._IterableDataset_len_called
         self._auto_collation = loader._auto_collation
         self._drop_last = loader.drop_last
         self._index_sampler = loader._index_sampler
         self._num_workers = loader.num_workers
         self._prefetch_factor = loader.prefetch_factor
         self._pin_memory = loader.pin_memory and torch.cuda.is_available()
         self._timeout = loader.timeout
         self._collate_fn = loader.collate_fn
         self._sampler_iter = iter(self._index_sampler)
         self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
         self._persistent_workers = loader.persistent_workers
         self._num_yielded = 0
    
     def __iter__(self) -> '_BaseDataLoaderIter':
         return self
    
     def _reset(self, loader, first_iter=False):
         self._sampler_iter = iter(self._index_sampler)
         self._num_yielded = 0
         self._IterableDataset_len_called = loader._IterableDataset_len_called
    
     def _next_index(self):
         return next(self._sampler_iter)  # may raise StopIteration
    
     def _next_data(self):
         raise NotImplementedError
    
     def __next__(self) -> Any:
         if self._sampler_iter is None:
             self._reset()
         data = self._next_data()
         self._num_yielded += 1
         if self._dataset_kind == _DatasetKind.Iterable and \
                 self._IterableDataset_len_called is not None and \
                 self._num_yielded > self._IterableDataset_len_called:
             warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
                         "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
                                                               self._num_yielded)
             if self._num_workers > 0:
                 warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
                              "IterableDataset replica at each worker. Please see "
                              "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
             warnings.warn(warn_msg)
         return data
    
     next = __next__  # Python 2 compatibility
    
     def __len__(self) -> int:
         return len(self._index_sampler)
    
     def __getstate__(self):
         raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
    

    _SingleProcessDataLoaderIter

    需要重载nextdata,_next_data会被__next调用

    [https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/dataloader.py#L464]
    class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
     def __init__(self, loader):
         super(_SingleProcessDataLoaderIter, self).__init__(loader)
         assert self._timeout == 0
         assert self._num_workers == 0
    
         self._dataset_fetcher = _DatasetKind.create_fetcher(
             self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
    
     def _next_data(self):
         index = self._next_index()  # may raise StopIteration
         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
         if self._pin_memory:
             data = _utils.pin_memory.pin_memory(data)
         return data
    

    _MultiProcessingDataLoaderIter

    需要重载nextdata,_next_data会被__next调用

    # Our data model looks like this (queues are indicated with curly brackets): # # main process || # | || # {index_queue} || # | || # worker processes || DATA # | || # {worker_result_queue} || FLOW # | || # pin_memory_thread of main process || DIRECTION # | || # {data_queue} || # | || # data output \/

因为多进程加载数据的逻辑相对繁琐,下面先介绍init,然后再介绍nextdata。
**__init
**

[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/dataloader.py#L481]
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):   
    def __init__(self, loader):
        super(_MultiProcessingDataLoaderIter, self).__init__(loader)

        assert self._num_workers > 0
        assert self._prefetch_factor > 0

        if loader.multiprocessing_context is None:
            multiprocessing_context = multiprocessing
        else:
            multiprocessing_context = loader.multiprocessing_context

        self._worker_init_fn = loader.worker_init_fn
         # cycle([0, 1, 2]) -> 0 1 2 0 1 2 ...
        self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
        # No certainty which module multiprocessing_context is
        self._worker_result_queue = multiprocessing_context.Queue()  # type: ignore
        self._worker_pids_set = False
        self._shutdown = False
        self._workers_done_event = multiprocessing_context.Event()

        self._index_queues = []
        self._workers = []
        for i in range(self._num_workers):
            # No certainty which module multiprocessing_context is
            index_queue = multiprocessing_context.Queue()  # type: ignore
            # index_queue.cancel_join_thread()
            w = multiprocessing_context.Process(
                target=_utils.worker._worker_loop,
                args=(self._dataset_kind, self._dataset, index_queue,
                      self._worker_result_queue, self._workers_done_event,
                      self._auto_collation, self._collate_fn, self._drop_last,
                      self._base_seed + i, self._worker_init_fn, i, self._num_workers,
                      self._persistent_workers))
            w.daemon = True
            # NB: Process.start() actually take some time as it needs to
            #     start a process and pass the arguments over via a pipe.
            #     Therefore, we only add a worker to self._workers list after
            #     it started, so that we do not call .join() if program dies
            #     before it starts, and __del__ tries to join but will get:
            #     AssertionError: can only join a started process.
            w.start()
            self._index_queues.append(index_queue)
            self._workers.append(w)

        if self._pin_memory:
            self._pin_memory_thread_done_event = threading.Event()

            # Queue is not type-annotated
            self._data_queue = queue.Queue()  # type: ignore
            pin_memory_thread = threading.Thread(
                target=_utils.pin_memory._pin_memory_loop,
                args=(self._worker_result_queue, self._data_queue,
                      torch.cuda.current_device(),
                      self._pin_memory_thread_done_event))
            pin_memory_thread.daemon = True
            pin_memory_thread.start()
            # Similar to workers (see comment above), we only register
            # pin_memory_thread once it is started.
            self._pin_memory_thread = pin_memory_thread
        else:
            self._data_queue = self._worker_result_queue

        # .pid can be None only before process is spawned (not the case, so ignore)
        _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))  # type: ignore
        _utils.signal_handling._set_SIGCHLD_handler()
        self._worker_pids_set = True
        self._reset(loader, first_iter=True)

    def _reset(self, loader, first_iter=False):
        super()._reset(loader, first_iter)
        self._send_idx = 0  # idx of the next task to be sent to workers
        self._rcvd_idx = 0  # idx of the next task to be returned in __next__
        # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
        # map: task idx => - (worker_id,)        if data isn't fetched (outstanding)
        #                  \ (worker_id, data)   if data is already fetched (out-of-order)
        self._task_info = {}
        self._tasks_outstanding = 0  # always equal to count(v for v in task_info.values() if len(v) == 1)
        # A list of booleans representing whether each worker still has work to
        # do, i.e., not having exhausted its iterable dataset object. It always
        # contains all `True`s if not using an iterable-style dataset
        # (i.e., if kind != Iterable).
        # Not that this indicates that a worker still has work to do *for this epoch*.
        # It does not mean that a worker is dead. In case of `_persistent_workers`, 
        # the worker will be reset to available in the next epoch.
        self._workers_status = [True for i in range(self._num_workers)]
        # We resume the prefetching in case it was enabled
        if not first_iter:
            for idx in range(self._num_workers):
                self._index_queues[idx].put(_utils.worker._ResumeIteration())
            resume_iteration_cnt = self._num_workers
            while resume_iteration_cnt > 0:
                data = self._get_data()
                if isinstance(data, _utils.worker._ResumeIteration):
                    resume_iteration_cnt -= 1
        # prime the prefetch loop
        for _ in range(self._prefetch_factor * self._num_workers):
            self._try_put_index()

_worker_loop

def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
                 auto_collation, collate_fn, drop_last, seed, init_fn, worker_id,
                 num_workers, persistent_workers):
    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
    # logic of this function.

    try:
        # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
        # module's handlers are executed after Python returns from C low-level
        # handlers, likely when the same fatal signal had already happened
        # again.
        # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
        signal_handling._set_worker_signal_handlers()

        torch.set_num_threads(1)
        random.seed(seed)
        torch.manual_seed(seed)

        global _worker_info
        _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
                                  seed=seed, dataset=dataset)

        from torch.utils.data import _DatasetKind

        init_exception = None

        try:
            if init_fn is not None:
                init_fn(worker_id)

            fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
        except Exception:
            init_exception = ExceptionWrapper(
                where="in DataLoader worker process {}".format(worker_id))

        # When using Iterable mode, some worker can exit earlier than others due
        # to the IterableDataset behaving differently for different workers.
        # When such things happen, an `_IterableDatasetStopIteration` object is
        # sent over to the main process with the ID of this worker, so that the
        # main process won't send more tasks to this worker, and will send
        # `None` to this worker to properly exit it.
        #
        # Note that we cannot set `done_event` from a worker as it is shared
        # among all processes. Instead, we set the `iteration_end` flag to
        # signify that the iterator is exhausted. When either `done_event` or
        # `iteration_end` is set, we skip all processing step and just wait for
        # `None`.
        iteration_end = False

        watchdog = ManagerWatchdog()

        while watchdog.is_alive():
            try:
                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
            except queue.Empty:
                continue
            if isinstance(r, _ResumeIteration):
                # Acknowledge the main process
                data_queue.put(r)
                iteration_end = False
                # Recreate the fetcher for worker-reuse policy
                fetcher = _DatasetKind.create_fetcher(
                    dataset_kind, dataset, auto_collation, collate_fn, drop_last)
                continue
            elif r is None:
                # Received the final signal
                assert done_event.is_set() or iteration_end
                break
            elif done_event.is_set() or iteration_end:
                # `done_event` is set. But I haven't received the final signal
                # (None) yet. I will keep continuing until get it, and skip the
                # processing steps.
                continue
            idx, index = r  # idx是_send_idx,index是数据集的索引
            data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
            if init_exception is not None:
                data = init_exception
                init_exception = None
            else:
                try:
                    data = fetcher.fetch(index)  # 读取数据
                except Exception as e:
                    if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
                        data = _IterableDatasetStopIteration(worker_id)
                        # Set `iteration_end`
                        #   (1) to save future `next(...)` calls, and
                        #   (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
                        iteration_end = True
                    else:
                        # It is important that we don't store exc_info in a variable.
                        # `ExceptionWrapper` does the correct thing.
                        # See NOTE [ Python Traceback Reference Cycle Problem ]
                        data = ExceptionWrapper(
                            where="in DataLoader worker process {}".format(worker_id))
            data_queue.put((idx, data))  # 把数据
            del data, idx, index, r  # save memory
    except KeyboardInterrupt:
        # Main process will raise KeyboardInterrupt anyways.
        pass
    if done_event.is_set():
        data_queue.cancel_join_thread()
        data_queue.close()

_next_data
上面说到,当调用next()时,其真正调用的是_next_data(),我们将分析_next_data。

[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/dataloader.py#L1038]
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
    def _next_data(self):
        while True:
            # If the worker responsible for `self._rcvd_idx` has already ended
            # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
            # we try to advance `self._rcvd_idx` to find the next valid index.
            #
            # This part needs to run in the loop because both the `self._get_data()`
            # call and `_IterableDatasetStopIteration` check below can mark
            # extra worker(s) as dead.
            while self._rcvd_idx < self._send_idx:
                info = self._task_info[self._rcvd_idx]
                worker_id = info[0]
                if len(info) == 2 or self._workers_status[worker_id]:  # has data or is still active
                    break
                del self._task_info[self._rcvd_idx]
                self._rcvd_idx += 1
            else:
                # no valid `self._rcvd_idx` is found (i.e., didn't break)
                if not self._persistent_workers:
                    self._shutdown_workers()
                raise StopIteration

            # Now `self._rcvd_idx` is the batch index we want to fetch

            # Check if the next sample has already been generated
            if len(self._task_info[self._rcvd_idx]) == 2:
                data = self._task_info.pop(self._rcvd_idx)[1]
                return self._process_data(data)

            assert not self._shutdown and self._tasks_outstanding > 0
            idx, data = self._get_data()
            self._tasks_outstanding -= 1
            if self._dataset_kind == _DatasetKind.Iterable:
                # Check for _IterableDatasetStopIteration
                if isinstance(data, _utils.worker._IterableDatasetStopIteration):
                    if self._persistent_workers:
                        self._workers_status[data.worker_id] = False
                    else:
                        self._mark_worker_as_unavailable(data.worker_id)
                    self._try_put_index()
                    continue

            if idx != self._rcvd_idx:
                # store out-of-order samples
                self._task_info[idx] += (data,)
            else:
                del self._task_info[idx]
                return self._process_data(data)
    def _try_put_index(self):

        assert self._tasks_outstanding < self._prefetch_factor * self._num_workers

        try:
            index = self._next_index()
        except StopIteration:
            return
        # 找到下一个active worker
        for _ in range(self._num_workers):  # find the next active worker, if any
            worker_queue_idx = next(self._worker_queue_idx_cycle)
            if self._workers_status[worker_queue_idx]:
                break
        else:
            # not found (i.e., didn't break)
            return
        # 把(self._send_idx, index)放到worker_queue_idx对应的队列
        self._index_queues[worker_queue_idx].put((self._send_idx, index))
        self._task_info[self._send_idx] = (worker_queue_idx,)
        self._tasks_outstanding += 1
        self._send_idx += 1

    def _process_data(self, data):
        self._rcvd_idx += 1
        self._try_put_index()
        if isinstance(data, ExceptionWrapper):
            data.reraise()
        return data

_get_data

[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/dataloader.py#L1005]
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
    def _get_data(self):
        # Fetches data from `self._data_queue`.
        #
        # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
        # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
        # in a loop. This is the only mechanism to detect worker failures for
        # Windows. For other platforms, a SIGCHLD handler is also used for
        # worker failure detection.
        #
        # If `pin_memory=True`, we also need check if `pin_memory_thread` had
        # died at timeouts.
        if self._timeout > 0:
            success, data = self._try_get_data(self._timeout)
            if success:
                return data
            else:
                raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
        elif self._pin_memory:
            while self._pin_memory_thread.is_alive():
                success, data = self._try_get_data()
                if success:
                    return data
            else:
                # while condition is false, i.e., pin_memory_thread died.
                raise RuntimeError('Pin memory thread exited unexpectedly')
            # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
            # need to call `.task_done()` because we don't use `.join()`.
        else:
            while True:
                success, data = self._try_get_data()
                if success:
                    return data

_try_get_data

[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/utils/data/dataloader.py#L859]
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
    def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
        # Tries to fetch data from `self._data_queue` once for a given timeout.
        # This can also be used as inner loop of fetching without timeout, with
        # the sender status as the loop condition.
        #
        # This raises a `RuntimeError` if any worker died expectedly. This error
        # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
        # (only for non-Windows platforms), or the manual check below on errors
        # and timeouts.
        #
        # Returns a 2-tuple:
        #   (bool: whether successfully get data, any: data if successful else None)
        try:
            data = self._data_queue.get(timeout=timeout)
            return (True, data)
        except Exception as e:
            # At timeout and error, we manually check whether any worker has
            # failed. Note that this is the only mechanism for Windows to detect
            # worker failures.
            failed_workers = []
            for worker_id, w in enumerate(self._workers):
                if self._workers_status[worker_id] and not w.is_alive():
                    failed_workers.append(w)
                    self._mark_worker_as_unavailable(worker_id)
            if len(failed_workers) > 0:
                pids_str = ', '.join(str(w.pid) for w in failed_workers)
                raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
            if isinstance(e, queue.Empty):
                return (False, None)
            import tempfile
            import errno
            try:
                # Raise an exception if we are this close to the FDs limit.
                # Apparently, trying to open only one file is not a sufficient
                # test.
                # See NOTE [ DataLoader on Linux and open files limit ]
                fds_limit_margin = 10
                fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
            except OSError as e:
                if e.errno == errno.EMFILE:
                    raise RuntimeError(
                        "Too many open files. Communication with the"
                        " workers is no longer possible. Please increase the"
                        " limit using `ulimit -n` in the shell or change the"
                        " sharing strategy by calling"
                        " `torch.multiprocessing.set_sharing_strategy('file_system')`"
                        " at the beginning of your code") from None
            raise