#加载数据部分
import torchvision
import torch
from torchvision import transforms
from torch import nn
from datetime import datetime
transforms = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset = torchvision.datasets.CIFAR10(root='./classic_dataset/cifar-10-batches-py', train=True, download=True,transform=transforms)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=3)
trainset = torchvision.datasets.CIFAR10(root='./classic_dataset/cifar-10-batches-py', train=False, download=False,transform=transforms)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=False,num_workers=3)
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
查看训练集中图像个数:
查看一共有多少个batch: