现有网络模型的使用及修改
- 示例代码:
import torchvision
from torch import nn
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print(vgg16_true)
'''
在现有的网络中添加
'''
# 加到classifirer外面
# vgg16_true.add_module('add_linear',nn.Linear(1000,10))
# 加到classifier里面
vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
print(vgg16_true)
'''
在现有的网络中修改
'''
# 修改classifier的下标为6的层
vgg16_true.classifier[6] = nn.Linear(2048,10)
print(vgg16_true)