参考文章:
https://blog.csdn.net/weixin_44791964/article/details/102853782
参考讲解视频:
https://www.bilibili.com/video/BV1E7411J72R?from=search&seid=3954447457476218575&spm_id_from=333.337.0.0
什么是Focal loss
Focal loss是何恺明大神提出的一种新的loss计算方案。
其具有两个重要的特点。
1、控制正负样本的权重
2、控制容易分类和难分类样本的权重
正负样本的概念如下:
一张图像可能生成成千上万的候选框,但是其中只有很少一部分是包含目标的的,有目标的就是正样本,没有目标的就是负样本。
容易分类和难分类样本的概念如下:
假设存在一个二分类,样本1属于类别1的pt=0.9,样本2属于类别1的pt=0.6,显然前者更可能是类别1,其就是容易分类的样本;后者有可能是类别1,所以其为难分类样本。
def focal(alpha=0.25, gamma=2.0):def _focal(y_true, y_pred):# y_true [batch_size, num_anchor, num_classes+1]# y_pred [batch_size, num_anchor, num_classes]labels = y_true[:, :, :-1]anchor_state = y_true[:, :, -1] # -1 是需要忽略的, 0 是背景, 1 是存在目标classification = y_pred# 找出存在目标的先验框indices_for_object = backend.where(keras.backend.equal(anchor_state, 1))labels_for_object = backend.gather_nd(labels, indices_for_object)classification_for_object = backend.gather_nd(classification, indices_for_object)# 计算每一个先验框应该有的权重alpha_factor_for_object = keras.backend.ones_like(labels_for_object) * alphaalpha_factor_for_object = backend.where(keras.backend.equal(labels_for_object, 1), alpha_factor_for_object, 1 - alpha_factor_for_object)focal_weight_for_object = backend.where(keras.backend.equal(labels_for_object, 1), 1 - classification_for_object, classification_for_object)focal_weight_for_object = alpha_factor_for_object * focal_weight_for_object ** gamma# 将权重乘上所求得的交叉熵cls_loss_for_object = focal_weight_for_object * keras.backend.binary_crossentropy(labels_for_object, classification_for_object)# 找出实际上为背景的先验框indices_for_back = backend.where(keras.backend.equal(anchor_state, 0))labels_for_back = backend.gather_nd(labels, indices_for_back)classification_for_back = backend.gather_nd(classification, indices_for_back)# 计算每一个先验框应该有的权重alpha_factor_for_back = keras.backend.ones_like(labels_for_back) * (1 - alpha)focal_weight_for_back = classification_for_backfocal_weight_for_back = alpha_factor_for_back * focal_weight_for_back ** gamma# 将权重乘上所求得的交叉熵cls_loss_for_back = focal_weight_for_back * keras.backend.binary_crossentropy(labels_for_back, classification_for_back)# 标准化,实际上是正样本的数量normalizer = tf.where(keras.backend.equal(anchor_state, 1))normalizer = keras.backend.cast(keras.backend.shape(normalizer)[0], keras.backend.floatx())normalizer = keras.backend.maximum(keras.backend.cast_to_floatx(1.0), normalizer)# 将所获得的loss除上正样本的数量cls_loss_for_object = keras.backend.sum(cls_loss_for_object)cls_loss_for_back = keras.backend.sum(cls_loss_for_back)# 总的lossloss = (cls_loss_for_object + cls_loss_for_back)/normalizerreturn lossreturn _focal
