开始的时候,我使用 pytorch 中预训练好的 Inception_v3 作为初始化,训练草图分类的模型。训练完后,导入训练好的模型进行特征提取,然后发现这时候计算的准确率居然是一片混乱,根本就不对。
检查一番之后,模型导入没有什么问题。
终于发现罪魁祸首,居然是 transform_input 这个参数
在创建模型并进行训练的时候,我是这样写的:
Model = inception_v3# 1、使用 ImageNet pretrained 的模型model = Model(pretrained=False, transform_input=True) # use the pretrained model# 2、然后修改最后的一个全连接层就 OK 了model.AuxLogits = InceptionAux(768, number_class)model.fc = nn.Linear(in_features=2048, out_features=number_class, bias=True)
在进行特征提取的时候,我是这样写的:
Model = inception_v3model = Model(num_classes=250) # use the pretrained modelmodel.load_state_dict(t.load(ckpt_path))
后来发现: 如果使用 pretrained=True,则会默认设置 transform_input=True.
if self.transform_input:x = x.clone()x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
这个时候,forward函数一开始就对数据进行处理了。
所以尽管我导入的ckpt之后两个模型的参数都是完全一样的,但是由于这个参数,forward函数中一开始就对输入进行了处理,导致我最后得到的结果是不准确的。
所以正确的做法应该是
Model = inception_v3model = Model(num_classes=250, transform_input=True) # use the pretrained modelmodel.load_state_dict(t.load(ckpt_path))
关于其他模型,检查了一下:
resnet、vgg 都没有这样的问题。
