最近PyTorch因其使用和学习简单而获得了很多人的欢迎。特斯拉AI高级总监Andrej Karpathy在他的推文中说了以下内容。
PyTorch for Beginners:基础知识
PyTorch for Beginners:使用预训练模型进行图像分类
PyTorch for Beginners:使用torchvision进行语义分割



训练整个数据集需要数小时,因此我们将研究包含10种动物的数据集的子集 - 熊,黑猩猩,长颈鹿,大猩猩,美洲驼,鸵鸟,豪猪,臭鼬,三角龙斑马。这样我们就可以更快地进行实验 然后,代码也可用于训练整个数据集。

  1. 下载 CalTech256 数据集
  2. 使用名称train,validtest创建三个目录 。
  3. 在列车和测试目录中创建10个子目录。子目录应该命名为熊,黑猩猩,长颈鹿,大猩猩,美洲驼,鸵鸟,豪猪,臭鼬,三角龙斑马
  4. 将Caltech256数据集中的前60张图像移动到目录train/bear,并为每只动物重复此操作。
  5. 将接下来的10张图像在Caltech256数据集中移动到目录valid/bear,并为每只动物重复此操作。
  6. 将剩余的熊图像(即未包含在列车或有效文件夹中的图像)复制到目录test/bear。对每只动物重复这个。


    标准化采用3通道张量,并通过通道的输入平均值和标准偏差对每个通道进行标准化。平均和标准偏差矢量作为3个元素矢量输入。张量中的每个通道归一化为T =(T - mean)/(标准差)
  1. # Applying Transforms to the Data
  2. image_transforms = {
  3. 'train': transforms.Compose([
  4. transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
  5. transforms.RandomRotation(degrees=15),
  6. transforms.RandomHorizontalFlip(),
  7. transforms.CenterCrop(size=224),
  8. transforms.ToTensor(),
  9. transforms.Normalize([0.485, 0.456, 0.406],
  10. [0.229, 0.224, 0.225])
  11. ]),
  12. 'valid': transforms.Compose([
  13. transforms.Resize(size=256),
  14. transforms.CenterCrop(size=224),
  15. transforms.ToTensor(),
  16. transforms.Normalize([0.485, 0.456, 0.406],
  17. [0.229, 0.224, 0.225])
  18. ]),
  19. 'test': transforms.Compose([
  20. transforms.Resize(size=256),
  21. transforms.CenterCrop(size=224),
  22. transforms.ToTensor(),
  23. transforms.Normalize([0.485, 0.456, 0.406],
  24. [0.229, 0.224, 0.225])
  25. ])
  26. }




  1. Image Classification using Transfer Learning in PyTorch
  3. Image Classification Results
  4. Recently PyTorch has gained a lot of popularity because of its simplicity to use and learn. Andrej Karpathy, Senior Director of AI at Tesla, said the following in his tweet.
  5. Putting jokes aside, PyTorch is also very transparent and can help researchers and data scientists achieve high productivity.
  6. This blog is part of the following series:
  7. PyTorch for Beginners
  8. PyTorch for Beginners: Basics
  9. PyTorch for Beginners: Image Classification using Pre-trained models
  10. Image Classification using Transfer Learning in PyTorch
  11. PyTorch Model Inference using ONNX and Caffe2
  12. PyTorch for Beginners: Semantic Segmentation using torchvision
  13. Object Detection
  14. Human Pose Keypoint Detection
  15. In this post, we describe how to do image classification in PyTorch.
  16. We will use a subset of the CalTech256 dataset to classify images of 10 different kinds of animals. We will go over the dataset preparation, data augmentation and then steps to build the classifier. We use transfer learning to use the low level image features like edges, textures etc. learnt by a pretrained model, ResNet50, and then train our classifier to learn the higher level details in our dataset images like eyes, legs etc. ResNet50 has already been trained on ImageNet with millions of images.
  17. We share a python notebook with the complete code and share important snippets in this post so that the reader can understand how it works.
  20. While we have tried to make the blog post self sufficient,
  21. we still encourage the readers to get familiarized to the basics of Pytorch before proceeding further.
  22. Dataset Preparation
  23. The CalTech256 dataset has 30,607 images categorized into 256 different labeled classes along with another ‘clutter’ class.
  24. Training the whole dataset will take hours, so we will work on a subset of the dataset containing 10 animals – bear, chimp, giraffe, gorilla, llama, ostrich, porcupine, skunk, triceratops and zebra. That way we can experiment faster. The code can then be used to train the whole dataset too.
  25. The number of images in these folders varies from 81(for skunk) to 212(for gorilla). We use the first 60 images in each of these categories for training, the next 10 images for validation and the rest for testing in our experiments below.
  26. So finally we have 600 training images, 100 validation images, 409 test images and 10 classes of animals.
  27. If you want to replicate the experiments, please follow the steps below
  28. Download the CalTech256 dataset
  29. Create three directories with names train, valid and test.
  30. Create 10 sub-directories each inside the train and the test directories. The sub-directories should be named bear, chimp, giraffe, gorilla, llama, ostrich, porcupine, skunk, triceratops and zebra.
  31. Move the first 60 images for bear in the Caltech256 dataset to the directory train/bear, and repeat this for every animal.
  32. Move the next 10 images for bear in the Caltech256 dataset to the directory valid/bear, and repeat this for every animal.
  33. Copy the remaining images for bear (i.e. the ones not included in train or valid folders) to the directory test/bear. Repeat this for every animal.
  34. Data Augmentation
  35. The images in the available training set can be modified in a number of ways to incorporate more variations in the training process, so that the trained model gets more generalized and performs well on different kinds of test data. Also the input data can come in a variety of sizes. They need to be normalized to a fixed size and format before batches of data are used together for training.
  36. Each of the input images are first passed through a number of transformations. We try to insert some variations by introducing some randomness into the transformations. In each epoch, a single set of transformations are applied to each image. When we train for multiple epochs, the models gets to see more variations of the input images with a new randomized variation of the transformation in each epoch. This results in data augmentation and the model then tries to generalize more.
  37. Below we see examples of the transformed versions of a Triceratops image.
  38. Data Augmentation
  39. Transformed versions of a Triceratops image
  40. Let us go over the transformations we used for our data augmentation.
  41. The transform RandomResizedCrop crops the input image by a random size(within a scale range of 0.8 to 1.0 of the original size and a random aspect ratio in the default range of 0.75 to 1.33 ). The crop is then resized to 256×256.
  42. RandomRotation rotates the image by an angle randomly chosen between -15 to 15 degrees.
  43. RandomHorizontalFlip randomly flips the image horizontally with a default probability of 50%.
  44. CenterCrop crops an 224×224 image from the center.
  45. ToTensor converts the PIL Image which has values in the range of 0-255 to a floating point Tensor and normalizes them to a range of 0-1, by dividing it by 255.
  46. Normalize takes in a 3 channel Tensor and normalizes each channel by the input mean and standard deviation for the channel. Mean and standard deviation vectors are input as 3 element vectors. Each channel in the tensor is normalized as T = (T – mean)/(standard deviation)
  47. All the above transformations are chained together using Compose.
  48. # Applying Transforms to the Data
  49. image_transforms = {
  50. 'train': transforms.Compose([
  51. transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
  52. transforms.RandomRotation(degrees=15),
  53. transforms.RandomHorizontalFlip(),
  54. transforms.CenterCrop(size=224),
  55. transforms.ToTensor(),
  56. transforms.Normalize([0.485, 0.456, 0.406],
  57. [0.229, 0.224, 0.225])
  58. ]),
  59. 'valid': transforms.Compose([
  60. transforms.Resize(size=256),
  61. transforms.CenterCrop(size=224),
  62. transforms.ToTensor(),
  63. transforms.Normalize([0.485, 0.456, 0.406],
  64. [0.229, 0.224, 0.225])
  65. ]),
  66. 'test': transforms.Compose([
  67. transforms.Resize(size=256),
  68. transforms.CenterCrop(size=224),
  69. transforms.ToTensor(),
  70. transforms.Normalize([0.485, 0.456, 0.406],
  71. [0.229, 0.224, 0.225])
  72. ])
  73. }
  74. Note that for the validation and test data, we do not do the RandomResizedCrop, RandomRotation and RandomHorizontalFlip transformations. Instead, we just resize the validation images to 256×256 and crop out the center 224×224 in order to be able to use them with the pretrained model. Then the image is transformed into a tensor and normalized by the mean and standard deviation of all images in ImageNet.
  75. Data Loading
  76. Next, let us see how to use the above defined transformations and load the data to be used for training.
  77. # Load the Data
  78. # Set train and valid directory paths
  79. train_directory = 'train'
  80. valid_directory = 'test'
  81. # Batch size
  82. bs = 32
  83. # Number of classes
  84. num_classes = 10
  85. # Load Data from folders
  86. data = {
  87. 'train': datasets.ImageFolder(root=train_directory, transform=image_transforms['train']),
  88. 'valid': datasets.ImageFolder(root=valid_directory, transform=image_transforms['valid']),
  89. 'test': datasets.ImageFolder(root=test_directory, transform=image_transforms['test'])
  90. }
  91. # Size of Data, to be used for calculating Average Loss and Accuracy
  92. train_data_size = len(data['train'])
  93. valid_data_size = len(data['valid'])
  94. test_data_size = len(data['test'])
  95. # Create iterators for the Data loaded using DataLoader module
  96. train_data = DataLoader(data['train'], batch_size=bs, shuffle=True)
  97. valid_data = DataLoader(data['valid'], batch_size=bs, shuffle=True)
  98. test_data = DataLoader(data['test'], batch_size=bs, shuffle=True)
  99. # Print the train, validation and test set data sizes
  100. train_data_size, valid_data_size, test_data_size




# Load pretrained ResNet50 Model
resnet50 ``= models.resnet50(pretrained``=``True``)

Canziani等人列出了许多用于各种实际应用的预训练模型,分析了所获得的精度以及每个模型所需的推理时间。ResNet50是精确度和推理时间之间具有良好折衷的那些之一。在PyTorch中加载模型时,默认情况下其所有参数的’ requires_grad ‘字段都设置为true。这意味着将存储对参数值的每个改变,以便在用于训练的反向传播图中使用。这增加了内存需求。因此,由于我们已经训练过预训练模型中的大多数参数,因此我们将requires_grad字段重置为false。

  1. # Freeze model parameters
  2. for param in resnet50.parameters():
  3. param.requires_grad = False


# Convert model to be used on GPU
resnet50 ``= resnet50.to(``'cuda:0'``)


  1. # Define Optimizer and Loss Function
  2. loss_func = nn.NLLLoss()
  3. optimizer = optim.Adam(resnet50.parameters())


PyTorch中使用迁移学习的图像分类 - 图5

  1. for epoch in range(epochs):
  2. epoch_start = time.time()
  3. print("Epoch: {}/{}".format(epoch+1, epochs))
  4. # Set to training mode
  5. model.train()
  6. # Loss and Accuracy within the epoch
  7. train_loss = 0.0
  8. train_acc = 0.0
  9. valid_loss = 0.0
  10. valid_acc = 0.0
  11. for i, (inputs, labels) in enumerate(train_data_loader):
  12. inputs = inputs.to(device)
  13. labels = labels.to(device)
  14. # Clean existing gradients
  15. optimizer.zero_grad()
  16. # Forward pass - compute outputs on input data using the model
  17. outputs = model(inputs)
  18. # Compute loss
  19. loss = loss_criterion(outputs, labels)
  20. # Backpropagate the gradients
  21. loss.backward()
  22. # Update the parameters
  23. optimizer.step()
  24. # Compute the total loss for the batch and add it to train_loss
  25. train_loss += loss.item() * inputs.size(0)
  26. # Compute the accuracy
  27. ret, predictions = torch.max(outputs.data, 1)
  28. correct_counts = predictions.eq(labels.data.view_as(predictions))
  29. # Convert correct_counts to float and then compute the mean
  30. acc = torch.mean(correct_counts.type(torch.FloatTensor))
  31. # Compute total accuracy in the whole batch and add to train_acc
  32. train_acc += acc.item() * inputs.size(0)
  33. print("Batch number: {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}".format(i, loss.item(), acc.item()))



  1. # Validation - No gradient tracking needed
  2. with torch.no_grad():
  3. # Set to evaluation mode
  4. model.eval()
  5. # Validation loop
  6. for j, (inputs, labels) in enumerate(valid_data_loader):
  7. inputs = inputs.to(device)
  8. labels = labels.to(device)
  9. # Forward pass - compute outputs on input data using the model
  10. outputs = model(inputs)
  11. # Compute loss
  12. loss = loss_criterion(outputs, labels)
  13. # Compute the total loss for the batch and add it to valid_loss
  14. valid_loss += loss.item() * inputs.size(0)
  15. # Calculate validation accuracy
  16. ret, predictions = torch.max(outputs.data, 1)
  17. correct_counts = predictions.eq(labels.data.view_as(predictions))
  18. # Convert correct_counts to float and then compute the mean
  19. acc = torch.mean(correct_counts.type(torch.FloatTensor))
  20. # Compute total accuracy in the whole batch and add to valid_acc
  21. valid_acc += acc.item() * inputs.size(0)
  22. print("Validation Batch number: {:03d}, Validation: Loss: {:.4f}, Accuracy: {:.4f}".format(j, loss.item(), acc.item()))
  23. # Find average training loss and training accuracy
  24. avg_train_loss = train_loss/train_data_size
  25. avg_train_acc = train_acc/float(train_data_size)
  26. # Find average training loss and training accuracy
  27. avg_valid_loss = valid_loss/valid_data_size
  28. avg_valid_acc = valid_acc/float(valid_data_size)
  29. history.append([avg_train_loss, avg_valid_loss, avg_train_acc, avg_valid_acc])
  30. epoch_end = time.time()
  31. print("Epoch : {:03d}, Training: Loss: {:.4f}, Accuracy: {:.4f}%, \n\t\tValidation : Loss : {:.4f}, Accuracy: {:.4f}%, Time: {:.4f}s".format(epoch, avg_train_loss, avg_train_acc*100, avg_valid_loss, avg_valid_acc*100, epoch_end-epoch_start))




  1. def predict(model, test_image_name):
  2. transform = image_transforms['test']
  3. test_image = Image.open(test_image_name)
  4. plt.imshow(test_image)
  5. test_image_tensor = transform(test_image)
  6. if torch.cuda.is_available():
  7. test_image_tensor = test_image_tensor.view(1, 3, 224, 224).cuda()
  8. else:
  9. test_image_tensor = test_image_tensor.view(1, 3, 224, 224)
  10. with torch.no_grad():
  11. model.eval()
  12. # Model outputs log probabilities
  13. out = model(test_image_tensor)
  14. ps = torch.exp(out)
  15. topk, topclass = ps.topk(1, dim=1)
  16. print("Output class : ", idx_to_class[topclass.cpu().numpy()[0][0]])

我们在一个小数据集上显示了分类结果。在以后的文章中,我们将在更难的数据集上应用相同的迁移学习方法来解决更难的现实问题。敬请关注 !


我感谢我们的实习生Kushashwa Ravi Shrimali撰写了这篇文章的部分代码。


  • Griffin,Gregory和Holub,Alex和Perona,Pietro(2007) Caltech-256 Object Category Dataset。

  • 博客文章中使用的图片:[1][2][3][4][5][6][7][8][9][10]11 [] ]


