https://pytorch.org/vision/stable/index.html

  1. import torchvision
  2. # 下载数据集
  3. train_dataset = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
  4. test_dataset = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
  5. print(train_dataset[0])
  6. print(train_dataset.classes)
  7. img, target = train_dataset[0]
  8. print(img)
  9. print(target)
  10. print(train_dataset.classes[target])
  11. img.show()

image.png

CIFAR-10数据集

https://www.cs.toronto.edu/~kriz/cifar.html
image.png

  1. import torchvision
  2. from torch.utils.tensorboard import SummaryWriter
  3. dataset_transform = torchvision.transforms.Compose([
  4. torchvision.transforms.ToTensor()
  5. ])
  6. # 下载数据集
  7. train_dataset = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
  8. test_dataset = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
  9. # 使用TensorBoard显示
  10. writer = SummaryWriter("log")
  11. for i in range(10):
  12. img, target = train_dataset[i]
  13. writer.add_image("train_dataset", img, i)
  14. writer.close()

image.png