这块之前已经操练过了,把网络结构与训练函数放在这里。

网络结构

  1. import time
  2. import torch
  3. from torch import nn, optim
  4. import torchvision
  5. from utils import *
  6. class LeNet(nn.Module):
  7. def __init__(self):
  8. super(LeNet, self).__init__()
  9. self.conv = nn.Sequential(
  10. nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
  11. nn.ReLU(),
  12. nn.MaxPool2d(2, 2), # kernel_size, stride
  13. nn.Conv2d(6, 16, 5),
  14. nn.ReLU(),
  15. nn.MaxPool2d(2, 2)
  16. )
  17. self.fc = nn.Sequential(
  18. nn.Linear(16*4*4, 512),
  19. nn.LeakyReLU(negative_slope=0.01),
  20. nn.Dropout(0.5),
  21. # nn.ReLU(),
  22. nn.Linear(512, 256),
  23. # nn.LeakyReLU(negative_slope=0.01),
  24. nn.Sigmoid(),
  25. nn.Dropout(0.5),
  26. nn.Linear(256, 200)
  27. )
  28. def forward(self, img):
  29. feature = self.conv(img)
  30. output = self.fc(feature.view(img.shape[0], -1))
  31. return output
  32. if __name__ == "__main__":
  33. net = LeNet()
  34. print(net)
  35. batch_size = 64
  36. resize = [28, 28]
  37. train_iter, test_iter = load_data_cnn(batch_size, resize=resize)
  38. lr, num_epochs = 0.001, 50
  39. loss = torch.nn.CrossEntropyLoss()
  40. optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=0.001)
  41. train(net, train_iter, test_iter, loss=loss, optimizer=optimizer, device='cuda', num_epochs=num_epochs, file='cnn.pt')

数据集与训练函数

  1. import torch
  2. from torch.utils.data import *
  3. import struct
  4. import os
  5. import time
  6. import torchvision
  7. from PIL import Image
  8. from torchvision.datasets.vision import VisionDataset
  9. from torchvision.transforms import ToPILImage
  10. # from IPython.display import Image
  11. import matplotlib.pyplot as plt
  12. from IPython import display
  13. import hiddenlayer as hl
  14. class ImageSet(Dataset):
  15. def __init__(self, path, dimensions, classes, size):
  16. super(ImageSet, self).__init__()
  17. data_raw = ()
  18. label_raw = []
  19. for i in range(classes):
  20. file_name = os.path.join(path, 'f' + str(i) + '.dat')
  21. with open(file_name, 'rb') as f:
  22. data_raw += struct.unpack('f' * size * dimensions, f.read(4 * size * dimensions))
  23. for j in range(size):
  24. label_raw.append(float(i))
  25. self.train_data = torch.tensor(data_raw).clone().view(-1, dimensions)
  26. self.train_label = torch.tensor(label_raw)
  27. def __getitem__(self, index):
  28. return self.train_data[index], self.train_label[index].long()
  29. def __len__(self):
  30. return self.train_label.size()[0]
  31. class ImageSetCNN(VisionDataset):
  32. def __init__(self, root, dim_x, dim_y, classes, size, transform=None, target_transform=None):
  33. super(ImageSetCNN, self).__init__(root, transform=transform, target_transform=target_transform)
  34. data_raw = ()
  35. label_raw = []
  36. for i in range(classes):
  37. file_name = os.path.join(root, 'f' + str(i) + '.dat')
  38. with open(file_name, 'rb') as f:
  39. data_raw += struct.unpack('f' * size * dim_x * dim_y, f.read(4 * size * dim_x * dim_y))
  40. for j in range(size):
  41. # one_hot = torch.zeros(classes)
  42. # one_hot[i] = 1
  43. label_raw.append(i)
  44. self.train_data = torch.tensor(data_raw).clone().view(-1, dim_x, dim_y)
  45. self.train_label = torch.tensor(label_raw)
  46. def __getitem__(self, index):
  47. img, target = self.train_data[index], int(self.train_label[index])
  48. img = Image.fromarray(img.numpy(), mode='L')
  49. # img.show()
  50. if self.transform is not None:
  51. img = self.transform(img)
  52. if self.target_transform is not None:
  53. target = self.target_transform(target)
  54. return img, target
  55. def __len__(self):
  56. return self.train_label.size()[0]
  57. def load_data(batch_size):
  58. training_set = ImageSet('train', 440, 200, 144)
  59. test_set = ImageSet('test', 440, 200, 18)
  60. num_workers = 4
  61. train_iter = DataLoader(training_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
  62. test_iter = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
  63. return train_iter, test_iter
  64. def load_data_cnn(batch_size, resize=None):
  65. trans = []
  66. if resize:
  67. trans.append(torchvision.transforms.Resize(size=resize))
  68. trans.append(torchvision.transforms.ToTensor())
  69. transform = torchvision.transforms.Compose(trans)
  70. training_set = ImageSetCNN('train', 20, 22, 200, 144, transform=transform)
  71. test_set = ImageSetCNN('test', 20, 22, 200, 18, transform=transform)
  72. train_iter = DataLoader(training_set, batch_size=batch_size, shuffle=True, num_workers=4)
  73. test_iter = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4)
  74. return train_iter, test_iter
  75. def train(net, train_iter, test_iter, loss, optimizer, device, num_epochs, file):
  76. # 记录训练过程的指标
  77. history = hl.History()
  78. # 使用canvas进行可视化
  79. canvas = hl.Canvas()
  80. net = net.to(device)
  81. print("training on ", device)
  82. best_acc = 0
  83. no_progress = 0
  84. # loss = torch.nn.CrossEntropyLoss()
  85. for epoch in range(num_epochs):
  86. train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
  87. for X, y in train_iter:
  88. X = X.to(device)
  89. y = y.to(device)
  90. y_hat = net(X)
  91. l = loss(y_hat, y)
  92. optimizer.zero_grad()
  93. l.backward()
  94. optimizer.step()
  95. train_l_sum += l.cpu().item()
  96. train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
  97. n += y.shape[0]
  98. batch_count += 1
  99. test_acc = evaluate(test_iter, net)
  100. print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
  101. % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
  102. history.log(epoch,
  103. train_loss=train_l_sum / batch_count,
  104. train_acc=train_acc_sum / n,
  105. test_acc=test_acc)
  106. # 可视化
  107. with canvas:
  108. canvas.draw_plot(history["train_loss"])
  109. canvas.draw_plot(history["train_acc"])
  110. canvas.draw_plot(history["test_acc"])
  111. if test_acc > best_acc:
  112. no_progress = 0
  113. torch.save(net.state_dict(), file)
  114. best_acc = test_acc
  115. print('best model saved to ' + file)
  116. else:
  117. no_progress += 1
  118. if no_progress > 20:
  119. break
  120. def evaluate(data_iter, net, device=None):
  121. if device is None and isinstance(net, torch.nn.Module):
  122. device = list(net.parameters())[0].device
  123. acc_sum, n = 0.0, 0
  124. with torch.no_grad():
  125. for X, y in data_iter:
  126. net.eval() # 评估模式, 关闭dropout正则化层
  127. acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
  128. net.train() # 改回训练模式
  129. n += y.shape[0]
  130. return acc_sum / n