Dataset:提供一种方式去获取数据及其label

    1. 1.如何获取每一个数据及其label
    2. 2.告诉我们总共有多少个数据

    Dataloader:为后面的网络提供不同的数据形式

    1. from torch.utils.data import Dataset
    2. from PIL import Image
    3. import os #operatopm system
    4. class mydata(Dataset):
    5. def __init__(self,root_dir,label_dir):
    6. self.root_dir = root_dir
    7. self.label_dir = label_dir
    8. self.path = os.path.join(self.root_dir,self.label_dir)
    9. self.img_path =os.listdir(self.path)
    10. def __getitem__(self, index):
    11. img_name = self.img_path[index]
    12. img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
    13. img = Image.open(img_item_path)
    14. label = self.label_dir
    15. return img,label
    16. def __len__(self):
    17. return len(self.img_path)
    18. root_dir = 'data/dataset/train'
    19. ants_label_dir = 'ants'
    20. bees_label_dir = 'bees'
    21. ants_dataset = mydata(root_dir,ants_label_dir)
    22. bees_dataset = mydata(root_dir,bees_label_dir)
    23. train_dataset = ants_dataset + bees_dataset
    24. print(len(train_dataset))
    25. print(len(ants_dataset))
    26. print(len(bees_dataset))
    27. img,label = train_dataset[123]
    28. img.show()