model.py
import torch.nn as nnimport torchclass BasicBlock(nn.Module): expansion = 1 #主分支中卷积核的个数是否发生变化 def __init__(self, in_channel, out_channel, stride = 1, downsample = None): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1, bias=False) # stride = 1 对应实线残差结构 # output = (input - kernel_size + 2 * padding)/stride + 1 self.bn1 = nn.BatchNorm2d(out_channel) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channel) self.downsample = downsample def forward(self, x): #x - 特征矩阵 identity = x #捷径分支上的输出值 if self.downsample is not None: identity = self.downsample(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += identity out = self.relu(out) return out # 18/34 定义完成class Bottleneck(nn.Module): expansion = 4 def __init__(self, in_channel, out_channel, stride = 1, downsample = None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, stride=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channel) self.conv2 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride, bias=False, padding=1) #两种不同的残差结构的步长不一致,采用传入的参数 self.bn2 = nn.BatchNorm2d(out_channel) #padding大小一般设置为核大小的一半 self.conv3 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel * self.expansion, kernel_size=1, stride=1, bias=False) self.bn3 = nn.BatchNorm2d(out_channel * self.expansion) self.relu = nn.ReLU(inplace=True) #inplace:是否进行覆盖运算 self.downsample = downsample def forward(self, x): identity = x if self.downsample is not None: identity = self.downsample(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) out += identity out = self.relu(out) return outclass ResNet(nn.Module): def __init__(self, block, blocks_num, num_classes=1000, include_top=True): """ block:对应残差结构,根据定义的层结构,传入不同的block blocks_num:对应所使用残差结构的数目,是一个列表参数,e.g.34layers,3,4,6,3 include_top:便于搭建更复杂的网络 """ super(ResNet, self).__init__() self.include_top = include_top in_channel = 64 #通过max_pooling后得到的特征矩阵的深度 #定义第一层卷积层 self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, padding=3, bias=False) #out_channel:卷积核个数 self.bn1 = nn.BatchNorm2d(self.in_channel) self.relu = nn.ReLU(inplace=True) #定义max pooling downsample self.maxpooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, blocks_num[0]) #Conv2对应的一系列残差结构 self.layer2 = self._make_layer(block, 128, blocks_num[0], stride=2) self.layer3 = self._make_layer(block, 256, blocks_num[0], stride=2) self.layer4 = self._make_layer(block, 512, blocks_num[0], stride=2) if self.include_top: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1) self.fc = nn.Linear(512 * block.expansion, num_classes) #初始化 for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') def _make_layer(self, block, channel, block_num, stride=1): """ channel:残差结构中主分支上第一个卷积核的个数:64,128,256,512 """ downsample = None if stride != 1 or self.in_channel != channel * block.expansion: #50/101/152 conv2对应的虚线残差结构只需要调整矩阵深度 downsample = nn.Sequential( nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(channel * block.expansion)) layers = [] #空列表 layers.append(block(self.in_channel, channel, downsample = downsample, stride = stride)) self.in_channel = channel * block.expansion for _ in range(1, block_num): layers.append(block(self.in_channel, channel)) #输入特征矩阵的深度,主线分支上第一个卷积层的卷积核个数 return nn.Sequential(*layers) #将list列表转换为非关键字参数传入 def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) if self.include_top: x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return xdef resnet34(num_classes=1000, include_top=True): return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet101(num_classes=1000, include_top=True): return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)
train.py
import torchimport torch.nn as nnfrom torchvision import transforms, datasetsimport jsonimport matplotlib.pyplot as pltimport osimport torch.optim as optimfrom model import resnet34, resnet101import torchvision.models.resnetdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(device)data_transform = { "train": transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), "val": transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root pathimage_path = data_root + "/data_set/flower_data/" # flower data set pathtrain_dataset = datasets.ImageFolder(root=image_path+"train", transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file: json_file.write(json_str)batch_size = 16train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)validate_dataset = datasets.ImageFolder(root=image_path + "val", transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size, shuffle=False, num_workers=0)net = resnet34()# load pretrain weightsmodel_weight_path = "./resnet34-pre.pth"missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)# for param in net.parameters():# param.requires_grad = False# change fc layer structureinchannel = net.fc.in_featuresnet.fc = nn.Linear(inchannel, 5)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0001)best_acc = 0.0save_path = './resNet34.pth'for epoch in range(3): # train net.train() running_loss = 0.0 for step, data in enumerate(train_loader, start=0): images, labels = data optimizer.zero_grad() logits = net(images.to(device)) loss = loss_function(logits, labels.to(device)) loss.backward() optimizer.step() # print statistics running_loss += loss.item() # print train process rate = (step+1)/len(train_loader) a = "*" * int(rate * 50) b = "." * int((1 - rate) * 50) print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="") print() # validate net.eval() acc = 0.0 # accumulate accurate number / epoch with torch.no_grad(): for val_data in validate_loader: val_images, val_labels = val_data outputs = net(val_images.to(device)) # eval model only have last output layer # loss = loss_function(outputs, test_labels) predict_y = torch.max(outputs, dim=1)[1] acc += (predict_y == val_labels.to(device)).sum().item() val_accurate = acc / val_num if val_accurate > best_acc: best_acc = val_accurate torch.save(net.state_dict(), save_path) print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' % (epoch + 1, running_loss / step, val_accurate))print('Finished Training')
predict.py
import torchfrom model import resnet34from PIL import Imagefrom torchvision import transformsimport matplotlib.pyplot as pltimport jsondevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose( [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load imageimg = Image.open("../tulip.jpg")plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indicttry: json_file = open('./class_indices.json', 'r') class_indict = json.load(json_file)except Exception as e: print(e) exit(-1)# create modelmodel = resnet34(num_classes=5)# load model weightsmodel_weight_path = "./resNet34.pth"model.load_state_dict(torch.load(model_weight_path, map_location=device))model.eval()with torch.no_grad(): # predict class output = torch.squeeze(model(img)) predict = torch.softmax(output, dim=0) predict_cla = torch.argmax(predict).numpy()print(class_indict[str(predict_cla)], predict[predict_cla].numpy())plt.show()