开始的时候,我使用 pytorch 中预训练好的 Inception_v3 作为初始化,训练草图分类的模型。训练完后,导入训练好的模型进行特征提取,然后发现这时候计算的准确率居然是一片混乱,根本就不对。

检查一番之后,模型导入没有什么问题。

终于发现罪魁祸首,居然是 transform_input 这个参数

在创建模型并进行训练的时候,我是这样写的:

  1. Model = inception_v3
  2. # 1、使用 ImageNet pretrained 的模型
  3. model = Model(pretrained=False, transform_input=True) # use the pretrained model
  4. # 2、然后修改最后的一个全连接层就 OK 了
  5. model.AuxLogits = InceptionAux(768, number_class)
  6. model.fc = nn.Linear(in_features=2048, out_features=number_class, bias=True)

在进行特征提取的时候,我是这样写的:

  1. Model = inception_v3
  2. model = Model(num_classes=250) # use the pretrained model
  3. model.load_state_dict(t.load(ckpt_path))

后来发现: 如果使用 pretrained=True,则会默认设置 transform_input=True.

  1. if self.transform_input:
  2. x = x.clone()
  3. x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
  4. x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
  5. x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5

这个时候,forward函数一开始就对数据进行处理了。
所以尽管我导入的ckpt之后两个模型的参数都是完全一样的,但是由于这个参数,forward函数中一开始就对输入进行了处理,导致我最后得到的结果是不准确的。

所以正确的做法应该是

  1. Model = inception_v3
  2. model = Model(num_classes=250, transform_input=True) # use the pretrained model
  3. model.load_state_dict(t.load(ckpt_path))

关于其他模型,检查了一下:
resnet、vgg 都没有这样的问题。