现有网络模型的使用及修改
- 示例代码:
import torchvisionfrom torch import nnvgg16_false = torchvision.models.vgg16(pretrained=False)vgg16_true = torchvision.models.vgg16(pretrained=True)print(vgg16_true)'''在现有的网络中添加'''# 加到classifirer外面# vgg16_true.add_module('add_linear',nn.Linear(1000,10))# 加到classifier里面vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))print(vgg16_true)'''在现有的网络中修改'''# 修改classifier的下标为6的层vgg16_true.classifier[6] = nn.Linear(2048,10)print(vgg16_true)
