42.pdf

    1. import torch
    2. from torch import nn
    3. from torch.nn import functional as F
    4. from torch.utils.data import DataLoader
    5. from torchvision import datasets
    6. from torchvision import transforms
    7. from torch import nn, optim
    8. # from torchvision.models import resnet18
    9. class ResBlk(nn.Module):
    10. """
    11. resnet block
    12. """
    13. def __init__(self, ch_in, ch_out):
    14. """
    15. :param ch_in:
    16. :param ch_out:
    17. """
    18. super(ResBlk, self).__init__()
    19. self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
    20. self.bn1 = nn.BatchNorm2d(ch_out)
    21. self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
    22. self.bn2 = nn.BatchNorm2d(ch_out)
    23. self.extra = nn.Sequential()
    24. if ch_out != ch_in:
    25. # [b, ch_in, h, w] => [b, ch_out, h, w]
    26. self.extra = nn.Sequential(
    27. nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
    28. nn.BatchNorm2d(ch_out)
    29. )
    30. def forward(self, x):
    31. """
    32. :param x: [b, ch, h, w]
    33. :return:
    34. """
    35. out = F.relu(self.bn1(self.conv1(x)))
    36. out = self.bn2(self.conv2(out))
    37. # short cut.
    38. # extra module: [b, ch_in, h, w] => [b, ch_out, h, w]
    39. # element-wise add:
    40. out = self.extra(x) + out
    41. return out
    42. class ResNet18(nn.Module):
    43. def __init__(self):
    44. super(ResNet18, self).__init__()
    45. self.conv1 = nn.Sequential(
    46. nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
    47. nn.BatchNorm2d(16)
    48. )
    49. # followed 4 blocks
    50. # [b, 64, h, w] => [b, 128, h ,w]
    51. self.blk1 = ResBlk(16, 16)
    52. # [b, 128, h, w] => [b, 256, h, w]
    53. self.blk2 = ResBlk(16, 32)
    54. # # [b, 256, h, w] => [b, 512, h, w]
    55. # self.blk3 = ResBlk(128, 256)
    56. # # [b, 512, h, w] => [b, 1024, h, w]
    57. # self.blk4 = ResBlk(256, 512)
    58. self.outlayer = nn.Linear(32*32*32, 10)
    59. def forward(self, x):
    60. """
    61. :param x:
    62. :return:
    63. """
    64. x = F.relu(self.conv1(x))
    65. # [b, 64, h, w] => [b, 1024, h, w]
    66. x = self.blk1(x)
    67. x = self.blk2(x)
    68. # x = self.blk3(x)
    69. # x = self.blk4(x)
    70. # print(x.shape)
    71. x = x.view(x.size(0), -1)
    72. x = self.outlayer(x)
    73. return x
    74. def main():
    75. batchsz = 32
    76. cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
    77. transforms.Resize((32, 32)),
    78. transforms.ToTensor()
    79. ]), download=True)
    80. cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
    81. cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
    82. transforms.Resize((32, 32)),
    83. transforms.ToTensor()
    84. ]), download=True)
    85. cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
    86. x, label = iter(cifar_train).next()
    87. print('x:', x.shape, 'label:', label.shape)
    88. device = torch.device('cuda')
    89. # model = Lenet5().to(device)
    90. model = ResNet18().to(device)
    91. criteon = nn.CrossEntropyLoss().to(device)
    92. optimizer = optim.Adam(model.parameters(), lr=1e-3)
    93. print(model)
    94. for epoch in range(1000):
    95. model.train()
    96. for batchidx, (x, label) in enumerate(cifar_train):
    97. # [b, 3, 32, 32]
    98. # [b]
    99. x, label = x.to(device), label.to(device)
    100. logits = model(x)
    101. # logits: [b, 10]
    102. # label: [b]
    103. # loss: tensor scalar
    104. loss = criteon(logits, label)
    105. # backprop
    106. optimizer.zero_grad()
    107. loss.backward()
    108. optimizer.step()
    109. #
    110. print(epoch, 'loss:', loss.item())
    111. model.eval()
    112. with torch.no_grad():
    113. # test
    114. total_correct = 0
    115. total_num = 0
    116. for x, label in cifar_test:
    117. # [b, 3, 32, 32]
    118. # [b]
    119. x, label = x.to(device), label.to(device)
    120. # [b, 10]
    121. logits = model(x)
    122. # [b]
    123. pred = logits.argmax(dim=1)
    124. # [b] vs [b] => scalar tensor
    125. correct = torch.eq(pred, label).float().sum().item()
    126. total_correct += correct
    127. total_num += x.size(0)
    128. # print(correct)
    129. acc = total_correct / total_num
    130. print(epoch, 'acc:', acc)
    131. if __name__ == '__main__':
    132. main()