PyTorch代码重现
可关注文档中相关章节。
强制确定性操作
避免使用非确定性算法。
PyTorch 中,[torch.use_deterministic_algorithms()](https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms)
可以强制使用确定性算法而不是非确定性算法,并且如果已知操作是非确定性的(并且没有确定性的替代方案),则会抛出错误。
设置随机数种子
def seed_torch(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
seed_torch()
参考自https://www.zdaiot.com/MLFrameworks/Pytorch/Pytorch随机种子/
PyTorch 1.9 版本前 DataLoader 中的隐藏 BUG
具体细节可见可能 95%的人还在犯的 PyTorch 错误 - serendipity 的文章 - 知乎
解决方法可参考文档:
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
numpy.random.seed(worker_seed)
random.seed(worker_seed)
DataLoader(..., worker_init_fn=seed_worker)