dataloader的使用

dataset是准备数据集

dataloader是将准备好的数据集按照一定方式读取出来,输入到神经网络中。

  1. import torchvision
  2. from torch.utils.data import DataLoader
  3. from torch.utils.tensorboard import SummaryWriter
  4. trans_dataset = torchvision.transforms.Compose([
  5. torchvision.transforms.ToTensor()
  6. ])
  7. #dataset准备好数据集,dataloader将数据集按一定方式读取
  8. test_data = torchvision.datasets.CIFAR10(root='./dataset',train=True,transform=trans_dataset,download=True)
  9. test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,drop_last=False)
  10. # img,target = test_data[0]
  11. # print(img)
  12. # print(img.shape)
  13. writer = SummaryWriter('dataloader')
  14. #每个epoch中dataloader读取一次,训练两个epoch
  15. for epoch in range(2):
  16. step = 0
  17. for data in test_loader:
  18. imgs,targets = data
  19. writer.add_images('epoch:{}'.format(epoch),imgs,step)
  20. step = step+1
  21. writer.close()
  22. print('over')