import torchvision.datasetsfrom torch import nnvgg16_false = torchvision.models.vgg16(pretrained=False)# vgg16_true = torchvision.models.vgg16(pretrained=True) # 会下载已经训练好的模型参数print(vgg16_false)'''修改网络现有结构'''# 在最后添加一个线性层vgg16_false.classifier.add_module("add_linear", nn.Linear(1000, 10))print(vgg16_false)# 修改第7层vgg16_false.classifier[6] = nn.Linear(4096, 10)print(vgg16_false)
