lesson5.pdf

    1. import torch
    2. from torch import nn
    3. from torch.nn import functional as F
    4. from torch import optim
    5. import torchvision
    6. from matplotlib import pyplot as plt
    7. from utils import plot_image, plot_curve, one_hot
    8. batch_size = 512
    9. # step1. load dataset
    10. train_loader = torch.utils.data.DataLoader(
    11. torchvision.datasets.MNIST('mnist_data', train=True, download=True,
    12. transform=torchvision.transforms.Compose([
    13. torchvision.transforms.ToTensor(),
    14. torchvision.transforms.Normalize(
    15. (0.1307,), (0.3081,))
    16. ])),
    17. batch_size=batch_size, shuffle=True)
    18. test_loader = torch.utils.data.DataLoader(
    19. torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
    20. transform=torchvision.transforms.Compose([
    21. torchvision.transforms.ToTensor(),
    22. torchvision.transforms.Normalize(
    23. (0.1307,), (0.3081,))
    24. ])),
    25. batch_size=batch_size, shuffle=False)
    26. x, y = next(iter(train_loader))
    27. print(x.shape, y.shape, x.min(), x.max())
    28. plot_image(x, y, 'image sample')
    29. class Net(nn.Module):
    30. def __init__(self):
    31. super(Net, self).__init__()
    32. # xw+b
    33. self.fc1 = nn.Linear(28*28, 256)
    34. self.fc2 = nn.Linear(256, 64)
    35. self.fc3 = nn.Linear(64, 10)
    36. def forward(self, x):
    37. # x: [b, 1, 28, 28]
    38. # h1 = relu(xw1+b1)
    39. x = F.relu(self.fc1(x))
    40. # h2 = relu(h1w2+b2)
    41. x = F.relu(self.fc2(x))
    42. # h3 = h2w3+b3
    43. x = self.fc3(x)
    44. return x
    45. net = Net()
    46. # [w1, b1, w2, b2, w3, b3]
    47. optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
    48. train_loss = []
    49. for epoch in range(3):
    50. for batch_idx, (x, y) in enumerate(train_loader):
    51. # x: [b, 1, 28, 28], y: [512]
    52. # [b, 1, 28, 28] => [b, 784]
    53. x = x.view(x.size(0), 28*28)
    54. # => [b, 10]
    55. out = net(x)
    56. # [b, 10]
    57. y_onehot = one_hot(y)
    58. # loss = mse(out, y_onehot)
    59. loss = F.mse_loss(out, y_onehot)
    60. optimizer.zero_grad()
    61. loss.backward()
    62. # w' = w - lr*grad
    63. optimizer.step()
    64. train_loss.append(loss.item())
    65. if batch_idx % 10==0:
    66. print(epoch, batch_idx, loss.item())
    67. plot_curve(train_loss)
    68. # we get optimal [w1, b1, w2, b2, w3, b3]
    69. total_correct = 0
    70. for x,y in test_loader:
    71. x = x.view(x.size(0), 28*28)
    72. out = net(x)
    73. # out: [b, 10] => pred: [b]
    74. pred = out.argmax(dim=1)
    75. correct = pred.eq(y).sum().float().item()
    76. total_correct += correct
    77. total_num = len(test_loader.dataset)
    78. acc = total_correct / total_num
    79. print('test acc:', acc)
    80. x, y = next(iter(test_loader))
    81. out = net(x.view(x.size(0), 28*28))
    82. pred = out.argmax(dim=1)
    83. plot_image(x, pred, 'test')
    1. import torch
    2. from matplotlib import pyplot as plt
    3. def plot_curve(data):
    4. fig = plt.figure()
    5. plt.plot(range(len(data)), data, color='blue')
    6. plt.legend(['value'], loc='upper right')
    7. plt.xlabel('step')
    8. plt.ylabel('value')
    9. plt.show()
    10. def plot_image(img, label, name):
    11. fig = plt.figure()
    12. for i in range(6):
    13. plt.subplot(2, 3, i + 1)
    14. plt.tight_layout()
    15. plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
    16. plt.title("{}: {}".format(name, label[i].item()))
    17. plt.xticks([])
    18. plt.yticks([])
    19. plt.show()
    20. def one_hot(label, depth=10):
    21. out = torch.zeros(label.size(0), depth)
    22. idx = torch.LongTensor(label).view(-1, 1)
    23. out.scatter_(dim=1, index=idx, value=1)
    24. return out