batch_size>1的情况下,同一个batch内图像size相同,不同batch的图像size不一定相同。
重点放在改写torch.utils.data.Dataset的collate_fn上
例子学习
https://github.com/ultralytics/yolov3 关键词:collate_fn
从中摘取了部分:
class LoadImagesAndLabels(Dataset): # for training/testingdef __get_item__(self, index):# 返回的包括 输入图片,labels,路径,图像高宽的tuple@staticmethoddef collate_fn(batch):img, label, path, hw = list(zip(*batch)) # transposedfor i, l in enumerate(label):l[:, 0] = i # add target image index for build_targets()# torch.stack(img, 0): 增加维度# torch.cat(label, 0): 元素拼接。 将labels按行拼接return torch.stack(img, 0), torch.cat(label, 0), path, hw
在这里,collate_fn的作用主要是处理labels(因为detection中的labels长度通常不固定)。torch.cat后labels的维度是(num_of_labels(all images), 6)
看到这里就明白了collate_fn是怎么统一维度的了,它将所有的label一个个排列了,而不是把用一个个list结构放到一个总的list中。
构造数据集时的代码如下:
testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, imgsz_test, batch_size,pin_memory=True,collate_fn=dataset.collate_fn)
写代码
我的Dataset在get_item之后是两张图片,sample_A和sample_B。
我期望不同batch中图片的尺寸不固定,但在一个batch内部图片尺寸需要统一。
最后实现的代码如下:
class DatasetFolder(data.Dataset):def __getitem__(self, index):#...return sample_A, sample_B@staticmethoddef collate_fn(batch):# different batch with different size# images in one batch with same sizesize = random.choice([(256, 256), (512, 512), (768, 768), (1024, 1024), (1280, 1280)])samples_A = []samples_B = []for sample_A, sample_B in batch:sample_A = F.interpolate(sample_A.unsqueeze(0), size=size).squeeze(0)sample_B = F.interpolate(sample_B.unsqueeze(0), size=size).squeeze(0)samples_A.append(sample_A)samples_B.append(sample_B)samples_A = torch.stack(samples_A, 0)samples_B = torch.stack(samples_B, 0)return samples_A, samples_B
