dataloader的使用
dataset是准备数据集
dataloader是将准备好的数据集按照一定方式读取出来,输入到神经网络中。
import torchvisionfrom torch.utils.data import DataLoaderfrom torch.utils.tensorboard import SummaryWritertrans_dataset = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])#dataset准备好数据集,dataloader将数据集按一定方式读取test_data = torchvision.datasets.CIFAR10(root='./dataset',train=True,transform=trans_dataset,download=True)test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,drop_last=False)# img,target = test_data[0]# print(img)# print(img.shape)writer = SummaryWriter('dataloader')#每个epoch中dataloader读取一次,训练两个epochfor epoch in range(2):step = 0for data in test_loader:imgs,targets = datawriter.add_images('epoch:{}'.format(epoch),imgs,step)step = step+1writer.close()print('over')
