上一节介绍了如何训练输入和输出均为不定长序列的编码器—解码器。本节我们介绍如何使用编码器—解码器来预测不定长的序列。

上一节里已经提到,在准备训练数据集时,我们通常会在样本的输入序列和输出序列后面分别附上一个特殊符号”“表示序列的终止。我们在接下来的讨论中也将沿用上一节的全部数学符号。为了便于讨论,假设解码器的输出是一段文本序列。设输出文本词典10.10 束搜索 - 图1(包含特殊符号”“)的大小为10.10 束搜索 - 图2,输出序列的最大长度为10.10 束搜索 - 图3。所有可能的输出序列一共有10.10 束搜索 - 图4#card=math&code=%5Cmathcal%7BO%7D%28%5Cleft%7C%5Cmathcal%7BY%7D%5Cright%7C%5E%7BT%27%7D%29)种。这些输出序列中所有特殊符号”“后面的子序列将被舍弃。

10.10.1 贪婪搜索

让我们先来看一个简单的解决方案:贪婪搜索(greedy search)。对于输出序列任一时间步10.10 束搜索 - 图5,我们从10.10 束搜索 - 图6个词中搜索出条件概率最大的词

10.10 束搜索 - 图7%0A#card=math&code=y%20%20%7B%20t%20%5E%20%7B%20%5Cprime%20%7D%20%7D%20%3D%20%5Cunderset%20%7B%20y%20%5Cin%20%5Cmathcal%20%7B%20Y%20%7D%20%7D%20%7B%20%5Coperatorname%20%7B%20argmax%20%7D%20%7D%20P%20%5Cleft%28%20y%20%7C%20y%20%20%7B%201%20%7D%20%2C%20%5Cldots%20%2C%20y%20_%20%7B%20t%20%5E%20%7B%20%5Cprime%20%7D%20-%201%20%7D%20%2C%20c%20%5Cright%29%0A)

作为输出。一旦搜索出”“符号,或者输出序列长度已经达到了最大长度10.10 束搜索 - 图8,便完成输出。

我们在描述解码器时提到,基于输入序列生成输出序列的条件概率是10.10 束搜索 - 图9#card=math&code=%5Cprod%7Bt%27%3D1%7D%5E%7BT%27%7D%20P%28y%7Bt%27%7D%20%5Cmid%20y1%2C%20%5Cldots%2C%20y%7Bt%27-1%7D%2C%20%5Cboldsymbol%7Bc%7D%29)。我们将该条件概率最大的输出序列称为最优输出序列。而贪婪搜索的主要问题是不能保证得到最优输出序列。

下面来看一个例子。假设输出词典里面有“A”“B”“C”和“”这4个词。图10.9中每个时间步下的4个数字分别代表了该时间步生成“A”“B”“C”和“”这4个词的条件概率。在每个时间步,贪婪搜索选取条件概率最大的词。因此,图10.9中将生成输出序列“A”“B”“C”“”。该输出序列的条件概率是10.10 束搜索 - 图10

10.10_beam_search.svg

接下来,观察图10.10演示的例子。与图10.9中不同,图10.10在时间步2中选取了条件概率第二大的词“C”。由于时间步3所基于的时间步1和2的输出子序列由图10.9中的“A”“B”变为了图10.10中的“A”“C”,图10.10中时间步3生成各个词的条件概率发生了变化。我们选取条件概率最大的词“B”。此时时间步4所基于的前3个时间步的输出子序列为“A”“C”“B”,与图10.9中的“A”“B”“C”不同。因此,图10.10中时间步4生成各个词的条件概率也与图10.9中的不同。我们发现,此时的输出序列“A”“C”“B”“”的条件概率是10.10 束搜索 - 图12,大于贪婪搜索得到的输出序列的条件概率。因此,贪婪搜索得到的输出序列“A”“B”“C”“”并非最优输出序列。

10.10_s2s_prob1.svg

10.10.2 穷举搜索

如果目标是得到最优输出序列,我们可以考虑穷举搜索(exhaustive search):穷举所有可能的输出序列,输出条件概率最大的序列。

虽然穷举搜索可以得到最优输出序列,但它的计算开销10.10 束搜索 - 图14#card=math&code=%5Cmathcal%7BO%7D%28%5Cleft%7C%5Cmathcal%7BY%7D%5Cright%7C%5E%7BT%27%7D%29)很容易过大。例如,当10.10 束搜索 - 图1510.10 束搜索 - 图16时,我们将评估10.10 束搜索 - 图17个序列:这几乎不可能完成。而贪婪搜索的计算开销是10.10 束搜索 - 图18#card=math&code=%5Cmathcal%7BO%7D%28%5Cleft%7C%5Cmathcal%7BY%7D%5Cright%7CT%27%29),通常显著小于穷举搜索的计算开销。例如,当10.10 束搜索 - 图1910.10 束搜索 - 图20时,我们只需评估10.10 束搜索 - 图21个序列。

10.10.3 束搜索

束搜索(beam search)是对贪婪搜索的一个改进算法。它有一个束宽(beam size)超参数。我们将它设为10.10 束搜索 - 图22。在时间步1时,选取当前时间步条件概率最大的10.10 束搜索 - 图23个词,分别组成10.10 束搜索 - 图24个候选输出序列的首词。在之后的每个时间步,基于上个时间步的10.10 束搜索 - 图25个候选输出序列,从10.10 束搜索 - 图26个可能的输出序列中选取条件概率最大的10.10 束搜索 - 图27个,作为该时间步的候选输出序列。最终,我们从各个时间步的候选输出序列中筛选出包含特殊符号“”的序列,并将它们中所有特殊符号“”后面的子序列舍弃,得到最终候选输出序列的集合。

10.10_s2s_prob2.svg

图10.11通过一个例子演示了束搜索的过程。假设输出序列的词典中只包含5个元素,即10.10 束搜索 - 图29,且其中一个为特殊符号“”。设束搜索的束宽等于2,输出序列最大长度为3。在输出序列的时间步1时,假设条件概率10.10 束搜索 - 图30#card=math&code=P%28y_1%20%5Cmid%20%5Cboldsymbol%7Bc%7D%29)最大的2个词为10.10 束搜索 - 图3110.10 束搜索 - 图32。我们在时间步2时将对所有的10.10 束搜索 - 图33都分别计算10.10 束搜索 - 图34#card=math&code=P%28y_2%20%5Cmid%20A%2C%20%5Cboldsymbol%7Bc%7D%29)和10.10 束搜索 - 图35#card=math&code=P%28y_2%20%5Cmid%20C%2C%20%5Cboldsymbol%7Bc%7D%29),并从计算出的10个条件概率中取最大的2个,假设为10.10 束搜索 - 图36#card=math&code=P%28B%20%5Cmid%20A%2C%20%5Cboldsymbol%7Bc%7D%29)和10.10 束搜索 - 图37#card=math&code=P%28E%20%5Cmid%20C%2C%20%5Cboldsymbol%7Bc%7D%29)。那么,我们在时间步3时将对所有的10.10 束搜索 - 图38都分别计算10.10 束搜索 - 图39#card=math&code=P%28y_3%20%5Cmid%20A%2C%20B%2C%20%5Cboldsymbol%7Bc%7D%29)和10.10 束搜索 - 图40#card=math&code=P%28y_3%20%5Cmid%20C%2C%20E%2C%20%5Cboldsymbol%7Bc%7D%29),并从计算出的10个条件概率中取最大的2个,假设为10.10 束搜索 - 图41#card=math&code=P%28D%20%5Cmid%20A%2C%20B%2C%20%5Cboldsymbol%7Bc%7D%29)和10.10 束搜索 - 图42#card=math&code=P%28D%20%5Cmid%20C%2C%20E%2C%20%5Cboldsymbol%7Bc%7D%29)。如此一来,我们得到6个候选输出序列:(1)10.10 束搜索 - 图43;(2)10.10 束搜索 - 图44;(3)10.10 束搜索 - 图4510.10 束搜索 - 图46;(4)10.10 束搜索 - 图4710.10 束搜索 - 图48;(5)10.10 束搜索 - 图4910.10 束搜索 - 图5010.10 束搜索 - 图51和(6)10.10 束搜索 - 图5210.10 束搜索 - 图5310.10 束搜索 - 图54。接下来,我们将根据这6个序列得出最终候选输出序列的集合。

在最终候选输出序列的集合中,我们取以下分数最高的序列作为输出序列:

10.10 束搜索 - 图55%20%3D%20%5Cfrac%7B1%7D%7BL%5E%5Calpha%7D%20%5Csum%7Bt’%3D1%7D%5EL%20%5Clog%20P(y%7Bt’%7D%20%5Cmid%20y1%2C%20%5Cldots%2C%20y%7Bt’-1%7D%2C%20%5Cboldsymbol%7Bc%7D)%2C%0A#card=math&code=%5Cfrac%7B1%7D%7BL%5E%5Calpha%7D%20%5Clog%20P%28y1%2C%20%5Cldots%2C%20y%7BL%7D%29%20%3D%20%5Cfrac%7B1%7D%7BL%5E%5Calpha%7D%20%5Csum%7Bt%27%3D1%7D%5EL%20%5Clog%20P%28y%7Bt%27%7D%20%5Cmid%20y1%2C%20%5Cldots%2C%20y%7Bt%27-1%7D%2C%20%5Cboldsymbol%7Bc%7D%29%2C%0A)

其中10.10 束搜索 - 图56为最终候选序列长度,10.10 束搜索 - 图57一般可选为0.75。分母上的10.10 束搜索 - 图58是为了惩罚较长序列在以上分数中较多的对数相加项。分析可知,束搜索的计算开销为10.10 束搜索 - 图59#card=math&code=%5Cmathcal%7BO%7D%28k%5Cleft%7C%5Cmathcal%7BY%7D%5Cright%7CT%27%29)。这介于贪婪搜索和穷举搜索的计算开销之间。此外,贪婪搜索可看作是束宽为1的束搜索。束搜索通过灵活的束宽10.10 束搜索 - 图60来权衡计算开销和搜索质量。

小结

  • 预测不定长序列的方法包括贪婪搜索、穷举搜索和束搜索。
  • 束搜索通过灵活的束宽来权衡计算开销和搜索质量。

注:本节与原书基本相同,原书传送门