PyTorch代码重现

可关注文档中相关章节

强制确定性操作

避免使用非确定性算法
PyTorch 中,[torch.use_deterministic_algorithms()](https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms)可以强制使用确定性算法而不是非确定性算法,并且如果已知操作是非确定性的(并且没有确定性的替代方案),则会抛出错误。

设置随机数种子

  1. def seed_torch(seed=1029):
  2. random.seed(seed)
  3. os.environ['PYTHONHASHSEED'] = str(seed)
  4. np.random.seed(seed)
  5. torch.manual_seed(seed)
  6. torch.cuda.manual_seed(seed)
  7. torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
  8. torch.backends.cudnn.benchmark = False
  9. torch.backends.cudnn.deterministic = True
  10. seed_torch()

参考自https://www.zdaiot.com/MLFrameworks/Pytorch/Pytorch随机种子/

PyTorch 1.9 版本前 DataLoader 中的隐藏 BUG

具体细节可见可能 95%的人还在犯的 PyTorch 错误 - serendipity 的文章 - 知乎
解决方法可参考文档

  1. def seed_worker(worker_id):
  2. worker_seed = torch.initial_seed() % 2**32
  3. numpy.random.seed(worker_seed)
  4. random.seed(worker_seed)
  5. DataLoader(..., worker_init_fn=seed_worker)