1. import torch
    2. from torch import nn
    3. from torch.nn import functional as F
    4. class ResBlk(nn.Module):
    5. """
    6. resnet block
    7. """
    8. def __init__(self, ch_in, ch_out, stride=1):
    9. """
    10. :param ch_in:
    11. :param ch_out:
    12. """
    13. super(ResBlk, self).__init__()
    14. # we add stride support for resbok, which is distinct from tutorials.
    15. self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
    16. self.bn1 = nn.BatchNorm2d(ch_out)
    17. self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
    18. self.bn2 = nn.BatchNorm2d(ch_out)
    19. self.extra = nn.Sequential()
    20. if ch_out != ch_in:
    21. # [b, ch_in, h, w] => [b, ch_out, h, w]
    22. self.extra = nn.Sequential(
    23. nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
    24. nn.BatchNorm2d(ch_out)
    25. )
    26. def forward(self, x):
    27. """
    28. :param x: [b, ch, h, w]
    29. :return:
    30. """
    31. out = F.relu(self.bn1(self.conv1(x)))
    32. out = self.bn2(self.conv2(out))
    33. # short cut.
    34. # extra module: [b, ch_in, h, w] => [b, ch_out, h, w]
    35. # element-wise add:
    36. out = self.extra(x) + out
    37. out = F.relu(out)
    38. return out
    39. class ResNet18(nn.Module):
    40. def __init__(self):
    41. super(ResNet18, self).__init__()
    42. self.conv1 = nn.Sequential(
    43. nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
    44. nn.BatchNorm2d(64)
    45. )
    46. # followed 4 blocks
    47. # [b, 64, h, w] => [b, 128, h ,w]
    48. self.blk1 = ResBlk(64, 128, stride=2)
    49. # [b, 128, h, w] => [b, 256, h, w]
    50. self.blk2 = ResBlk(128, 256, stride=2)
    51. # # [b, 256, h, w] => [b, 512, h, w]
    52. self.blk3 = ResBlk(256, 512, stride=2)
    53. # # [b, 512, h, w] => [b, 1024, h, w]
    54. self.blk4 = ResBlk(512, 512, stride=2)
    55. self.outlayer = nn.Linear(512*1*1, 10)
    56. def forward(self, x):
    57. """
    58. :param x:
    59. :return:
    60. """
    61. x = F.relu(self.conv1(x))
    62. # [b, 64, h, w] => [b, 1024, h, w]
    63. x = self.blk1(x)
    64. x = self.blk2(x)
    65. x = self.blk3(x)
    66. x = self.blk4(x)
    67. # print('after conv:', x.shape) #[b, 512, 2, 2]
    68. # [b, 512, h, w] => [b, 512, 1, 1]
    69. x = F.adaptive_avg_pool2d(x, [1, 1])
    70. # print('after pool:', x.shape)
    71. x = x.view(x.size(0), -1)
    72. x = self.outlayer(x)
    73. return x
    74. def main():
    75. blk = ResBlk(64, 128, stride=4)
    76. tmp = torch.randn(2, 64, 32, 32)
    77. out = blk(tmp)
    78. print('block:', out.shape)
    79. x = torch.randn(2, 3, 32, 32)
    80. model = ResNet18()
    81. out = model(x)
    82. print('resnet:', out.shape)
    83. if __name__ == '__main__':
    84. main()
    1. import torch
    2. from torch.utils.data import DataLoader
    3. from torchvision import datasets
    4. from torchvision import transforms
    5. from torch import nn, optim
    6. from lenet5 import Lenet5
    7. from resnet import ResNet18
    8. def main():
    9. batchsz = 128
    10. cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
    11. transforms.Resize((32, 32)),
    12. transforms.ToTensor(),
    13. transforms.Normalize(mean=[0.485, 0.456, 0.406],
    14. std=[0.229, 0.224, 0.225])
    15. ]), download=True)
    16. cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
    17. cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
    18. transforms.Resize((32, 32)),
    19. transforms.ToTensor(),
    20. transforms.Normalize(mean=[0.485, 0.456, 0.406],
    21. std=[0.229, 0.224, 0.225])
    22. ]), download=True)
    23. cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
    24. x, label = iter(cifar_train).next()
    25. print('x:', x.shape, 'label:', label.shape)
    26. device = torch.device('cuda')
    27. # model = Lenet5().to(device)
    28. model = ResNet18().to(device)
    29. criteon = nn.CrossEntropyLoss().to(device)
    30. optimizer = optim.Adam(model.parameters(), lr=1e-3)
    31. print(model)
    32. for epoch in range(1000):
    33. model.train()
    34. for batchidx, (x, label) in enumerate(cifar_train):
    35. # [b, 3, 32, 32]
    36. # [b]
    37. x, label = x.to(device), label.to(device)
    38. logits = model(x)
    39. # logits: [b, 10]
    40. # label: [b]
    41. # loss: tensor scalar
    42. loss = criteon(logits, label)
    43. # backprop
    44. optimizer.zero_grad()
    45. loss.backward()
    46. optimizer.step()
    47. print(epoch, 'loss:', loss.item())
    48. model.eval()
    49. with torch.no_grad():
    50. # test
    51. total_correct = 0
    52. total_num = 0
    53. for x, label in cifar_test:
    54. # [b, 3, 32, 32]
    55. # [b]
    56. x, label = x.to(device), label.to(device)
    57. # [b, 10]
    58. logits = model(x)
    59. # [b]
    60. pred = logits.argmax(dim=1)
    61. # [b] vs [b] => scalar tensor
    62. correct = torch.eq(pred, label).float().sum().item()
    63. total_correct += correct
    64. total_num += x.size(0)
    65. # print(correct)
    66. acc = total_correct / total_num
    67. print(epoch, 'test acc:', acc)
    68. if __name__ == '__main__':
    69. main()
    1. import torch
    2. from torch import nn
    3. from torch.nn import functional as F
    4. class Lenet5(nn.Module):
    5. """
    6. for cifar10 dataset.
    7. """
    8. def __init__(self):
    9. super(Lenet5, self).__init__()
    10. self.conv_unit = nn.Sequential(
    11. # x: [b, 3, 32, 32] => [b, 16, ]
    12. nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=0),
    13. nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
    14. #
    15. nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=0),
    16. nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
    17. #
    18. )
    19. # flatten
    20. # fc unit
    21. self.fc_unit = nn.Sequential(
    22. nn.Linear(32*5*5, 32),
    23. nn.ReLU(),
    24. # nn.Linear(120, 84),
    25. # nn.ReLU(),
    26. nn.Linear(32, 10)
    27. )
    28. # [b, 3, 32, 32]
    29. tmp = torch.randn(2, 3, 32, 32)
    30. out = self.conv_unit(tmp)
    31. # [b, 16, 5, 5]
    32. print('conv out:', out.shape)
    33. # # use Cross Entropy Loss
    34. # self.criteon = nn.CrossEntropyLoss()
    35. def forward(self, x):
    36. """
    37. :param x: [b, 3, 32, 32]
    38. :return:
    39. """
    40. batchsz = x.size(0)
    41. # [b, 3, 32, 32] => [b, 16, 5, 5]
    42. x = self.conv_unit(x)
    43. # [b, 16, 5, 5] => [b, 16*5*5]
    44. x = x.view(batchsz, 32*5*5)
    45. # [b, 16*5*5] => [b, 10]
    46. logits = self.fc_unit(x)
    47. # # [b, 10]
    48. # pred = F.softmax(logits, dim=1)
    49. # loss = self.criteon(logits, y)
    50. return logits
    51. def main():
    52. net = Lenet5()
    53. tmp = torch.randn(2, 3, 32, 32)
    54. out = net(tmp)
    55. print('lenet out:', out.shape)
    56. if __name__ == '__main__':
    57. main()