两个类

Dataset

  • 提供一种方式去获取数据及其label
  1. 如何获取每一个数据及其label
  2. 告诉我们总共有多少数据
  • 在介绍Dataset时提到了一些有关于神经网络的东西,暂时不了解

    • 在训练时会有大量数据,但是其中有垃圾数据,Dataset的作用
      是从中提取有用数据,并将其打包递给神经网络
  • 示例代码:
  1. from PIL import Image
  2. from torch.utils.data import Dataset
  3. import os #系统类
  4. img_path = "dataset/train/ants/0013035.jpg"
  5. image = Image.open(img_path)
  6. # image.show()
  7. dir_path = "dataset/train/ants"
  8. img_path_list = os.listdir(dir_path)
  9. print(img_path_list[0])
  10. class MyData(Dataset):
  11. def __init__(self,root_dir,label_dir): #类似于构造函数
  12. #self相当于指定类中的一个全局变量,类似于java的this?
  13. self.root_dir = root_dir
  14. self.label_dir = label_dir
  15. self.path = os.path.join(self.root_dir,self.label_dir)
  16. self.img_path = os.listdir(self.path)
  17. #这个构造函数只用来初始化变量,没有其他操作
  18. # self.paht : 根目录 + 下一级目录, 是一个目录名
  19. # os.listdir(path) 返回指定路径下所有文件和文件夹的名字,并存放于一个列表中。
  20. # img_path : 是一个列表,里面保存了 path的所有图片的名字
  21. def __getitem__(self, index):
  22. img_name = self.img_path[index]
  23. img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
  24. #这样写不会造成路径重合吗? label_dir只是一个文件夹名,所以不会造成路径重合
  25. # root_dir : 根目录
  26. # label_dir : 根目录下面一级目录
  27. # img_name : 图片名称
  28. img = Image.open(img_item_path)
  29. label = self.label_dir
  30. return img,label
  31. def __len__(self):
  32. # 查看返回值的长度
  33. return len(self.img_path)
  34. # 将类实例化
  35. root_dir = "dataset/train"
  36. label_dir = "ants"
  37. bee_dir = "bees"
  38. ants_dataset = MyData(root_dir,label_dir)
  39. bees_dataset = MyData(root_dir,bee_dir)
  40. img,label = ants_dataset[0]
  41. img2,label = bees_dataset[0]
  42. #img.show()
  43. #img2.show()
  44. # 将两个数据集合二为一
  45. train_dataset = ants_dataset + bees_dataset
  46. print(len(train_dataset))
  47. print(len(ants_dataset))
  48. print(len(bees_dataset))
  49. img3,label = train_dataset[123]
  50. # img3.show()
  51. img4,label = train_dataset[130]
  52. # img4.show()
  53. # 这样可以用于扩充数据集

Dataloader

  • 为后面的网络提供不同的数据形式
  • loader,即为加载器,可以将数据加载到神经网络
  • 参数设置:
  1. batch_size : 每次加载多少个数据
  2. shuffle :若为True则顺序多次加载数据时数据的顺序不一样,为False时数据一样,默认为False
  3. num_workers : 进程数,一般设置为0,使用主进程,在Windows下设置数值大于0时容易出错,通常为0
  4. drop_last : true时舍弃数据,false时不舍弃
  5. transforms : 要设置为torchvision.transforms.ToTensor(),不然for循环会报错
  • 示例代码:
  1. import torchvision
  2. from torch.utils.data import DataLoader
  3. test_dataset = torchvision.datasets.CIFAR10(root="./download_data",train=False,transform=torchvision.transforms.ToTensor(),download=True)
  4. # 要设置transforms的参数,不然for循环会出错
  5. dataloader = DataLoader(dataset=test_dataset,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
  6. print(type(dataloader))
  7. for img,label in enumerate(dataloader):
  8. # img,label = data
  9. print(label)
  10. for data in dataloader:
  11. img,label = data
  12. print(label)
  13. '''
  14. img, target = test_dataset[0]
  15. print(type(img))
  16. print(target)
  17. # img.show()
  18. '''

打开图片

示例代码

  1. # 打开单个图片
  2. from PIL import Image
  3. img_paht = "path_str"
  4. img = Image.open(img_path)
  5. img.show()
  1. # 所有图片路径作为一个list
  2. import os
  3. dir_path = "dir_path" #文件夹路径
  4. img_path_list = os.listdir(dir_path)
  5. img_path_list[0] #获取第一个图片的路径
  • 都是基本代码,基本上一看就懂
  1. os.path.join(root_dir,label_dir)
  2. # 将两个字符串以路径的方式拼接,Windows中默认为 \ 这个函数
  3. # 会自动写为 /