现有网络模型的使用及修改

  • 示例代码:
  1. import torchvision
  2. from torch import nn
  3. vgg16_false = torchvision.models.vgg16(pretrained=False)
  4. vgg16_true = torchvision.models.vgg16(pretrained=True)
  5. print(vgg16_true)
  6. '''
  7. 在现有的网络中添加
  8. '''
  9. # 加到classifirer外面
  10. # vgg16_true.add_module('add_linear',nn.Linear(1000,10))
  11. # 加到classifier里面
  12. vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
  13. print(vgg16_true)
  14. '''
  15. 在现有的网络中修改
  16. '''
  17. # 修改classifier的下标为6的层
  18. vgg16_true.classifier[6] = nn.Linear(2048,10)
  19. print(vgg16_true)