原代码

    1. checkpoint = _load(path)
    2. s = checkpoint["state_dict"]
    3. new_s = {}
    4. for k, v in s.items():
    5. new_s[k.replace('module.', '')] = v
    6. model.load_state_dict(new_s)

    报错
    Error(s) in loading state_dict

    改为

    1. checkpoint = _load(path)
    2. s = checkpoint["state_dict"]
    3. from collections import OrderedDict
    4. new_s = OrderedDict()
    5. for k, v in s.items():
    6. if 'module' not in k:
    7. k = 'module.' + k
    8. else:
    9. k = k.replace('features.module.', 'module.features.')
    10. new_s[k] = v
    11. # new_s[k.replace('module.', '')] = v
    12. model.load_state_dict(new_s)

    Error(s) in loading state_dict for DataParallel: · Issue #27 · bearpaw/pytorch-classification