项目地址:

一、准备数据

在Pytorch中构建图片数据管道通常有三种方法

  1. 第一种是使用torchvision中的datasets.ImageFolder来读取图片,然后用DataLoader来并行加载。
  2. 第二种是通过继承torch.utils.data.Dataset实现用户自定义读取逻辑然后用DataLoader来并行加载。
  3. 第三种方法是读取用户自定义数据集的通用方法,既可以读取图片数据集,也可以读取文本数据集。

这里使用第一种方法

  1. import torch
  2. import torch.nn as nn
  3. from torch.utils.data import Dataset, DataLoader
  4. from torchvision import transforms,datasets
  5. transform_train = transforms.Compose(
  6. [transforms.ToTensor()])
  7. transform_valid = transforms.Compose(
  8. [transforms.ToTensor()])
  9. ds_train = datasets.ImageFolder("../data/cifar2/train/",
  10. transform = transform_train, target_transform = lambda t: torch.tensor([t]).float())
  11. ds_valid = datasets.ImageFolder("../data/cifar2/test/",
  12. transform = transform_valid,target_transform= lambda t:torch.tensor([t]).float())
  13. print(ds_train.class_to_idx)
  14. dl_train = DataLoader(ds_train, batch_size = 50, shuffle = True, num_workers = 3)
  15. dl_valid = DataLoader(ds_train, batch_size = 50, shuffle = True, num_workers = 3)

查看部分样本

  1. from matplotlib import pyplot as plt
  2. plt.figure(figsize=(8,8))
  3. for i in range(9):
  4. img, label = ds_train[i]
  5. img = img.permute(1,2,0)
  6. ax = plt.subplot(3,3,i+1)
  7. ax.imshow(img.numpy())
  8. ax.set_title("label = %d"%label.item())
  9. ax.set_xticks([])
  10. ax.set_yticks([])
  11. plt.show()

pytorch 的图片默认顺序是: Batch, Channel, Width, Height

  1. for x, y in dl_train:
  2. print(x.shape, y.shape)
  3. break
  4. >>> torch.Size([50, 3, 32, 32]) torch.Size([50, 1])

二,定义模型

使用Pytorch通常有三种方式构建模型:

  1. 使用nn.Sequential 按层顺序构建模型(nn.Sequential)
  2. 继承nn.Module基类构建自定义模型(nn.ModuleList)
  3. 继承nn.Module基类构建模型并辅助应用模型容器(nn.ModuleDict)
  1. #测试AdaptiveMaxPool2d的效果
  2. pool = nn.AdaptiveMaxPool2d((1,1))
  3. t = torch.randn(10,8,32,32)
  4. pool(t).shape
  1. class Net(nn.Module):
  2. def __init__(self):
  3. super(Net, self).__init__()
  4. self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)
  5. self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
  6. self.conv2 = nn.Conv2d(in_channels=32, out_channles=64, kernel_size=5)
  7. self.dropout = nn.Dropout2d(p=0.1)
  8. self.adaptive_pool = nn.AdaptiveMaxPool2d((1,1))
  9. self.flatten = nn.Flatten()
  10. self.linear1 = nn.Linear(64,32)
  11. self.relu = nn.ReLU()
  12. self.linear2 = nn.Linear(32,1)
  13. self.sigmoid = nn.Sigmoid()
  14. def forward(self, x):
  15. x = self.conv1(x)
  16. x = self.pool(x)
  17. x = self.conv2(x)
  18. x = self.pool(x)
  19. x = self.dropout(x)
  20. x = self.adaptive_pool(x)
  21. x = self.flatten(x)
  22. x = self.linear1(x)
  23. x = self.relu(x)
  24. y = self.sigmoid(x)
  25. return y
  26. net = Net()
  27. print(net)

使用torchkeras打印网络模型

  1. import torchkeras
  2. torchkears.summary(net, input_shape=(3, 32, 32))
  1. ----------------------------------------------------------------
  2. Layer (type) Output Shape Param #
  3. ================================================================
  4. Conv2d-1 [-1, 32, 30, 30] 896
  5. MaxPool2d-2 [-1, 32, 15, 15] 0
  6. Conv2d-3 [-1, 64, 11, 11] 51,264
  7. MaxPool2d-4 [-1, 64, 5, 5] 0
  8. Dropout2d-5 [-1, 64, 5, 5] 0
  9. AdaptiveMaxPool2d-6 [-1, 64, 1, 1] 0
  10. Flatten-7 [-1, 64] 0
  11. Linear-8 [-1, 32] 2,080
  12. ReLU-9 [-1, 32] 0
  13. Linear-10 [-1, 1] 33
  14. Sigmoid-11 [-1, 1] 0
  15. ================================================================
  16. Total params: 54,273
  17. Trainable params: 54,273
  18. Non-trainable params: 0
  19. ----------------------------------------------------------------
  20. Input size (MB): 0.011719
  21. Forward/backward pass size (MB): 0.359634
  22. Params size (MB): 0.207035
  23. Estimated Total Size (MB): 0.578388
  24. ----------------------------------------------------------------

三,训练模型

有三种典型训练循环代码风格:

  1. 脚本形式训练循环
  2. 函数形式训练循环
  3. 类形式训练循环

函数形式循环训练

  1. import pandas as pd
  2. from sklearn.metrics import roc_auc_score
  3. model = net
  4. model.optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
  5. model.loss_func = torch.nn.BCELoss()
  6. modle.metric_func = lambda y_pred, y_true: roc_auc_score(y_true.data.numpy(), y_pred.data.numpy())
  7. modle.metric_name = "auc"
  8. def train_step(modle, features, labels):
  9. # 训练模式, dropout层发生作用,使用model.train()
  10. model.train()
  11. # 梯度清零
  12. model.optimizer.zero_grad()
  13. # 正向传播求损失
  14. predictions = model(features)
  15. loss = model.loss_func(predictions, labels)
  16. metric = model.metric_func(predictions, labels)
  17. # 反向传播求梯度
  18. loss.backward()
  19. model.optimizer.step()
  20. return loss.item(), metric.item()
  21. def valid_step(model, features, labels):
  22. # 预测模式, dropout层不发生作用
  23. model.eval()
  24. # 关闭梯度计算
  25. with torch.no_grad():
  26. predictions = model(features)
  27. loss = model.loss_func(predictions, labels)
  28. metric = model.metric_func(predictions, labels)
  29. return loss.item(), metric.item()
  30. # 测试train_step效果
  31. features, labels = next(iter(dl_train))
  32. train_step(model, features, labels)
  1. def train_model(model, epochs, dl_train, dl_valid, log_step_freq):
  2. metric_name = model.metric_name
  3. dfhistory = pd.DataFrame(columns = ["epoch", "loss", metric_name, "val_loss", "val_"+metric_name])
  4. print("Start Training...")
  5. nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  6. print("==========="*8 + "%s"%nowtime)
  7. for epoch in range(1, epochs+1):
  8. # 1, 训练循环
  9. loss_num = 0.0
  10. metric_sum = 0.0
  11. step = 1
  12. for step, (features, labels) in enumerate(dl_train, 1):
  13. loss, metric = train_step(model, features, labels)
  14. loss_sum += loss
  15. metric_sum += metric
  16. if step%log_step_freq == 0:
  17. print(("[step = %d] loss: %.3f, "+metric_name+":%.3f")%
  18. (step, loss_sum/step, metric_sum/step)
  19. )
  20. # 2, 验证循环
  21. val_loss_sum = 0.0
  22. val_metric_sum = 0.0
  23. val_step = 1
  24. for val_step, (features, labels) in enumerate(dl_valid, 1):
  25. val_loss, val_metric = valid_step(model, features, labels)
  26. val_loss_sum += val_loss
  27. val_metric_sum += val_metric
  28. # 3, 记录日志
  29. info = (epoch, loss_sum/step, metric_sum/step,
  30. val_loss_sum/val_step, val_metric_sum/val_step)
  31. dfhistory.loc[epoch-1] = info
  32. # 打印epoch级别日志
  33. print(("\nEPOCH = %d, loss = %.3f,"+ metric_name + \
  34. " = %.3f, val_loss = %.3f, "+"val_"+ metric_name+" = %.3f")
  35. %info)
  36. nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  37. print("\n"+"=========="*8 + "%s"%nowtime)