https://pytorch.org/vision/stable/index.html
import torchvision
# 下载数据集
train_dataset = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_dataset = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
print(train_dataset[0])
print(train_dataset.classes)
img, target = train_dataset[0]
print(img)
print(target)
print(train_dataset.classes[target])
img.show()
CIFAR-10数据集
https://www.cs.toronto.edu/~kriz/cifar.html
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# 下载数据集
train_dataset = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_dataset = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
# 使用TensorBoard显示
writer = SummaryWriter("log")
for i in range(10):
img, target = train_dataset[i]
writer.add_image("train_dataset", img, i)
writer.close()