一、 现有网络模型的使用及修改

1. 使用现有的网络

  1. import torchvision
  2. # 网络中参数为默认值
  3. vgg16_false = torchvision.models.vgg16(pretrained=False)
  4. # 网络中参数为通过ImageNet训练好的
  5. vgg16_true = torchvision.models.vgg16(pretrained=True)
  6. print(vgg16_true)

输出的网络结构如下:

  1. VGG(
  2. (features): Sequential(
  3. (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  4. (1): ReLU(inplace=True)
  5. (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  6. (3): ReLU(inplace=True)
  7. (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  8. (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  9. (6): ReLU(inplace=True)
  10. (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  11. (8): ReLU(inplace=True)
  12. (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  13. (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  14. (11): ReLU(inplace=True)
  15. (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  16. (13): ReLU(inplace=True)
  17. (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  18. (15): ReLU(inplace=True)
  19. (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  20. (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  21. (18): ReLU(inplace=True)
  22. (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  23. (20): ReLU(inplace=True)
  24. (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  25. (22): ReLU(inplace=True)
  26. (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  27. (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  28. (25): ReLU(inplace=True)
  29. (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  30. (27): ReLU(inplace=True)
  31. (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  32. (29): ReLU(inplace=True)
  33. (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  34. )
  35. (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  36. (classifier): Sequential(
  37. (0): Linear(in_features=25088, out_features=4096, bias=True)
  38. (1): ReLU(inplace=True)
  39. (2): Dropout(p=0.5, inplace=False)
  40. (3): Linear(in_features=4096, out_features=4096, bias=True)
  41. (4): ReLU(inplace=True)
  42. (5): Dropout(p=0.5, inplace=False)
  43. (6): Linear(in_features=4096, out_features=1000, bias=True)
  44. )
  45. )

2. 修改VGG16网络用于对CIFAR10数据集的分类

  • 在vgg16中添加一层全连接层(1000,10)

    1. vgg16_true.add_module("add_linear", nn.Linear(1000, 10))

    网络结构如下:

    1. VGG(
    2. (features): Sequential(
    3. (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    4. (1): ReLU(inplace=True)
    5. (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    6. (3): ReLU(inplace=True)
    7. (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    8. (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    9. (6): ReLU(inplace=True)
    10. (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    11. (8): ReLU(inplace=True)
    12. (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    13. (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    14. (11): ReLU(inplace=True)
    15. (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    16. (13): ReLU(inplace=True)
    17. (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    18. (15): ReLU(inplace=True)
    19. (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    20. (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    21. (18): ReLU(inplace=True)
    22. (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    23. (20): ReLU(inplace=True)
    24. (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    25. (22): ReLU(inplace=True)
    26. (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    27. (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    28. (25): ReLU(inplace=True)
    29. (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    30. (27): ReLU(inplace=True)
    31. (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    32. (29): ReLU(inplace=True)
    33. (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    34. )
    35. (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
    36. (classifier): Sequential(
    37. (0): Linear(in_features=25088, out_features=4096, bias=True)
    38. (1): ReLU(inplace=True)
    39. (2): Dropout(p=0.5, inplace=False)
    40. (3): Linear(in_features=4096, out_features=4096, bias=True)
    41. (4): ReLU(inplace=True)
    42. (5): Dropout(p=0.5, inplace=False)
    43. (6): Linear(in_features=4096, out_features=1000, bias=True)
    44. (add_linear): Linear(in_features=1000, out_features=10, bias=True)
    45. )
    46. )
  • 在vgg16的分类层中添加一层全连接层

    1. vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))

    网络结构如下:

    1. VGG(
    2. (features): Sequential(
    3. (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    4. (1): ReLU(inplace=True)
    5. (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    6. (3): ReLU(inplace=True)
    7. (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    8. (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    9. (6): ReLU(inplace=True)
    10. (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    11. (8): ReLU(inplace=True)
    12. (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    13. (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    14. (11): ReLU(inplace=True)
    15. (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    16. (13): ReLU(inplace=True)
    17. (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    18. (15): ReLU(inplace=True)
    19. (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    20. (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    21. (18): ReLU(inplace=True)
    22. (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    23. (20): ReLU(inplace=True)
    24. (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    25. (22): ReLU(inplace=True)
    26. (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    27. (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    28. (25): ReLU(inplace=True)
    29. (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    30. (27): ReLU(inplace=True)
    31. (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    32. (29): ReLU(inplace=True)
    33. (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    34. )
    35. (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
    36. (classifier): Sequential(
    37. (0): Linear(in_features=25088, out_features=4096, bias=True)
    38. (1): ReLU(inplace=True)
    39. (2): Dropout(p=0.5, inplace=False)
    40. (3): Linear(in_features=4096, out_features=4096, bias=True)
    41. (4): ReLU(inplace=True)
    42. (5): Dropout(p=0.5, inplace=False)
    43. (6): Linear(in_features=4096, out_features=1000, bias=True)
    44. (add_linear): Linear(in_features=1000, out_features=10, bias=True)
    45. )
    46. (add_linear): Linear(in_features=1000, out_features=10, bias=True)
    47. )
  • 修改具体某一层

    1. # 修改classifier或者修改features
    2. vgg16_false.classifier[6] = nn.Linear(in_features=1000, out_features=10, bias = True)

二、模型的保存和读取

1. 方式1(同时保存模型和参数)

  • 保存

    1. import torch
    2. import torchvision
    3. vgg16 = torchvision.models.vgg16(pretrained=True)
    4. # 同时保存模型和参数
    5. torch.save(vgg16, "vgg16_method1.pth")
  • 读取

    1. import torch
    2. import torchvision
    3. # 加载模型
    4. model = torch.load("vgg16_method1.pth")
  • 方式1的陷阱

Tips:如果为自定义模型,那么加载时需要引入自定义模型

2. 方式2(只保存参数)

  • 保存

    1. import torch
    2. import torchvision
    3. vgg16 = torchvision.models.vgg16(pretrained=False)
    4. # 保存方式2,把vgg16的状态(参数)保存为字典,没有结构 【官方推荐】因为空间小
    5. torch.save(vgg16.state_dict(), "vgg16_method2.pth")
  • 读取 ```python import torch

定义模型

vgg16 = torchvision.models.vgg16(pretrained=False)

参数和结构放一起

vgg16.load_state_dict(torch.load(“vgg16_method2.pth”)) ```