通常GPU训练会用到nn.Dataparallel,所以转到cpu不能仅仅利用map_location=’cpu’。
样例:
if opt.pretrain_path:
print('loading pretrained model {}'.format(opt.pretrain_path))
pretrain = torch.load(opt.pretrain_path, map_location='cpu')
state_dict = pretrain['state_dict']
from collections import OrderedDict
state_dict_new = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # 去掉module
state_dict_new[name] = v
model.load_state_dict(state_dict_new)