pytorch将cpu训练好的模型参数load到gpu上,或者gpu->cpu上

gpu -> cpu

  1. torch.load('modelparameters.pth', map_location=lambda storage, loc: storage)

pytorch 模型部分参数的加载,修改模型之后导入原有的部分参数

pretrained_dict=torch.load(model_weight)
model_dict=myNet.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
myNet.load_state_dict(model_dict)