该节博客主要讲解网络的训练流程,以及其中的代码阅读与学习。主要涉及数据的读取与处理和模型的训练,涉及文件有data_loaders.py。

首先我们先来阅读data_loaders.py文件中的代码,了解整个数据读取和处理的流程。

1. 数据的处理流程(data_loaders.py)

  1. def mch(**kwargs):
  2. return munch.Munch(dict(**kwargs))
  3. def configure_metadata(metadata_root):
  4. metadata = mch()
  5. metadata.image_ids = os.path.join(metadata_root, 'image_ids.txt')
  6. metadata.image_ids_proxy = os.path.join(metadata_root,
  7. 'image_ids_proxy.txt')
  8. metadata.class_labels = os.path.join(metadata_root, 'class_labels.txt')
  9. metadata.image_sizes = os.path.join(metadata_root, 'image_sizes.txt')
  10. metadata.localization = os.path.join(metadata_root, 'localization.txt')
  11. return metadata

以上是读取数据.txt文件的函数,一般txt文本存储着图片的路径,名称(id),标签(label)等的数据形式,截图如下:
image.png
以上的代码就是把路径复制给字典变量。

1)def get_image_ids(metadata, proxy=False)

  1. def get_image_ids(metadata, proxy=False):
  2. """
  3. image_ids.txt has the structure
  4. <path>
  5. path/to/image1.jpg
  6. path/to/image2.jpg
  7. path/to/image3.jpg
  8. ...
  9. """
  10. image_ids = []
  11. suffix = '_proxy' if proxy else ''
  12. with open(metadata['image_ids' + suffix]) as f:
  13. for line in f.readlines():
  14. image_ids.append(line.strip('\n'))
  15. return image_ids

处理image_ids.txt文件,内容如下:
image.png
把每一行的图片的id存储到一个list里面。这里都比较简单。

2)def get_class_labels(metadata)

  1. def get_class_labels(metadata):
  2. """
  3. image_ids.txt has the structure
  4. <path>,<integer_class_label>
  5. path/to/image1.jpg,0
  6. path/to/image2.jpg,1
  7. path/to/image3.jpg,1
  8. ...
  9. """
  10. class_labels = {}
  11. with open(metadata.class_labels) as f:
  12. for line in f.readlines():
  13. image_id, class_label_string = line.strip('\n').split(',')
  14. class_labels[image_id] = int(class_label_string)
  15. return class_labels

返回数据类型为一个字典,字典的key为image_id,value为class_labels。如图所示:
image.png

3)def get_bounding_boxes(metadata)

  1. def get_bounding_boxes(metadata):
  2. """
  3. localization.txt (for bounding box) has the structure
  4. <path>,<x0>,<y0>,<x1>,<y1>
  5. path/to/image1.jpg,156,163,318,230
  6. path/to/image1.jpg,23,12,101,259
  7. path/to/image2.jpg,143,142,394,248
  8. path/to/image3.jpg,28,94,485,303
  9. ...
  10. One image may contain multiple boxes (multiple boxes for the same path).
  11. """
  12. boxes = {}
  13. with open(metadata.localization) as f:
  14. for line in f.readlines():
  15. image_id, x0s, x1s, y0s, y1s = line.strip('\n').split(',')
  16. x0, x1, y0, y1 = int(x0s), int(x1s), int(y0s), int(y1s)
  17. if image_id in boxes:
  18. boxes[image_id].append((x0, x1, y0, y1))
  19. else:
  20. boxes[image_id] = [(x0, x1, y0, y1)]
  21. return boxes

这个函数应该是在evaluate的时候用于检测模型的效果,在train的过程中用不到。

返回的boxes也是一个字典,其形式为: {image_id: [(x0, x1, y0, y1)]}

4)def get_mask_paths(metadata)

  1. def get_mask_paths(metadata):
  2. """
  3. localization.txt (for masks) has the structure
  4. <path>,<link_to_mask_file>,<link_to_ignore_mask_file>
  5. path/to/image1.jpg,path/to/mask1a.png,path/to/ignore1.png
  6. path/to/image1.jpg,path/to/mask1b.png,
  7. path/to/image2.jpg,path/to/mask2a.png,path/to/ignore2.png
  8. path/to/image3.jpg,path/to/mask3a.png,path/to/ignore3.png
  9. ...
  10. One image may contain multiple masks (multiple mask paths for same image).
  11. One image contains only one ignore mask.
  12. """
  13. mask_paths = {}
  14. ignore_paths = {}
  15. with open(metadata.localization) as f:
  16. for line in f.readlines():
  17. image_id, mask_path, ignore_path = line.strip('\n').split(',')
  18. if image_id in mask_paths:
  19. mask_paths[image_id].append(mask_path)
  20. assert (len(ignore_path) == 0)
  21. else:
  22. mask_paths[image_id] = [mask_path]
  23. ignore_paths[image_id] = ignore_path
  24. return mask_paths, ignore_paths

在训练的过程中用不到,在后面的笔记中看到了再学习,Mark一下!!!!

5)def get_image_sizes

  1. def get_image_sizes(metadata):
  2. """
  3. image_sizes.txt has the structure
  4. <path>,<w>,<h>
  5. path/to/image1.jpg,500,300
  6. path/to/image2.jpg,1000,600
  7. path/to/image3.jpg,500,300
  8. ...
  9. """
  10. image_sizes = {}
  11. with open(metadata.image_sizes) as f:
  12. for line in f.readlines():
  13. image_id, ws, hs = line.strip('\n').split(',')
  14. w, h = int(ws), int(hs)
  15. image_sizes[image_id] = (w, h)
  16. return image_sizes
  17. # 返回同样是一个字典:
  18. # type: {image_id:(w, h)}

6)Class WSOLImageLabelDataset(Dataset)

这里才是真正的数据的加载过程。

  1. # 继承自Pytorch的Dataset类
  2. class WSOLImageLabelDataset(Dataset):
  3. def __init__(self, data_root, metadata_root, transform, proxy,
  4. num_sample_per_class=0):
  5. self.data_root = data_root
  6. self.metadata = configure_metadata(metadata_root)
  7. self.transform = transform
  8. self.image_ids = get_image_ids(self.metadata, proxy=proxy)
  9. self.image_labels = get_class_labels(self.metadata)
  10. self.num_sample_per_class = num_sample_per_class
  11. self._adjust_samples_per_class()
  12. def _adjust_samples_per_class(self):
  13. if self.num_sample_per_class == 0:
  14. return
  15. image_ids = np.array(self.image_ids)
  16. image_labels = np.array([self.image_labels[_image_id]
  17. for _image_id in self.image_ids])
  18. # 把对应image_id的图像标签全部提取出来,构成一个array
  19. unique_labels = np.unique(image_labels)
  20. # np.unique 该函数是去除数组中的重复数字,并进行排序之后输出
  21. # 相当于unique_labels里面保存着多少类别的标签
  22. new_image_ids = []
  23. new_image_labels = {}
  24. for _label in unique_labels:
  25. indices = np.where(image_labels == _label)[0]
  26. # np.where(condition): 返回image_labels 符合condition的array
  27. # 对每一个类的数据进行随机抽样
  28. sampled_indices = np.random.choice(
  29. indices, self.num_sample_per_class, replace=False)
  30. sampled_image_ids = image_ids[sampled_indices].tolist()
  31. # 对每一类随机抽取的图片id进行收集
  32. sampled_image_labels = image_labels[sampled_indices].tolist()
  33. # 对每一类随机抽取的图片label进行收集
  34. new_image_ids += sampled_image_ids
  35. new_image_labels.update(
  36. **dict(zip(sampled_image_ids, sampled_image_labels)))
  37. # 将随机抽取的图片的id和label进行整合成字典
  38. self.image_ids = new_image_ids
  39. self.image_labels = new_image_labels
  40. def __getitem__(self, idx):
  41. image_id = self.image_ids[idx]
  42. image_label = self.image_labels[image_id]
  43. image = Image.open(os.path.join(self.data_root, image_id))
  44. image = image.convert('RGB')
  45. image = self.transform(image)
  46. return image, image_label, image_id
  47. def __len__(self):
  48. return len(self.image_ids)

7)def get_data_loader

  1. def get_data_loader(data_roots, metadata_root, batch_size, workers,
  2. resize_size, crop_size, proxy_training_set,
  3. num_val_sample_per_class=0):
  4. dataset_transforms = dict(
  5. train=transforms.Compose([
  6. transforms.Resize((resize_size, resize_size)),
  7. transforms.RandomCrop(crop_size),
  8. transforms.RandomHorizontalFlip(),
  9. transforms.ToTensor(),
  10. transforms.Normalize(_IMAGE_MEAN_VALUE, _IMAGE_STD_VALUE)
  11. ]),
  12. val=transforms.Compose([
  13. transforms.Resize((crop_size, crop_size)),
  14. transforms.ToTensor(),
  15. transforms.Normalize(_IMAGE_MEAN_VALUE, _IMAGE_STD_VALUE)
  16. ]),
  17. test=transforms.Compose([
  18. transforms.Resize((crop_size, crop_size)),
  19. transforms.ToTensor(),
  20. transforms.Normalize(_IMAGE_MEAN_VALUE, _IMAGE_STD_VALUE)
  21. ]))
  22. loaders = {
  23. split: DataLoader(
  24. WSOLImageLabelDataset(
  25. data_root=data_roots[split],
  26. metadata_root=os.path.join(metadata_root, split),
  27. transform=dataset_transforms[split],
  28. proxy=proxy_training_set and split == 'train',
  29. num_sample_per_class=(num_val_sample_per_class
  30. if split == 'val' else 0)
  31. ),
  32. batch_size=batch_size,
  33. shuffle=split == 'train',
  34. num_workers=workers)
  35. for split in _SPLITS
  36. }
  37. return loaders