宝可梦数据集.pdf

    1. import torch
    2. import os, glob
    3. import random, csv
    4. from torch.utils.data import Dataset, DataLoader
    5. from torchvision import transforms
    6. from PIL import Image
    7. class Pokemon(Dataset):
    8. def __init__(self, root, resize, mode):
    9. super(Pokemon, self).__init__()
    10. self.root = root
    11. self.resize = resize
    12. self.name2label = {} # "sq...":0
    13. for name in sorted(os.listdir(os.path.join(root))):
    14. if not os.path.isdir(os.path.join(root, name)):
    15. continue
    16. self.name2label[name] = len(self.name2label.keys())
    17. # print(self.name2label)
    18. # image, label
    19. self.images, self.labels = self.load_csv('images.csv')
    20. if mode=='train': # 60%
    21. self.images = self.images[:int(0.6*len(self.images))]
    22. self.labels = self.labels[:int(0.6*len(self.labels))]
    23. elif mode=='val': # 20% = 60%->80%
    24. self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
    25. self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
    26. else: # 20% = 80%->100%
    27. self.images = self.images[int(0.8*len(self.images)):]
    28. self.labels = self.labels[int(0.8*len(self.labels)):]
    29. def load_csv(self, filename):
    30. if not os.path.exists(os.path.join(self.root, filename)):
    31. images = []
    32. for name in self.name2label.keys():
    33. # 'pokemon\\mewtwo\\00001.png
    34. images += glob.glob(os.path.join(self.root, name, '*.png'))
    35. images += glob.glob(os.path.join(self.root, name, '*.jpg'))
    36. images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
    37. # 1167, 'pokemon\\bulbasaur\\00000000.png'
    38. print(len(images), images)
    39. random.shuffle(images)
    40. with open(os.path.join(self.root, filename), mode='w', newline='') as f:
    41. writer = csv.writer(f)
    42. for img in images: # 'pokemon\\bulbasaur\\00000000.png'
    43. name = img.split(os.sep)[-2]
    44. label = self.name2label[name]
    45. # 'pokemon\\bulbasaur\\00000000.png', 0
    46. writer.writerow([img, label])
    47. print('writen into csv file:', filename)
    48. # read from csv file
    49. images, labels = [], []
    50. with open(os.path.join(self.root, filename)) as f:
    51. reader = csv.reader(f)
    52. for row in reader:
    53. # 'pokemon\\bulbasaur\\00000000.png', 0
    54. img, label = row
    55. label = int(label)
    56. images.append(img)
    57. labels.append(label)
    58. assert len(images) == len(labels)
    59. return images, labels
    60. def __len__(self):
    61. return len(self.images)
    62. def denormalize(self, x_hat):
    63. mean = [0.485, 0.456, 0.406]
    64. std = [0.229, 0.224, 0.225]
    65. # x_hat = (x-mean)/std
    66. # x = x_hat*std = mean
    67. # x: [c, h, w]
    68. # mean: [3] => [3, 1, 1]
    69. mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
    70. std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
    71. # print(mean.shape, std.shape)
    72. x = x_hat * std + mean
    73. return x
    74. def __getitem__(self, idx):
    75. # idx~[0~len(images)]
    76. # self.images, self.labels
    77. # img: 'pokemon\\bulbasaur\\00000000.png'
    78. # label: 0
    79. img, label = self.images[idx], self.labels[idx]
    80. tf = transforms.Compose([
    81. lambda x:Image.open(x).convert('RGB'), # string path= > image data
    82. transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
    83. transforms.RandomRotation(15),
    84. transforms.CenterCrop(self.resize),
    85. transforms.ToTensor(),
    86. transforms.Normalize(mean=[0.485, 0.456, 0.406],
    87. std=[0.229, 0.224, 0.225])
    88. ])
    89. img = tf(img)
    90. label = torch.tensor(label)
    91. return img, label
    92. def main():
    93. import visdom
    94. import time
    95. import torchvision
    96. viz = visdom.Visdom()
    97. # tf = transforms.Compose([
    98. # transforms.Resize((64,64)),
    99. # transforms.ToTensor(),
    100. # ])
    101. # db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
    102. # loader = DataLoader(db, batch_size=32, shuffle=True)
    103. #
    104. # print(db.class_to_idx)
    105. #
    106. # for x,y in loader:
    107. # viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
    108. # viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
    109. #
    110. # time.sleep(10)
    111. db = Pokemon('pokemon', 64, 'train')
    112. x,y = next(iter(db))
    113. print('sample:', x.shape, y.shape, y)
    114. viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))
    115. loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)
    116. for x,y in loader:
    117. viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
    118. viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
    119. time.sleep(10)
    120. if __name__ == '__main__':
    121. main()
    1. import torch
    2. from torch import nn
    3. from torch.nn import functional as F
    4. class ResBlk(nn.Module):
    5. """
    6. resnet block
    7. """
    8. def __init__(self, ch_in, ch_out, stride=1):
    9. """
    10. :param ch_in:
    11. :param ch_out:
    12. """
    13. super(ResBlk, self).__init__()
    14. self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
    15. self.bn1 = nn.BatchNorm2d(ch_out)
    16. self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
    17. self.bn2 = nn.BatchNorm2d(ch_out)
    18. self.extra = nn.Sequential()
    19. if ch_out != ch_in:
    20. # [b, ch_in, h, w] => [b, ch_out, h, w]
    21. self.extra = nn.Sequential(
    22. nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
    23. nn.BatchNorm2d(ch_out)
    24. )
    25. def forward(self, x):
    26. """
    27. :param x: [b, ch, h, w]
    28. :return:
    29. """
    30. out = F.relu(self.bn1(self.conv1(x)))
    31. out = self.bn2(self.conv2(out))
    32. # short cut.
    33. # extra module: [b, ch_in, h, w] => [b, ch_out, h, w]
    34. # element-wise add:
    35. out = self.extra(x) + out
    36. out = F.relu(out)
    37. return out
    38. class ResNet18(nn.Module):
    39. def __init__(self, num_class):
    40. super(ResNet18, self).__init__()
    41. self.conv1 = nn.Sequential(
    42. nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0),
    43. nn.BatchNorm2d(16)
    44. )
    45. # followed 4 blocks
    46. # [b, 16, h, w] => [b, 32, h ,w]
    47. self.blk1 = ResBlk(16, 32, stride=3)
    48. # [b, 32, h, w] => [b, 64, h, w]
    49. self.blk2 = ResBlk(32, 64, stride=3)
    50. # # [b, 64, h, w] => [b, 128, h, w]
    51. self.blk3 = ResBlk(64, 128, stride=2)
    52. # # [b, 128, h, w] => [b, 256, h, w]
    53. self.blk4 = ResBlk(128, 256, stride=2)
    54. # [b, 256, 7, 7]
    55. self.outlayer = nn.Linear(256*3*3, num_class)
    56. def forward(self, x):
    57. """
    58. :param x:
    59. :return:
    60. """
    61. x = F.relu(self.conv1(x))
    62. # [b, 64, h, w] => [b, 1024, h, w]
    63. x = self.blk1(x)
    64. x = self.blk2(x)
    65. x = self.blk3(x)
    66. x = self.blk4(x)
    67. # print(x.shape)
    68. x = x.view(x.size(0), -1)
    69. x = self.outlayer(x)
    70. return x
    71. def main():
    72. blk = ResBlk(64, 128)
    73. tmp = torch.randn(2, 64, 224, 224)
    74. out = blk(tmp)
    75. print('block:', out.shape)
    76. model = ResNet18(5)
    77. tmp = torch.randn(2, 3, 224, 224)
    78. out = model(tmp)
    79. print('resnet:', out.shape)
    80. p = sum(map(lambda p:p.numel(), model.parameters()))
    81. print('parameters size:', p)
    82. if __name__ == '__main__':
    83. main()
    1. import torch
    2. from torch import optim, nn
    3. import visdom
    4. import torchvision
    5. from torch.utils.data import DataLoader
    6. from pokemon import Pokemon
    7. from resnet import ResNet18
    8. batchsz = 32
    9. lr = 1e-3
    10. epochs = 10
    11. device = torch.device('cuda')
    12. torch.manual_seed(1234)
    13. train_db = Pokemon('pokemon', 224, mode='train')
    14. val_db = Pokemon('pokemon', 224, mode='val')
    15. test_db = Pokemon('pokemon', 224, mode='test')
    16. train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
    17. num_workers=4)
    18. val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
    19. test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
    20. viz = visdom.Visdom()
    21. def evalute(model, loader):
    22. model.eval()
    23. correct = 0
    24. total = len(loader.dataset)
    25. for x,y in loader:
    26. x,y = x.to(device), y.to(device)
    27. with torch.no_grad():
    28. logits = model(x)
    29. pred = logits.argmax(dim=1)
    30. correct += torch.eq(pred, y).sum().float().item()
    31. return correct / total
    32. def main():
    33. model = ResNet18(5).to(device)
    34. optimizer = optim.Adam(model.parameters(), lr=lr)
    35. criteon = nn.CrossEntropyLoss()
    36. best_acc, best_epoch = 0, 0
    37. global_step = 0
    38. viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    39. viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    40. for epoch in range(epochs):
    41. for step, (x,y) in enumerate(train_loader):
    42. # x: [b, 3, 224, 224], y: [b]
    43. x, y = x.to(device), y.to(device)
    44. model.train()
    45. logits = model(x)
    46. loss = criteon(logits, y)
    47. optimizer.zero_grad()
    48. loss.backward()
    49. optimizer.step()
    50. viz.line([loss.item()], [global_step], win='loss', update='append')
    51. global_step += 1
    52. if epoch % 1 == 0:
    53. val_acc = evalute(model, val_loader)
    54. if val_acc> best_acc:
    55. best_epoch = epoch
    56. best_acc = val_acc
    57. torch.save(model.state_dict(), 'best.mdl')
    58. viz.line([val_acc], [global_step], win='val_acc', update='append')
    59. print('best acc:', best_acc, 'best epoch:', best_epoch)
    60. model.load_state_dict(torch.load('best.mdl'))
    61. print('loaded from ckpt!')
    62. test_acc = evalute(model, test_loader)
    63. print('test acc:', test_acc)
    64. if __name__ == '__main__':
    65. main()
    1. import torch
    2. from torch import optim, nn
    3. import visdom
    4. import torchvision
    5. from torch.utils.data import DataLoader
    6. from pokemon import Pokemon
    7. # from resnet import ResNet18
    8. from torchvision.models import resnet18
    9. from utils import Flatten
    10. batchsz = 32
    11. lr = 1e-3
    12. epochs = 10
    13. device = torch.device('cuda')
    14. torch.manual_seed(1234)
    15. train_db = Pokemon('pokemon', 224, mode='train')
    16. val_db = Pokemon('pokemon', 224, mode='val')
    17. test_db = Pokemon('pokemon', 224, mode='test')
    18. train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
    19. num_workers=4)
    20. val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
    21. test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
    22. viz = visdom.Visdom()
    23. def evalute(model, loader):
    24. model.eval()
    25. correct = 0
    26. total = len(loader.dataset)
    27. for x,y in loader:
    28. x,y = x.to(device), y.to(device)
    29. with torch.no_grad():
    30. logits = model(x)
    31. pred = logits.argmax(dim=1)
    32. correct += torch.eq(pred, y).sum().float().item()
    33. return correct / total
    34. def main():
    35. # model = ResNet18(5).to(device)
    36. trained_model = resnet18(pretrained=True)
    37. model = nn.Sequential(*list(trained_model.children())[:-1], #[b, 512, 1, 1]
    38. Flatten(), # [b, 512, 1, 1] => [b, 512]
    39. nn.Linear(512, 5)
    40. ).to(device)
    41. # x = torch.randn(2, 3, 224, 224)
    42. # print(model(x).shape)
    43. optimizer = optim.Adam(model.parameters(), lr=lr)
    44. criteon = nn.CrossEntropyLoss()
    45. best_acc, best_epoch = 0, 0
    46. global_step = 0
    47. viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    48. viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    49. for epoch in range(epochs):
    50. for step, (x,y) in enumerate(train_loader):
    51. # x: [b, 3, 224, 224], y: [b]
    52. x, y = x.to(device), y.to(device)
    53. model.train()
    54. logits = model(x)
    55. loss = criteon(logits, y)
    56. optimizer.zero_grad()
    57. loss.backward()
    58. optimizer.step()
    59. viz.line([loss.item()], [global_step], win='loss', update='append')
    60. global_step += 1
    61. if epoch % 1 == 0:
    62. val_acc = evalute(model, val_loader)
    63. if val_acc> best_acc:
    64. best_epoch = epoch
    65. best_acc = val_acc
    66. torch.save(model.state_dict(), 'best.mdl')
    67. viz.line([val_acc], [global_step], win='val_acc', update='append')
    68. print('best acc:', best_acc, 'best epoch:', best_epoch)
    69. model.load_state_dict(torch.load('best.mdl'))
    70. print('loaded from ckpt!')
    71. test_acc = evalute(model, test_loader)
    72. print('test acc:', test_acc)
    73. if __name__ == '__main__':
    74. main()
    1. from matplotlib import pyplot as plt
    2. import torch
    3. from torch import nn
    4. class Flatten(nn.Module):
    5. def __init__(self):
    6. super(Flatten, self).__init__()
    7. def forward(self, x):
    8. shape = torch.prod(torch.tensor(x.shape[1:])).item()
    9. return x.view(-1, shape)
    10. def plot_image(img, label, name):
    11. fig = plt.figure()
    12. for i in range(6):
    13. plt.subplot(2, 3, i + 1)
    14. plt.tight_layout()
    15. plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
    16. plt.title("{}: {}".format(name, label[i].item()))
    17. plt.xticks([])
    18. plt.yticks([])
    19. plt.show()