原代码
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
报错
Error(s) in loading state_dict
改为
checkpoint = _load(path)
s = checkpoint["state_dict"]
from collections import OrderedDict
new_s = OrderedDict()
for k, v in s.items():
if 'module' not in k:
k = 'module.' + k
else:
k = k.replace('features.module.', 'module.features.')
new_s[k] = v
# new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
Error(s) in loading state_dict for DataParallel: · Issue #27 · bearpaw/pytorch-classification