该节博客主要讲解网络的训练流程,以及其中的代码阅读与学习。主要涉及数据的读取与处理和模型的训练,涉及文件有data_loaders.py。
首先我们先来阅读data_loaders.py文件中的代码,了解整个数据读取和处理的流程。
1. 数据的处理流程(data_loaders.py)
def mch(**kwargs):
return munch.Munch(dict(**kwargs))
def configure_metadata(metadata_root):
metadata = mch()
metadata.image_ids = os.path.join(metadata_root, 'image_ids.txt')
metadata.image_ids_proxy = os.path.join(metadata_root,
'image_ids_proxy.txt')
metadata.class_labels = os.path.join(metadata_root, 'class_labels.txt')
metadata.image_sizes = os.path.join(metadata_root, 'image_sizes.txt')
metadata.localization = os.path.join(metadata_root, 'localization.txt')
return metadata
以上是读取数据.txt文件的函数,一般txt文本存储着图片的路径,名称(id),标签(label)等的数据形式,截图如下:
以上的代码就是把路径复制给字典变量。
1)def get_image_ids(metadata, proxy=False)
def get_image_ids(metadata, proxy=False):
"""
image_ids.txt has the structure
<path>
path/to/image1.jpg
path/to/image2.jpg
path/to/image3.jpg
...
"""
image_ids = []
suffix = '_proxy' if proxy else ''
with open(metadata['image_ids' + suffix]) as f:
for line in f.readlines():
image_ids.append(line.strip('\n'))
return image_ids
处理image_ids.txt文件,内容如下:
把每一行的图片的id存储到一个list里面。这里都比较简单。
2)def get_class_labels(metadata)
def get_class_labels(metadata):
"""
image_ids.txt has the structure
<path>,<integer_class_label>
path/to/image1.jpg,0
path/to/image2.jpg,1
path/to/image3.jpg,1
...
"""
class_labels = {}
with open(metadata.class_labels) as f:
for line in f.readlines():
image_id, class_label_string = line.strip('\n').split(',')
class_labels[image_id] = int(class_label_string)
return class_labels
返回数据类型为一个字典,字典的key为image_id,value为class_labels。如图所示:
3)def get_bounding_boxes(metadata)
def get_bounding_boxes(metadata):
"""
localization.txt (for bounding box) has the structure
<path>,<x0>,<y0>,<x1>,<y1>
path/to/image1.jpg,156,163,318,230
path/to/image1.jpg,23,12,101,259
path/to/image2.jpg,143,142,394,248
path/to/image3.jpg,28,94,485,303
...
One image may contain multiple boxes (multiple boxes for the same path).
"""
boxes = {}
with open(metadata.localization) as f:
for line in f.readlines():
image_id, x0s, x1s, y0s, y1s = line.strip('\n').split(',')
x0, x1, y0, y1 = int(x0s), int(x1s), int(y0s), int(y1s)
if image_id in boxes:
boxes[image_id].append((x0, x1, y0, y1))
else:
boxes[image_id] = [(x0, x1, y0, y1)]
return boxes
这个函数应该是在evaluate的时候用于检测模型的效果,在train的过程中用不到。
返回的boxes也是一个字典,其形式为: {image_id: [(x0, x1, y0, y1)]}
4)def get_mask_paths(metadata)
def get_mask_paths(metadata):
"""
localization.txt (for masks) has the structure
<path>,<link_to_mask_file>,<link_to_ignore_mask_file>
path/to/image1.jpg,path/to/mask1a.png,path/to/ignore1.png
path/to/image1.jpg,path/to/mask1b.png,
path/to/image2.jpg,path/to/mask2a.png,path/to/ignore2.png
path/to/image3.jpg,path/to/mask3a.png,path/to/ignore3.png
...
One image may contain multiple masks (multiple mask paths for same image).
One image contains only one ignore mask.
"""
mask_paths = {}
ignore_paths = {}
with open(metadata.localization) as f:
for line in f.readlines():
image_id, mask_path, ignore_path = line.strip('\n').split(',')
if image_id in mask_paths:
mask_paths[image_id].append(mask_path)
assert (len(ignore_path) == 0)
else:
mask_paths[image_id] = [mask_path]
ignore_paths[image_id] = ignore_path
return mask_paths, ignore_paths
在训练的过程中用不到,在后面的笔记中看到了再学习,Mark一下!!!!
5)def get_image_sizes
def get_image_sizes(metadata):
"""
image_sizes.txt has the structure
<path>,<w>,<h>
path/to/image1.jpg,500,300
path/to/image2.jpg,1000,600
path/to/image3.jpg,500,300
...
"""
image_sizes = {}
with open(metadata.image_sizes) as f:
for line in f.readlines():
image_id, ws, hs = line.strip('\n').split(',')
w, h = int(ws), int(hs)
image_sizes[image_id] = (w, h)
return image_sizes
# 返回同样是一个字典:
# type: {image_id:(w, h)}
6)Class WSOLImageLabelDataset(Dataset)
这里才是真正的数据的加载过程。
# 继承自Pytorch的Dataset类
class WSOLImageLabelDataset(Dataset):
def __init__(self, data_root, metadata_root, transform, proxy,
num_sample_per_class=0):
self.data_root = data_root
self.metadata = configure_metadata(metadata_root)
self.transform = transform
self.image_ids = get_image_ids(self.metadata, proxy=proxy)
self.image_labels = get_class_labels(self.metadata)
self.num_sample_per_class = num_sample_per_class
self._adjust_samples_per_class()
def _adjust_samples_per_class(self):
if self.num_sample_per_class == 0:
return
image_ids = np.array(self.image_ids)
image_labels = np.array([self.image_labels[_image_id]
for _image_id in self.image_ids])
# 把对应image_id的图像标签全部提取出来,构成一个array
unique_labels = np.unique(image_labels)
# np.unique 该函数是去除数组中的重复数字,并进行排序之后输出
# 相当于unique_labels里面保存着多少类别的标签
new_image_ids = []
new_image_labels = {}
for _label in unique_labels:
indices = np.where(image_labels == _label)[0]
# np.where(condition): 返回image_labels 符合condition的array
# 对每一个类的数据进行随机抽样
sampled_indices = np.random.choice(
indices, self.num_sample_per_class, replace=False)
sampled_image_ids = image_ids[sampled_indices].tolist()
# 对每一类随机抽取的图片id进行收集
sampled_image_labels = image_labels[sampled_indices].tolist()
# 对每一类随机抽取的图片label进行收集
new_image_ids += sampled_image_ids
new_image_labels.update(
**dict(zip(sampled_image_ids, sampled_image_labels)))
# 将随机抽取的图片的id和label进行整合成字典
self.image_ids = new_image_ids
self.image_labels = new_image_labels
def __getitem__(self, idx):
image_id = self.image_ids[idx]
image_label = self.image_labels[image_id]
image = Image.open(os.path.join(self.data_root, image_id))
image = image.convert('RGB')
image = self.transform(image)
return image, image_label, image_id
def __len__(self):
return len(self.image_ids)
7)def get_data_loader
def get_data_loader(data_roots, metadata_root, batch_size, workers,
resize_size, crop_size, proxy_training_set,
num_val_sample_per_class=0):
dataset_transforms = dict(
train=transforms.Compose([
transforms.Resize((resize_size, resize_size)),
transforms.RandomCrop(crop_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(_IMAGE_MEAN_VALUE, _IMAGE_STD_VALUE)
]),
val=transforms.Compose([
transforms.Resize((crop_size, crop_size)),
transforms.ToTensor(),
transforms.Normalize(_IMAGE_MEAN_VALUE, _IMAGE_STD_VALUE)
]),
test=transforms.Compose([
transforms.Resize((crop_size, crop_size)),
transforms.ToTensor(),
transforms.Normalize(_IMAGE_MEAN_VALUE, _IMAGE_STD_VALUE)
]))
loaders = {
split: DataLoader(
WSOLImageLabelDataset(
data_root=data_roots[split],
metadata_root=os.path.join(metadata_root, split),
transform=dataset_transforms[split],
proxy=proxy_training_set and split == 'train',
num_sample_per_class=(num_val_sample_per_class
if split == 'val' else 0)
),
batch_size=batch_size,
shuffle=split == 'train',
num_workers=workers)
for split in _SPLITS
}
return loaders