16.practice_modules_1

torchvision中包含很多现有的模型,涵盖分类,检测,语义分割和视频检测

https://pytorch.org/vision/stable/models.html

我们使用vgg16的模型来用于CIFAR10数据集

通常现有的很多模型都是将vgg16看做一个前置模型,让其提取一些特征再在后面写网络组成一个完整的部分

先来看下载vgg16模型时。pretrained参数设置不同会有何区别

pretrained = False:只下载模型的各层,各层的参数是随机的初始参数(权重)

pretrained =True :下载模型的各层,同时包含各层已训练好的参数(权重) 即预训练好的模型

  1. vgg16_pretrained = torchvision.models.vgg16(pretrained=True)
  2. print(vgg16_pretrained)#查看下载好的vgg16的网络模型

可以查看到最后线性层输出1000,即分类可以分为1000个类别。

而CIFAR10数据集只有10个类别,因此我们需要在网络上进行修改,让vgg16网络最后输出为10

两种方法:

①网络最后加一个线性层,输入为1000,输出为10

加线性层可以直接加载vgg16网络后面,也可以加载其中classifier后面

直接加在网络最后:

  1. vgg16_pretrained.add_module("add_linear",nn.Linear(1000,10))

加载classifier最后:

  1. vgg16_pretrained.classifier.add_module("add_linear",nn.Linear(1000,10))

②改动网络,使最后一层输出10

  1. vgg16_pretrained.classifier[6] = nn.Linear(4096,10)