参考来源:
CSDN:超平实版 Pytorch Self-Attention: 参数详解(尤其是mask)(使用nn.MultiheadAttention)
Self-Attention 的结构图
self-attention 的具体结构参照下图。
(图中为输出第二项 attention output 的情况,k 与 q 为 key、query 的缩写)
本文中将使用 Pytorch 的 **torch.nn.MultiheadAttention** 来实现 **self-attention**。
2. forward 输入中的 query、key、value
首先,前三个输入是最重要的部分 query、key、value。由图 1 可知,我们 self-attention 的这三样东西其实是一样的,它们的形状都是:**(L,N,E)**。
**L**:输入 sequence 的长度(例如一个句子的长度)。**N**:批大小(例如一个批的句子个数)。**E**:词向量长度。
3. forward 的输出
输出的内容很少只有两项:
- attn_output
即通过 self-attention 之后,从每一个词语位置输出来的 attention。其形状为 (L,N,E),是和输入的 query 它们形状一样的。因为毕竟只是给 value 乘了一个 weight。
- attn_output_weights
即 attention weights,形状是 (N,L,L),因为每一个单词和任意另一个单词之间都会产生一个 weight,所以每一句句子的 weight 数量是 L*L。
4. 实例化一个 nn.MultiheadAttention
这里对 MultiheadAttention 进行一个实例化并传入一些参数,实例化之后我们得到的东西我们就可以向它传入 input 了。
实例化时的代码:
multihead_attn = nn.MultiheadAttention(embed_dim,num_heads)
其中,embed_dim 是每一个单词本来的词向量长度;num_heads 是我们 MultiheadAttention 的 head 的数量。
pytorch 的 MultiheadAttention 应该使用的是 Narrow self-attention 机制,即,把 embedding 分割成 num_heads 份,每一份分别拿来做一下 attention 。
也就是说:单词 1 的第一份、单词 2 的第一份、单词 3 的第一份…会当成一个 sequence,做一次我们图 1 所示的 self-attention。
然后,单词 1 的第二份、单词 2 的第二份、单词 3 的第二份…也会做一次
直到单词 1 的第 num_heads 份、单词 2 的第 num_heads 份、单词 3 的第 num_heads 份…也做完 self-attention。
从每一份我们都会得到一个 (L,N,E) 形状的输出,我们把这些全部 concat 在一起,会得到一个 (L,N,E*num_heads) 的张量。
这时候,我们拿一个矩阵,把这个张量的维度变回 (L,N,E) 即可输出。
5. 进行 forward 操作
我们把我们刚才实例化好的 multihead_attn 拿来进行 forward 操作(即输入 input 得到 output):
attn_output, attn_output_weights = multihead_attn(query, key, value)
6. 关于 mask
mask 可以理解成遮罩、面具,作用是帮助我们“遮挡”掉我们不需要的东西,即让被遮挡的东西不影响我们的 attention 过程。
在 forward 的时候,有两个 mask 参数可以设置:
**key_padding_mask**
每一个 batch 的每一个句子的长度一般是不可能完全相同的,所以我们会使用 padding 把一些空缺补上。而这里的这个 key_padding_mask 是用来“遮挡”这些 padding 的。
这个 mask 是二元(binary)的,也就是说,它是一个矩阵和我们 key 的大小是一样的,里面的值是 1 或 0,我们先取得 key 中有 padding 的位置,然后把 mask 里相应位置的数字设置为 1,这样 attention 就会把 key 相应的部分变为”-inf“。 (为什么变为 -inf 我们稍后再说)
**attn_mask**
这个 mask 经常是用来遮挡“正确答案”的:
假如你想要用这个模型每次预测下一个单词,我们每一个位置的 attention 输出是怎么得来的?是不是要看一遍整个序列,然后每一个单词都计算一个 attention weight?那也就是说,你在预测第 5 个词的时候,你其实会看到整个序列,这样的话你在预测之前不就已经知道第 5 个单词是什么了,这就是作弊了。
我们不想让模型作弊,因为在真实使用这个模型去预测的时候,我们是没有整个序列的信息的。那么怎么办?那就让第 5 个单词的 attention weight=0 吧,即声明:我不想看这个单词,我的注意力一点也别分给它。
如何让这个 weight=0:
我们先想象一下,我们目前拥有的 attention scores 是什么样的?(注:attention_score 是 attention_weight 的初始样子,经过 softmax 之后会变成 attention_weight。attention_score 和 weight 的形状是一样的,毕竟只有一个 softmax 的差别)
我们之前提到,attention weights 的形状是 L*L,因为每个单词两两之间都有一个 weight。
如下图所示,我用蓝笔圈出的部分,就是“我想要预测x2”时,整个 sequence 的 attention score 情况。我用叉划掉的地方,是我们希望 =0 的位置,因为我们想让 x2、x3、x4 的权值为 0,即:预测 x2 的时候,我们的注意力只能放在 x1 上。
对于其他行,你可以以此类推,发现我们需要一个三角形区域的 attention weight=0(见最底下的图), 这时候我们的 attn_mask 这时候就出场了,把这个 mask 做成三角形即可。
和 key_padding_mask 不同,我们的 attn_mask 不是 binary 的,而是一个“additive mask”。
什么是 additive mask 呢?就是我们 mask 上设置的值,会被加到我们原本的 attention score 上。我们要让三角形区域的 weight=0,我们这个三角 mask 设置什么值好呢?答案是 -inf,(这个 -inf 在 key_padding_mask 的讲解中也出现了,这里就来说说为什么要用-inf)。我们上面提到了,attention score 要经过一个 softmax 才变成 attention_weights。
我们都知道 softmax 的式子可以表示为:
当我们 attention score 的值设置为 -inf (可以看作这里式子里的 zj=−inf),于是通过 softmax 之后我们的 attention weight就会趋近于0了,这就是为什么我们这里的两个 mask 都要用到 -inf。
三角形的样子和 softmax 前后的情况可见下图:(图源见水印)
