去掉最后两层

  1. import torchvision.models as models
  2. from torch import nn
  3. net = models.resnet18(pretrained=True)
  4. net2 = nn.Sequential(*list(net.children())[:-2])

修改最后一层

  1. import torchvision.models as models
  2. from torch import nn
  3. net = models.resnet18(pretrained=True)
  4. print(net) #查看最后一层的情况
  5. net.fc = nn.Linear(512, 2) #将最后一层改为二分类

参考这三篇

详情点图访问链接(主要看这篇)
image.png

详情点图访问链接
image.png

详情点图访问链接
image.png