model.py

  1. import torch.nn as nn
  2. import torch
  3. class BasicBlock(nn.Module):
  4. expansion = 1 #主分支中卷积核的个数是否发生变化
  5. def __init__(self, in_channel, out_channel, stride = 1, downsample = None):
  6. super(BasicBlock, self).__init__()
  7. self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3,
  8. stride=stride, padding=1, bias=False) # stride = 1 对应实线残差结构
  9. # output = (input - kernel_size + 2 * padding)/stride + 1
  10. self.bn1 = nn.BatchNorm2d(out_channel)
  11. self.relu = nn.ReLU()
  12. self.conv2 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3,
  13. stride=1, padding=1, bias=False)
  14. self.bn2 = nn.BatchNorm2d(out_channel)
  15. self.downsample = downsample
  16. def forward(self, x): #x - 特征矩阵
  17. identity = x #捷径分支上的输出值
  18. if self.downsample is not None:
  19. identity = self.downsample(x)
  20. out = self.conv1(x)
  21. out = self.bn1(out)
  22. out = self.relu(out)
  23. out = self.conv2(out)
  24. out = self.bn2(out)
  25. out += identity
  26. out = self.relu(out)
  27. return out # 18/34 定义完成
  28. class Bottleneck(nn.Module):
  29. expansion = 4
  30. def __init__(self, in_channel, out_channel, stride = 1, downsample = None):
  31. super(Bottleneck, self).__init__()
  32. self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1,
  33. stride=1, bias=False)
  34. self.bn1 = nn.BatchNorm2d(out_channel)
  35. self.conv2 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3,
  36. stride=stride, bias=False, padding=1) #两种不同的残差结构的步长不一致,采用传入的参数
  37. self.bn2 = nn.BatchNorm2d(out_channel) #padding大小一般设置为核大小的一半
  38. self.conv3 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel * self.expansion,
  39. kernel_size=1, stride=1, bias=False)
  40. self.bn3 = nn.BatchNorm2d(out_channel * self.expansion)
  41. self.relu = nn.ReLU(inplace=True) #inplace:是否进行覆盖运算
  42. self.downsample = downsample
  43. def forward(self, x):
  44. identity = x
  45. if self.downsample is not None:
  46. identity = self.downsample(x)
  47. out = self.conv1(x)
  48. out = self.bn1(out)
  49. out = self.relu(out)
  50. out = self.conv2(out)
  51. out = self.bn2(out)
  52. out = self.relu(out)
  53. out = self.conv3(out)
  54. out = self.bn3(out)
  55. out += identity
  56. out = self.relu(out)
  57. return out
  58. class ResNet(nn.Module):
  59. def __init__(self, block, blocks_num, num_classes=1000, include_top=True):
  60. """
  61. block:对应残差结构,根据定义的层结构,传入不同的block
  62. blocks_num:对应所使用残差结构的数目,是一个列表参数,e.g.34layers,3,4,6,3
  63. include_top:便于搭建更复杂的网络
  64. """
  65. super(ResNet, self).__init__()
  66. self.include_top = include_top
  67. in_channel = 64 #通过max_pooling后得到的特征矩阵的深度
  68. #定义第一层卷积层
  69. self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
  70. padding=3, bias=False) #out_channel:卷积核个数
  71. self.bn1 = nn.BatchNorm2d(self.in_channel)
  72. self.relu = nn.ReLU(inplace=True)
  73. #定义max pooling downsample
  74. self.maxpooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  75. self.layer1 = self._make_layer(block, 64, blocks_num[0]) #Conv2对应的一系列残差结构
  76. self.layer2 = self._make_layer(block, 128, blocks_num[0], stride=2)
  77. self.layer3 = self._make_layer(block, 256, blocks_num[0], stride=2)
  78. self.layer4 = self._make_layer(block, 512, blocks_num[0], stride=2)
  79. if self.include_top:
  80. self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)
  81. self.fc = nn.Linear(512 * block.expansion, num_classes)
  82. #初始化
  83. for m in self.modules():
  84. if isinstance(m, nn.Conv2d):
  85. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  86. def _make_layer(self, block, channel, block_num, stride=1):
  87. """
  88. channel:残差结构中主分支上第一个卷积核的个数:64,128,256,512
  89. """
  90. downsample = None
  91. if stride != 1 or self.in_channel != channel * block.expansion:
  92. #50/101/152 conv2对应的虚线残差结构只需要调整矩阵深度
  93. downsample = nn.Sequential(
  94. nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
  95. nn.BatchNorm2d(channel * block.expansion))
  96. layers = [] #空列表
  97. layers.append(block(self.in_channel, channel, downsample = downsample, stride = stride))
  98. self.in_channel = channel * block.expansion
  99. for _ in range(1, block_num):
  100. layers.append(block(self.in_channel, channel)) #输入特征矩阵的深度,主线分支上第一个卷积层的卷积核个数
  101. return nn.Sequential(*layers) #将list列表转换为非关键字参数传入
  102. def forward(self, x):
  103. x = self.conv1(x)
  104. x = self.bn1(x)
  105. x = self.relu(x)
  106. x = self.maxpool(x)
  107. x = self.layer1(x)
  108. x = self.layer2(x)
  109. x = self.layer3(x)
  110. x = self.layer4(x)
  111. if self.include_top:
  112. x = self.avgpool(x)
  113. x = torch.flatten(x, 1)
  114. x = self.fc(x)
  115. return x
  116. def resnet34(num_classes=1000, include_top=True):
  117. return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
  118. def resnet101(num_classes=1000, include_top=True):
  119. return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

train.py

  1. import torch
  2. import torch.nn as nn
  3. from torchvision import transforms, datasets
  4. import json
  5. import matplotlib.pyplot as plt
  6. import os
  7. import torch.optim as optim
  8. from model import resnet34, resnet101
  9. import torchvision.models.resnet
  10. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  11. print(device)
  12. data_transform = {
  13. "train": transforms.Compose([transforms.RandomResizedCrop(224),
  14. transforms.RandomHorizontalFlip(),
  15. transforms.ToTensor(),
  16. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
  17. "val": transforms.Compose([transforms.Resize(256),
  18. transforms.CenterCrop(224),
  19. transforms.ToTensor(),
  20. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
  21. data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path
  22. image_path = data_root + "/data_set/flower_data/" # flower data set path
  23. train_dataset = datasets.ImageFolder(root=image_path+"train",
  24. transform=data_transform["train"])
  25. train_num = len(train_dataset)
  26. # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
  27. flower_list = train_dataset.class_to_idx
  28. cla_dict = dict((val, key) for key, val in flower_list.items())
  29. # write dict into json file
  30. json_str = json.dumps(cla_dict, indent=4)
  31. with open('class_indices.json', 'w') as json_file:
  32. json_file.write(json_str)
  33. batch_size = 16
  34. train_loader = torch.utils.data.DataLoader(train_dataset,
  35. batch_size=batch_size, shuffle=True,
  36. num_workers=0)
  37. validate_dataset = datasets.ImageFolder(root=image_path + "val",
  38. transform=data_transform["val"])
  39. val_num = len(validate_dataset)
  40. validate_loader = torch.utils.data.DataLoader(validate_dataset,
  41. batch_size=batch_size, shuffle=False,
  42. num_workers=0)
  43. net = resnet34()
  44. # load pretrain weights
  45. model_weight_path = "./resnet34-pre.pth"
  46. missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)
  47. # for param in net.parameters():
  48. # param.requires_grad = False
  49. # change fc layer structure
  50. inchannel = net.fc.in_features
  51. net.fc = nn.Linear(inchannel, 5)
  52. net.to(device)
  53. loss_function = nn.CrossEntropyLoss()
  54. optimizer = optim.Adam(net.parameters(), lr=0.0001)
  55. best_acc = 0.0
  56. save_path = './resNet34.pth'
  57. for epoch in range(3):
  58. # train
  59. net.train()
  60. running_loss = 0.0
  61. for step, data in enumerate(train_loader, start=0):
  62. images, labels = data
  63. optimizer.zero_grad()
  64. logits = net(images.to(device))
  65. loss = loss_function(logits, labels.to(device))
  66. loss.backward()
  67. optimizer.step()
  68. # print statistics
  69. running_loss += loss.item()
  70. # print train process
  71. rate = (step+1)/len(train_loader)
  72. a = "*" * int(rate * 50)
  73. b = "." * int((1 - rate) * 50)
  74. print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
  75. print()
  76. # validate
  77. net.eval()
  78. acc = 0.0 # accumulate accurate number / epoch
  79. with torch.no_grad():
  80. for val_data in validate_loader:
  81. val_images, val_labels = val_data
  82. outputs = net(val_images.to(device)) # eval model only have last output layer
  83. # loss = loss_function(outputs, test_labels)
  84. predict_y = torch.max(outputs, dim=1)[1]
  85. acc += (predict_y == val_labels.to(device)).sum().item()
  86. val_accurate = acc / val_num
  87. if val_accurate > best_acc:
  88. best_acc = val_accurate
  89. torch.save(net.state_dict(), save_path)
  90. print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %
  91. (epoch + 1, running_loss / step, val_accurate))
  92. print('Finished Training')

predict.py

  1. import torch
  2. from model import resnet34
  3. from PIL import Image
  4. from torchvision import transforms
  5. import matplotlib.pyplot as plt
  6. import json
  7. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  8. data_transform = transforms.Compose(
  9. [transforms.Resize(256),
  10. transforms.CenterCrop(224),
  11. transforms.ToTensor(),
  12. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
  13. # load image
  14. img = Image.open("../tulip.jpg")
  15. plt.imshow(img)
  16. # [N, C, H, W]
  17. img = data_transform(img)
  18. # expand batch dimension
  19. img = torch.unsqueeze(img, dim=0)
  20. # read class_indict
  21. try:
  22. json_file = open('./class_indices.json', 'r')
  23. class_indict = json.load(json_file)
  24. except Exception as e:
  25. print(e)
  26. exit(-1)
  27. # create model
  28. model = resnet34(num_classes=5)
  29. # load model weights
  30. model_weight_path = "./resNet34.pth"
  31. model.load_state_dict(torch.load(model_weight_path, map_location=device))
  32. model.eval()
  33. with torch.no_grad():
  34. # predict class
  35. output = torch.squeeze(model(img))
  36. predict = torch.softmax(output, dim=0)
  37. predict_cla = torch.argmax(predict).numpy()
  38. print(class_indict[str(predict_cla)], predict[predict_cla].numpy())
  39. plt.show()