1. 数据集预处理
将标签文件转为YOLO格式,即[x中心点坐标,y中心点坐标,w,h](都进行了归一化)
图片通过模型输出后,得到归一化并数据增强的图片信息
box预测为[x相对于网格左上角的偏移量,y相对于网格左上角的偏移量,w开根号,h开根号]
2. 获取正样本
- 将真实框标签信息乘上最后特征图的尺寸,将归一化的标签信息转为特征图上的信息
- 将真实框信息转换为[0,0,w,h],维度为[n,4],n代表图片中有几个目标(即有几个真实框)
- 将先验框信息转换为[0,0,w,h],维度为[9,4],9代表有9个先验框(三个特征层,每层三个先验框)
- 计算所有先验框和真实框的IOU,并找出IOU最大的先验框(返回先验框的索引),并找出该先验框属于哪个特征层;目标(真实框)所在的网格,最后正样本所在的位置即[[b, k, j, i],其中b代表该批次图片中的哪张图片,k代表第几个先验框,j代表目标所在网格的y坐标,i代表目标所在网格的x坐标。
将正样本的五个维度信息赋给代表正样本的变量(置信度设为1),负样本变量的正样本所在位置置为0,同时记录下当前目标相比当前层的比例系数当成loss的权重(即当前目标大小/特征层大小)
3. 获取负样本
将预测框信息转为x中心点坐标,y中心点坐标,,
- 根据yolo格式进行解码,调整将维度转为(n,4),n代表预测框数量,即b313*13(13为当前特征层的宽高)
计算所有调整后先验框框和真实框的IOU,将IOU大于阈值的先验框所在位置设为正样本,其余全为负样本
4. 计算loss
obj_mask代表正样本掩膜,有目标的位置为True,其他为false,
计算预测框和正样本的IOU,并计算IOU损失
def box_iou(self, b1, b2):
"""
输入格式:
----------
b1: tensor, shape=(batch, anchor_num, feat_w, feat_h, 4), xywh
b2: tensor, shape=(batch, anchor_num, feat_w, feat_h, 4), xywh
返回格式
-------
out: tensor, shape=(batch, anchor_num, feat_w, feat_h)
"""
#计算预测框左上角和右下角坐标
b1_xy = b1[..., :2]
b1_wh = b1[..., 2:4]
b1_wh_half = b1_wh / 2.
b1_mins = b1_xy - b1_wh_half
b1_maxes = b1_xy + b1_wh_half
#计算真实框左上角和右下角坐标
b2_xy = b2[..., :2]
b2_wh = b2[..., 2:4]
b2_wh_half = b2_wh / 2.
b2_mins = b2_xy - b2_wh_half
b2_maxes = b2_xy + b2_wh_half
#家孙真实框和预测框的IOU
intersect_mins = torch.max(b1_mins, b2_mins)
intersect_maxes = torch.min(b1_maxes, b2_maxes)
intersect_wh = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes))
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
b1_area = b1_wh[..., 0] * b1_wh[..., 1]
b2_area = b2_wh[..., 0] * b2_wh[..., 1]
union_area = b1_area + b2_area - intersect_area
iou = intersect_area / torch.clamp(union_area,min = 1e-6)
#计算两个框中心点距离
center_wh = b1_xy - b2_xy
#找到包裹两个框的最小框的左上角和右下角
enclose_mins = torch.min(b1_mins, b2_mins)
enclose_maxes = torch.max(b1_maxes, b2_maxes)
enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes))
if self.iou_type == 'ciou':
#计算两个框框中心点欧氏距离
center_distance = torch.sum(torch.pow(center_wh, 2), axis=-1)
#计算包裹框对角线距离
enclose_diagonal = torch.sum(torch.pow(enclose_wh, 2), axis=-1)
ciou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal, min = 1e-6)
#v 为长宽比惩罚项
v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(b1_wh[..., 0] / torch.clamp(b1_wh[..., 1],min = 1e-6)) - torch.atan(b2_wh[..., 0] / torch.clamp(b2_wh[..., 1], min = 1e-6))), 2)
alpha = v / torch.clamp((1.0 - iou + v), min = 1e-6)
out = ciou - alpha * v
elif self.iou_type == 'siou':
#计算中心点距离
sigma = torch.pow(torch.sum(torch.pow(center_wh, 2), axis=-1), 0.5)
#求h和w方向上的sin比值
sin_alpha_1 = torch.clamp(torch.abs(center_wh[..., 0]) / torch.clamp(sigma, min = 1e-6), min = 0, max = 1)
sin_alpha_2 = torch.clamp(torch.abs(center_wh[..., 1]) / torch.clamp(sigma, min = 1e-6), min = 0, max = 1)
#求门限,二分之根号二,0.707
#如果门限大于0.707,代表某个方向的角度大于45度
#此时求取另一个方向的角度
threshold = pow(2, 0.5) / 2
sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
#alpha越接近45度,angel_cost越接近1,gamma越接近1
#alpha越接近0度,angel_cost越接近0,gamma越接近2
angle_cost = torch.cos(torch.asin(sin_alpha) * 2 - math.pi / 2)
gamma = 2 - angle_cost
#----------------------------------------------------#
# Distance cost
# 求中心与包裹框宽高的比值
rho_x = (center_wh[..., 0] / torch.clamp(enclose_wh[..., 0], min = 1e-6)) ** 2
rho_y = (center_wh[..., 1] / torch.clamp(enclose_wh[..., 1], min = 1e-6)) ** 2
distance_cost = 2 - torch.exp(-gamma * rho_x) - torch.exp(-gamma * rho_y)
#----------------------------------------------------#
# Shape cost
# 真实框和预测框的宽高差与最大值的比值
# 差异越小,costshape_cost越小
omiga_w = torch.abs(b1_wh[..., 0] - b2_wh[..., 0]) / torch.clamp(torch.max(b1_wh[..., 0], b2_wh[..., 0]), min = 1e-6)
omiga_h = torch.abs(b1_wh[..., 1] - b2_wh[..., 1]) / torch.clamp(torch.max(b1_wh[..., 1], b2_wh[..., 1]), min = 1e-6)
shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
out = iou - 0.5 * (distance_cost + shape_cost)
return out
类别损失采用二值交叉熵损失函数(类别分类相对于二分类,分类当前目标是否为当前类别)
torch.mean(self.BCELoss(pred_cls[obj_mask], y_true[..., 5:][obj_mask]))
- 置信度损失采用focal_loss
#正样本是
pos_neg_ratio = torch.where(obj_mask, torch.ones_like(conf) * self.alpha, torch.ones_like(conf) * (1 - self.alpha))
hard_easy_ratio = torch.where(obj_mask, torch.ones_like(conf) - conf, conf) ** self.gamma
loss_conf = torch.mean((self.BCELoss(conf, obj_mask.type_as(conf)) * pos_neg_ratio * hard_easy_ratio)[noobj_mask.bool() | obj_mask]) * self.focal_loss_ratio