两个类
Dataset
- 提供一种方式去获取数据及其label
- 如何获取每一个数据及其label
- 告诉我们总共有多少数据
在介绍Dataset时提到了一些有关于神经网络的东西,暂时不了解
- 在训练时会有大量数据,但是其中有垃圾数据,Dataset的作用
是从中提取有用数据,并将其打包递给神经网络
- 在训练时会有大量数据,但是其中有垃圾数据,Dataset的作用
- 示例代码:
from PIL import Image
from torch.utils.data import Dataset
import os #系统类
img_path = "dataset/train/ants/0013035.jpg"
image = Image.open(img_path)
# image.show()
dir_path = "dataset/train/ants"
img_path_list = os.listdir(dir_path)
print(img_path_list[0])
class MyData(Dataset):
def __init__(self,root_dir,label_dir): #类似于构造函数
#self相当于指定类中的一个全局变量,类似于java的this?
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir,self.label_dir)
self.img_path = os.listdir(self.path)
#这个构造函数只用来初始化变量,没有其他操作
# self.paht : 根目录 + 下一级目录, 是一个目录名
# os.listdir(path) 返回指定路径下所有文件和文件夹的名字,并存放于一个列表中。
# img_path : 是一个列表,里面保存了 path的所有图片的名字
def __getitem__(self, index):
img_name = self.img_path[index]
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
#这样写不会造成路径重合吗? label_dir只是一个文件夹名,所以不会造成路径重合
# root_dir : 根目录
# label_dir : 根目录下面一级目录
# img_name : 图片名称
img = Image.open(img_item_path)
label = self.label_dir
return img,label
def __len__(self):
# 查看返回值的长度
return len(self.img_path)
# 将类实例化
root_dir = "dataset/train"
label_dir = "ants"
bee_dir = "bees"
ants_dataset = MyData(root_dir,label_dir)
bees_dataset = MyData(root_dir,bee_dir)
img,label = ants_dataset[0]
img2,label = bees_dataset[0]
#img.show()
#img2.show()
# 将两个数据集合二为一
train_dataset = ants_dataset + bees_dataset
print(len(train_dataset))
print(len(ants_dataset))
print(len(bees_dataset))
img3,label = train_dataset[123]
# img3.show()
img4,label = train_dataset[130]
# img4.show()
# 这样可以用于扩充数据集
Dataloader
- 为后面的网络提供不同的数据形式
- loader,即为加载器,可以将数据加载到神经网络
- 参数设置:
- batch_size : 每次加载多少个数据
- shuffle :若为True则顺序多次加载数据时数据的顺序不一样,为False时数据一样,默认为False
- num_workers : 进程数,一般设置为0,使用主进程,在Windows下设置数值大于0时容易出错,通常为0
- drop_last : true时舍弃数据,false时不舍弃
- transforms : 要设置为torchvision.transforms.ToTensor(),不然for循环会报错
- 示例代码:
import torchvision
from torch.utils.data import DataLoader
test_dataset = torchvision.datasets.CIFAR10(root="./download_data",train=False,transform=torchvision.transforms.ToTensor(),download=True)
# 要设置transforms的参数,不然for循环会出错
dataloader = DataLoader(dataset=test_dataset,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
print(type(dataloader))
for img,label in enumerate(dataloader):
# img,label = data
print(label)
for data in dataloader:
img,label = data
print(label)
'''
img, target = test_dataset[0]
print(type(img))
print(target)
# img.show()
'''
打开图片
示例代码
# 打开单个图片
from PIL import Image
img_paht = "path_str"
img = Image.open(img_path)
img.show()
# 所有图片路径作为一个list
import os
dir_path = "dir_path" #文件夹路径
img_path_list = os.listdir(dir_path)
img_path_list[0] #获取第一个图片的路径
- 都是基本代码,基本上一看就懂
os.path.join(root_dir,label_dir)
# 将两个字符串以路径的方式拼接,Windows中默认为 \ 这个函数
# 会自动写为 /