- pytorch将cpu训练好的模型参数load到gpu上,或者gpu->cpu上">pytorch将cpu训练好的模型参数load到gpu上,或者gpu->cpu上
- pytorch 模型部分参数的加载,修改模型之后导入原有的部分参数">pytorch 模型部分参数的加载,修改模型之后导入原有的部分参数
pytorch将cpu训练好的模型参数load到gpu上,或者gpu->cpu上
gpu -> cpu
torch.load('modelparameters.pth', map_location=lambda storage, loc: storage)
pytorch 模型部分参数的加载,修改模型之后导入原有的部分参数
pretrained_dict=torch.load(model_weight)
model_dict=myNet.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
myNet.load_state_dict(model_dict)
