1. import numpy as np
    2. import torch
    3. # 调入pytorch内置的mnist数据
    4. from torchvision.datasets import mnist
    5. # 导入预处理模块
    6. import torchvision.transforms as transforms
    7. from torch.utils.data import DataLoader
    8. # 导入nn以及优化器
    9. import torch.nn.functional as F
    10. import torch.optim as optim
    11. from torch import nn
    12. import matplotlib.pyplot as plt
    13. # 当前程序为全连接的手写数字识别
    14. # 定义一些超参数
    15. train_batch_size = 64
    16. test_batch_size = 128
    17. learning_rate = 0.01
    18. num_epoches = 5
    19. lr = 0.01
    20. momentum = 0.5
    21. # 下载数据并且进行数据预处理
    22. # 定义预处理函数,这些函数依次放在Compose函数中
    23. # 其中transforms.Compose用来把他们组合在一起,Normalize([0.5], [0.5])用来进行归一化
    24. # 因图像是灰色的,只有一个通道,所以只有一个数字,如果是三个通道,就需要三个数字
    25. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
    26. # 下载数据,并对数据进行预处理
    27. train_dataset = mnist.MNIST('./data', train=True, transform=transform, download=False)
    28. test_dataset = mnist.MNIST('./data', train=False, transform=transform)
    29. # dataloader是一个可迭代对象,可以使用迭代器一样使用。
    30. train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
    31. test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
    32. # 可视化数据源
    33. examples = enumerate(test_loader)
    34. batch_idx, (example_data, example_targets) = next(examples)
    35. fig = plt.figure()
    36. # 用来展示图片
    37. # for i in range(6):
    38. # plt.subplot(2, 3, i + 1)
    39. # plt.tight_layout()
    40. # plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
    41. # plt.title("Ground Truth: {}".format(example_targets[i]))
    42. # plt.xticks([])
    43. # plt.yticks([])
    44. # plt.show()
    45. # 开始构建网络
    46. class Net(nn.Module):
    47. # 使用sequential构建网络,Sequential()函数的功能是将网络的层组合到一起
    48. def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
    49. super(Net, self).__init__()
    50. self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1))
    51. self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2))
    52. self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))
    53. def forward(self, x):
    54. x = F.relu(self.layer1(x))
    55. x = F.relu(self.layer2(x))
    56. x = self.layer3(x)
    57. return x
    58. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    59. # 实例化网络
    60. model = Net(28 * 28, 300, 100, 10)
    61. model.to(device)
    62. # 定义损失函数和优化器
    63. criterion = nn.CrossEntropyLoss();
    64. optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    65. # 训练模型
    66. losses = []
    67. acces = []
    68. eval_losses = []
    69. eval_acces = []
    70. for epoch in range(num_epoches):
    71. train_loss = 0
    72. train_acc = 0
    73. model.train()
    74. # 动态修改学习率
    75. if epoch % 5 == 0:
    76. optimizer.param_groups[0]['lr'] *= 0.1
    77. for img, label in train_loader:
    78. print(img.shape)
    79. img = img.to(device)
    80. label = label.to(device)
    81. img = img.view(img.size(0), -1)
    82. # 前向传播
    83. out = model(img)
    84. loss = criterion(out, label)
    85. # 反向传播
    86. optimizer.zero_grad()
    87. loss.backward()
    88. optimizer.step()
    89. # 记录误差
    90. train_loss += loss.item()
    91. # 计算分类的准确率
    92. _, pred = out.max(1)
    93. num_correct = (pred == label).sum().item()
    94. acc = num_correct / img.shape[0]
    95. train_acc += acc
    96. losses.append(train_loss / len(train_loader))
    97. acces.append(train_acc / len(train_loader))
    98. # 在测试集上检验效果
    99. eval_loss = 0
    100. eval_acc = 0
    101. # 将模型改为预测模式
    102. model.eval()
    103. for img, label in test_loader:
    104. img = img.to(device)
    105. label = label.to(device)
    106. img = img.view(img.size(0), -1)
    107. out = model(img)
    108. loss = criterion(out, label)
    109. # 记录误差
    110. eval_loss += loss.item()
    111. # 记录准确率
    112. _, pred = out.max(1)
    113. num_correct = (pred == label).sum().item()
    114. acc = num_correct / img.shape[0]
    115. eval_acc += acc
    116. eval_losses.append(eval_loss / len(test_loader))
    117. eval_acces.append(eval_acc / len(test_loader))
    118. print('epoch: {}, Train Loss: {:.4f}, Train Acc: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.4f}'
    119. .format(epoch, train_loss / len(train_loader), train_acc / len(train_loader),
    120. eval_loss / len(test_loader), eval_acc / len(test_loader)))
    121. plt.title('train loss')
    122. plt.plot(np.arange(len(losses)), losses)
    123. plt.legend(['Train Loss'], loc='upper right')
    124. plt.show()