Dataset:提供一种方式去获取数据及其label
1.如何获取每一个数据及其label2.告诉我们总共有多少个数据
Dataloader:为后面的网络提供不同的数据形式
from torch.utils.data import Datasetfrom PIL import Imageimport os #operatopm systemclass mydata(Dataset):def __init__(self,root_dir,label_dir):self.root_dir = root_dirself.label_dir = label_dirself.path = os.path.join(self.root_dir,self.label_dir)self.img_path =os.listdir(self.path)def __getitem__(self, index):img_name = self.img_path[index]img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)img = Image.open(img_item_path)label = self.label_dirreturn img,labeldef __len__(self):return len(self.img_path)root_dir = 'data/dataset/train'ants_label_dir = 'ants'bees_label_dir = 'bees'ants_dataset = mydata(root_dir,ants_label_dir)bees_dataset = mydata(root_dir,bees_label_dir)train_dataset = ants_dataset + bees_datasetprint(len(train_dataset))print(len(ants_dataset))print(len(bees_dataset))img,label = train_dataset[123]img.show()
