在训练DAN网络过程中,遇到了以下代码:

  1. def load_training(root_path, dir, batch_size, kwargs):
  2. transform = transforms.Compose(
  3. [transforms.Resize([256, 256]),
  4. transforms.RandomCrop(224),
  5. transforms.RandomHorizontalFlip(),
  6. transforms.ToTensor()]) %对图像进行各种操作
  7. data = datasets.ImageFolder(root=root_path + dir, transform=transform)
  8. print("data的类型{0}".format(type(data))) %dataImageFolder类型的
  9. train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True,
  10. drop_last=True, **kwargs)
  11. print("train_loader的类型{0}".format(type(train_loader)))
  12. return train_loader

因为我的不是图像数据,所以需要将数据调整一下,需要明白这些代码的意思以及它的数据变化情况,发现最后的关键是train_loader,所以将其类型打印

train_loader的类型

所以,我的数据只需要调整成这样即可。

DataLoader做了什么?

torch.utils.data.dataloader.DataLoader是一个数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。在训练模型时使用到此函数,用来吧训练数据分成多个小组,此函数每次抛出一组数据,直至把所有的数据都抛出,做一个数据的初始化。

语法

  1. class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
  2. batch_sampler=None, num_workers=0, collate_fn=
  3. <function default_collate>,pin_memory=False,
  4. drop_last=False, timeout=0, worker_init_fn=None)

参数

  • dataset(Dataset):来源于加载数据的数据集
  • batch_size(int, optional):每批次加载的样本数(默认:1)
  • shuffle(bool, optional):设置为true的话会打乱每个批次的数据(默认:False
  • sampler(Sampler, optional):定义了从数据集中抽取样本的方法,如果已经定义好了,shuffle设定为False即可
  • batch_sampler(Sampler, optional):类似于上面的sampler,但是每次返回的是一批指标。与batch_sizeshufflesamplerdrop_last是互斥的关系。
  • num_workers(int, optional):使用多少子过程来加载数据。0就意味着数据会在主过程中加载(默认:0)
  • collate_fn(callable, optional):将一个列表中的样本整合为一个小批次样本(mini_batch)
  • pin_memory(bool, optional):如果设置为True的话,这个数据在返回之前会被拷贝到CUDA的锁定内存中去
  • drop_last(bool, optional):若果设置为True的话,在数据集不能够整除批次大小的情况下会扔掉最后一批不完整的数据集。如果设置为False的话,在数据集不能够整除批次大小的情况下仍旧会保留最后一批数据集,只不过这部分的数据集批次会很小(默认:False
  • timeout(numeric, optional):如果为正的情况下,这个timeout表示从工作者处理批次数据的时间,所以其一定非负(默认:0)
  • worker_init_fn(callable, optional):如果设置不是None的话,这将被带有id的工作者的子进程中作为输入值被调用,调用过程发生在数据加载后。(默认:None)

    例子

    下面的这个例子生成的是迭代数据: ```python “”” 批训练,把数据变成一小批一小批数据进行训练。 DataLoader就是用来包装所使用的数据,每次抛出一批数据 “”” import torch import torch.utils.data as Data

BATCH_SIZE = 5 x = torch.linspace(1, 10, 10) y = torch.linspace(10, 1, 10) torch_dataset = Data.TensorDataset(x, y) # 把数据放到数据库中

loader = Data.DataLoader( dataset=torch_dataset, # 从数据库中每次抽取batch_size大小的样本 batch_size=BATCH_SIZE, shuffle=True, num_workers=2, )

def show_batch(): for epoch in range(5): print(“epoch{0}”.format(epoch)) for step, (batch_x, batch_y) in enumerate(loader):

  1. # 训练
  2. print("steop:{}, batch_x:{},shape_x:{}, batch_y:{},shape_y:{}".format(step, batch_x, batch_x.shape,
  3. batch_y, batch_y.shape))

if name == ‘main‘: show_batch() ``` 在这个例子中,有一个enumerate语法,点击这里查看其语法。

epoch0 steop:0, batch_x:tensor([4., 8., 9., 7., 2.]),shape_x:torch.Size([5]), batch_y:tensor([7., 3., 2., 4., 9.]),shape_y:torch.Size([5])

steop:1, batch_x:tensor([ 5., 1., 6., 3., 10.]),shape_x:torch.Size([5]), batch_y:tensor([ 6., 10., 5., 8., 1.]),shape_y:torch.Size([5])

epoch1

steop:0, batch_x:tensor([7., 4., 5., 8., 2.]),shape_x:torch.Size([5]), batch_y:tensor([4., 7., 6., 3., 9.]),shape_y:torch.Size([5])

steop:1, batch_x:tensor([ 9., 3., 1., 6., 10.]),shape_x:torch.Size([5]), batch_y:tensor([ 2., 8., 10., 5., 1.]),shape_y:torch.Size([5])

epoch2

steop:0, batch_x:tensor([10., 6., 2., 7., 5.]),shape_x:torch.Size([5]), batch_y:tensor([1., 5., 9., 4., 6.]),shape_y:torch.Size([5])

steop:1, batch_x:tensor([4., 9., 3., 1., 8.]),shape_x:torch.Size([5]), batch_y:tensor([ 7., 2., 8., 10., 3.]),shape_y:torch.Size([5])

epoch3

steop:0, batch_x:tensor([ 7., 2., 1., 10., 3.]),shape_x:torch.Size([5]), batch_y:tensor([ 4., 9., 10., 1., 8.]),shape_y:torch.Size([5])

steop:1, batch_x:tensor([5., 6., 9., 8., 4.]),shape_x:torch.Size([5]), batch_y:tensor([6., 5., 2., 3., 7.]),shape_y:torch.Size([5])

epoch4

steop:0, batch_x:tensor([ 1., 7., 9., 10., 6.]),shape_x:torch.Size([5]), batch_y:tensor([10., 4., 2., 1., 5.]),shape_y:torch.Size([5])

steop:1, batch_x:tensor([2., 4., 3., 8., 5.]),shape_x:torch.Size([5]), batch_y:tensor([9., 7., 8., 3., 6.]),shape_y:torch.Size([5])

参考链接:

链接1:torch.utils.data.DataLoader使用方法