image.png

    1. #加载数据部分
    2. import torchvision
    3. import torch
    4. from torchvision import transforms
    5. from torch import nn
    6. from datetime import datetime
    7. transforms = transforms.Compose([transforms.ToTensor(),
    8. transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
    9. trainset = torchvision.datasets.CIFAR10(root='./classic_dataset/cifar-10-batches-py', train=True, download=True,transform=transforms)
    10. trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=3)
    11. trainset = torchvision.datasets.CIFAR10(root='./classic_dataset/cifar-10-batches-py', train=False, download=False,transform=transforms)
    12. trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=False,num_workers=3)
    13. classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

    查看训练集中图像个数:
    image.png
    查看一共有多少个batch:
    image.png