16.practice_modules_1
torchvision中包含很多现有的模型,涵盖分类,检测,语义分割和视频检测
https://pytorch.org/vision/stable/models.html
我们使用vgg16的模型来用于CIFAR10数据集
通常现有的很多模型都是将vgg16看做一个前置模型,让其提取一些特征再在后面写网络组成一个完整的部分
先来看下载vgg16模型时。pretrained参数设置不同会有何区别
pretrained = False:只下载模型的各层,各层的参数是随机的初始参数(权重)
pretrained =True :下载模型的各层,同时包含各层已训练好的参数(权重) 即预训练好的模型
vgg16_pretrained = torchvision.models.vgg16(pretrained=True)print(vgg16_pretrained)#查看下载好的vgg16的网络模型
可以查看到最后线性层输出1000,即分类可以分为1000个类别。
而CIFAR10数据集只有10个类别,因此我们需要在网络上进行修改,让vgg16网络最后输出为10
两种方法:
①网络最后加一个线性层,输入为1000,输出为10
加线性层可以直接加载vgg16网络后面,也可以加载其中classifier后面
直接加在网络最后:
vgg16_pretrained.add_module("add_linear",nn.Linear(1000,10))
加载classifier最后:
vgg16_pretrained.classifier.add_module("add_linear",nn.Linear(1000,10))
②改动网络,使最后一层输出10
vgg16_pretrained.classifier[6] = nn.Linear(4096,10)
