本文大量引用了 - CTC Algorithm Explained Part 1:Training the Network(CTC 算法详解之训练篇),只是用自己的语言理解了一下,原论文:Connectionist Temporal Classification: Labelling UnsegmSequence Data with Recurrent Neural Networ
解决的问题
套用知乎上的一句话,CTC Loss 要解决的问题就是当 label 长度小于模型输出长度时,如何做损失函数。
一般做分类时,已有的 softmax loss 都是模型输出长度和 label 长度相同且严格对齐,而语音识别或者手写体识别中,无法预知一句话或者一张图应该输出多长的文字,这时做法有两种:seq2seq+attention 机制,不限制输出长度,在最后加一个结束符号,让模型自动和 gt label 对齐;另一种是给定一个模型输出的最大长度,但是这些输出并没有对齐的 label 怎么办呢,这时就需要 CTC loss 了。
输出序列的扩展
所以,如果要计算?(?│?),可以累加其对应的全部输出序列 o (也即映射到最终 label 的 “路径”) 的概率即可,如下图。
前向和后向计算
由于我们没有每个时刻输出对应的 label,因此 CTC 使用最大似然进行训练 (CTC 假设输出的概率是(相对于输入)条件独立的)
给定输入 x x x,输出序列 o o o 的条件概率是:
p (π ∣ x) = ∏ y π t t , ∀ π ∈ L ′ T p(\pi|x) = \prod y^t{\pi_t}, \forall \pi \in L^{\prime T} p(π∣x)\=∏yπtt,∀π∈L′T
π t \pi _t πt 是序列 o o o 中的一个元素, y y y 为模型在所有时刻输出各个字符的概率,shape 为 T*C(T 是时刻,提前已固定。C 是字符类别数,所有字符 + blank(不是空格,是空) , y π t t y^t{\pi_t} yπtt 是模型 t 时刻输出为 π t \pi _t πt的概率
我们模型的目标就是给定输入 x,使得能映射到最终 label 的所有输出序列 o 的条件概率之和最大,该条件概率就是 p (π ∣ x) p(\pi|x) p(π∣x),和模型的输出概率 y y y 直接关联
那么我们如何计算这些条件概率之和呢?首先想到的就是暴力算法,一一找到可以映射到最终 label 的所有输出序列,然后概率连乘最后相加,但是很耗时,有木有更快的做法?联系一下 HMM 模型中的前向和后向算法,它就是利用动态规划求某个序列出现的概率,和此处我们要计算某个输出序列的条件概率很相似
比如 HMM 模型中,我们要求红白红出现的概率,我们就可以利用动态规划的思想,因为红白红包含子问题红白的产生,红白包含子问题红的产生,参考引用的图片。
而这里我们以 apple 这个 label 都可以由哪些输出序列映射过去为例(T 为 8):
其中的一种 a p _ p l e
当然其他也可以如 a p p _ p p l e,但是考虑到我们最终对输出序列的处理 (两个空字符之间的重复元素会去除,字符是从左到右的,且是依次的),我们的路径(状态转移) 不是随便的,根据这样的规则,我们可以找到所有可以映射到 apple 的输出序列
很明显可以看到这和 HMM 很像,包含很多相同子问题,可以用动态规划做
定义在时刻 t 经过节点 s 的全部前缀子路径的概率总和为前向概率 α t (s) \alpha_t (s) αt(s),如 α 3 ( 4 ) \alpha_3 (4) α3(4) 为在时刻 3 所有经过第 4 个节点的全部前缀子路径的概率总和: α 3 ( 4 ) \alpha_3 (4) α3(4) = p(_ap) + p(aap) + p(a_p) + p(app),该节点为 p
类似的定义在时刻 t 经过节点 s 的全部后缀子路径的概率总和为前向概率 β t (s) \betat (s) βt(s),如 β 6 ( 8 ) \beta_6 (8) β6(8) 为在时刻 6 所有经过第 8 个节点的全部后缀子路径的概率总和: β 3 ( 4 ) \beta_3 (4) β3(4) = p(lle) + p(l_e) + p(lee) + p(le),该节点为 l
总结
Focal CTC Loss
实现
参考论文 Focal CTC Loss for Chinese Optical Character Recognition on Unbalanced Datasets
- 语音识别:深入理解 CTC Loss 原理
- CTC Algorithm Explained Part 1:Training the Network(CTC 算法详解之训练篇)
- 隐马尔可夫 (HMM)、前 / 后向算法、Viterbi 算法 再次总结
- 【Learning Notes】CTC 原理及实现
- 统计学习方法 - p178
https://sundrops.blog.csdn.net/article/details/97136572?spm=1001.2101.3001.6650.4&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7Edefault-4-97136572-blog-78941696.pc_relevant_multi_platform_whitelistv1&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7Edefault-4-97136572-blog-78941696.pc_relevant_multi_platform_whitelistv1&utm_relevant_index=6