batch_size>1的情况下,同一个batch内图像size相同,不同batch的图像size不一定相同。

重点放在改写torch.utils.data.Dataset的collate_fn上

例子学习

https://github.com/ultralytics/yolov3 关键词:collate_fn

从中摘取了部分:

  1. class LoadImagesAndLabels(Dataset): # for training/testing
  2. def __get_item__(self, index):
  3. # 返回的包括 输入图片,labels,路径,图像高宽的tuple
  4. @staticmethod
  5. def collate_fn(batch):
  6. img, label, path, hw = list(zip(*batch)) # transposed
  7. for i, l in enumerate(label):
  8. l[:, 0] = i # add target image index for build_targets()
  9. # torch.stack(img, 0): 增加维度
  10. # torch.cat(label, 0): 元素拼接。 将labels按行拼接
  11. return torch.stack(img, 0), torch.cat(label, 0), path, hw

在这里,collate_fn的作用主要是处理labels(因为detection中的labels长度通常不固定)。torch.cat后labels的维度是(num_of_labels(all images), 6)
看到这里就明白了collate_fn是怎么统一维度的了,它将所有的label一个个排列了,而不是把用一个个list结构放到一个总的list中。

构造数据集时的代码如下:

  1. testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, imgsz_test, batch_size,
  2. pin_memory=True,
  3. collate_fn=dataset.collate_fn)

写代码

我的Dataset在get_item之后是两张图片,sample_A和sample_B。
我期望不同batch中图片的尺寸不固定,但在一个batch内部图片尺寸需要统一。

最后实现的代码如下:

  1. class DatasetFolder(data.Dataset):
  2. def __getitem__(self, index):
  3. #...
  4. return sample_A, sample_B
  5. @staticmethod
  6. def collate_fn(batch):
  7. # different batch with different size
  8. # images in one batch with same size
  9. size = random.choice([(256, 256), (512, 512), (768, 768), (1024, 1024), (1280, 1280)])
  10. samples_A = []
  11. samples_B = []
  12. for sample_A, sample_B in batch:
  13. sample_A = F.interpolate(sample_A.unsqueeze(0), size=size).squeeze(0)
  14. sample_B = F.interpolate(sample_B.unsqueeze(0), size=size).squeeze(0)
  15. samples_A.append(sample_A)
  16. samples_B.append(sample_B)
  17. samples_A = torch.stack(samples_A, 0)
  18. samples_B = torch.stack(samples_B, 0)
  19. return samples_A, samples_B