卷积神经网络出现的意义在于从人工设计卷积核,到通过样本自动学习卷积核。

LeNet结构

点击查看【processon】

特征图计算公式

卷积计算输出的公式

MINIST入门LeNet网络 - 图1

模型Pytorch代码解析

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import matplotlib.pyplot as plt
  5. import copy
  6. from torchvision.utils import make_grid
  7. from matplotlib.pyplot import MultipleLocator
  8. class LeNet(nn.Module):
  9. def __init__(self):
  10. super(LeNet, self).__init__()
  11. self.conv1 = nn.Conv2d(1, 6, 5) # nn.Conv2d(3, 6, 5)
  12. self.conv2 = nn.Conv2d(6, 16, 5)
  13. # nn.init()
  14. self.fc1 = nn.Linear(16 * 4 * 4, 10)
  15. self.fc2 = nn.Linear(120, 84)
  16. self.fc3 = nn.Linear(84, 10)
  17. def forward(self, x):
  18. #show_graph(x, "origin")
  19. x = self.conv1(x)
  20. #show_graph(x, "conv1")
  21. x = F.relu(x) # f(x) = max(0, x)
  22. #show_graph(x, "conv1_relu")
  23. x = F.max_pool2d(x, 2)
  24. #show_graph(x, "conv1_relu_maxpool")
  25. x = self.conv2(x)
  26. #show_graph(x, "conv2")
  27. x = F.relu(x)
  28. #show_graph(x, "conv2_relu")
  29. x = F.max_pool2d(x, 2)
  30. #show_graph(x, "conv2_relu_maxpool")
  31. x = x.view(x.size(0), -1)
  32. # print(x[0].size())
  33. x = self.fc1(x)
  34. # print(x[0].size())
  35. # x = F.relu(x)
  36. # x = self.fc2(x)
  37. # # print(x[0].size())
  38. # x = F.relu(x)
  39. # x = self.fc3(x)
  40. # print(x[0])
  41. x = F.log_softmax(x, dim=1)
  42. # print(x[0])
  43. # exit()
  44. return x
  45. def show_graph(x, string):
  46. # y = copy.deepcopy(x[0][0])
  47. # make_grid(y)
  48. # plt.imshow(y.cpu().numpy(), cmap='gray')
  49. # plt.grid()
  50. # plt.show()
  51. y = copy.deepcopy(x[0])
  52. print(y[0])
  53. y = y * 0.3081 + 0.1307
  54. y = y.cpu().numpy()
  55. print(len(y[0]))
  56. print(y[0])
  57. ax = plt.gca()
  58. if len(y[0]) < 10:
  59. x_major_locator = MultipleLocator(1)
  60. y_major_locator = MultipleLocator(1)
  61. else:
  62. x_major_locator = MultipleLocator(5)
  63. y_major_locator = MultipleLocator(5)
  64. ax.xaxis.set_major_locator(x_major_locator)
  65. ax.yaxis.set_major_locator(y_major_locator)
  66. for i in range(len(y)):
  67. plt.imshow(y[0], cmap='gray')
  68. #plt.grid(b=True, which="major", axis="both", ls="--")
  69. #plt.xlim(0, len(y[0]))
  70. #plt.ylim(len(y[0]), 0)
  71. plt.title("LeNet_{}_{}".format(string, i + 1))
  72. #plt.savefig("E:/WorkSpace/Pytorch/mnist/model/lenet_feature_map/{}_{}".format(string, i + 1))
  73. plt.show()

Reference

https://zhuanlan.zhihu.com/p/38200980

https://blog.csdn.net/CSSDCC/article/details/116461271?spm=1001.2014.3001.5501