参考来源:
CSDN:超平实版 Pytorch Self-Attention: 参数详解(尤其是mask)(使用nn.MultiheadAttention)

Self-Attention 的结构图

self-attention 的具体结构参照下图。
image.png
(图中为输出第二项 attention output 的情况,k 与 q 为 key、query 的缩写)
本文中将使用 Pytorch 的 **torch.nn.MultiheadAttention** 来实现 **self-attention**

2. forward 输入中的 query、key、value

首先,前三个输入是最重要的部分 query、key、value。由图 1 可知,我们 self-attention 的这三样东西其实是一样的,它们的形状都是:**(L,N,E)**

**L**:输入 sequence 的长度(例如一个句子的长度)。
**N**:批大小(例如一个批的句子个数)。
**E**:词向量长度。

3. forward 的输出

输出的内容很少只有两项:

  1. attn_output

即通过 self-attention 之后,从每一个词语位置输出来的 attention。其形状为 (L,N,E),是和输入的 query 它们形状一样的。因为毕竟只是给 value 乘了一个 weight

  1. attn_output_weights

attention weights,形状是 (N,L,L),因为每一个单词和任意另一个单词之间都会产生一个 weight,所以每一句句子的 weight 数量是 L*L

4. 实例化一个 nn.MultiheadAttention

这里对 MultiheadAttention 进行一个实例化并传入一些参数,实例化之后我们得到的东西我们就可以向它传入 input 了。
实例化时的代码:

  1. multihead_attn = nn.MultiheadAttention(
  2. embed_dim,
  3. num_heads
  4. )

其中,embed_dim 是每一个单词本来的词向量长度;num_heads 是我们 MultiheadAttentionhead 的数量。
pytorch 的 MultiheadAttention 应该使用的是 Narrow self-attention 机制,即,把 embedding 分割成 num_heads 份,每一份分别拿来做一下 attention
也就是说:单词 1 的第一份、单词 2 的第一份、单词 3 的第一份…会当成一个 sequence,做一次我们图 1 所示的 self-attention
然后,单词 1 的第二份、单词 2 的第二份、单词 3 的第二份…也会做一次
直到单词 1 的第 num_heads 份、单词 2 的第 num_heads 份、单词 3 的第 num_heads 份…也做完 self-attention
从每一份我们都会得到一个 (L,N,E) 形状的输出,我们把这些全部 concat 在一起,会得到一个 (L,N,E*num_heads) 的张量。
这时候,我们拿一个矩阵,把这个张量的维度变回 (L,N,E) 即可输出。

5. 进行 forward 操作

我们把我们刚才实例化好的 multihead_attn 拿来进行 forward 操作(即输入 input 得到 output):

  1. attn_output, attn_output_weights = multihead_attn(query, key, value)

6. 关于 mask

mask 可以理解成遮罩、面具,作用是帮助我们“遮挡”掉我们不需要的东西,即让被遮挡的东西不影响我们的 attention 过程。
forward 的时候,有两个 mask 参数可以设置:

  1. **key_padding_mask**

每一个 batch 的每一个句子的长度一般是不可能完全相同的,所以我们会使用 padding 把一些空缺补上。而这里的这个 key_padding_mask 是用来“遮挡”这些 padding 的。
这个 mask 是二元(binary)的,也就是说,它是一个矩阵和我们 key 的大小是一样的,里面的值是 1 或 0,我们先取得 key 中有 padding 的位置,然后把 mask 里相应位置的数字设置为 1,这样 attention 就会把 key 相应的部分变为”-inf“。 (为什么变为 -inf 我们稍后再说)

  1. **attn_mask**

这个 mask 经常是用来遮挡“正确答案”的:
假如你想要用这个模型每次预测下一个单词,我们每一个位置的 attention 输出是怎么得来的?是不是要看一遍整个序列,然后每一个单词都计算一个 attention weight?那也就是说,你在预测第 5 个词的时候,你其实会看到整个序列,这样的话你在预测之前不就已经知道第 5 个单词是什么了,这就是作弊了。
我们不想让模型作弊,因为在真实使用这个模型去预测的时候,我们是没有整个序列的信息的。那么怎么办?那就让第 5 个单词的 attention weight=0 吧,即声明:我不想看这个单词,我的注意力一点也别分给它。

如何让这个 weight=0
我们先想象一下,我们目前拥有的 attention scores 是什么样的?(注:attention_scoreattention_weight 的初始样子,经过 softmax 之后会变成 attention_weight
attention_scoreweight 的形状是一样的,毕竟只有一个 softmax 的差别)
我们之前提到,attention weights 的形状是 L*L,因为每个单词两两之间都有一个 weight
如下图所示,我用蓝笔圈出的部分,就是“我想要预测x2”时,整个 sequenceattention score 情况。我用叉划掉的地方,是我们希望 =0 的位置,因为我们想让 x2x3x4 的权值为 0,即:预测 x2 的时候,我们的注意力只能放在 x1 上。
image.png
对于其他行,你可以以此类推,发现我们需要一个三角形区域的 attention weight=0(见最底下的图), 这时候我们的 attn_mask 这时候就出场了,把这个 mask 做成三角形即可。
key_padding_mask 不同,我们的 attn_mask 不是 binary 的,而是一个“additive mask”。
什么是 additive mask 呢?就是我们 mask 上设置的值,会被加到我们原本的 attention score 上。我们要让三角形区域的 weight=0,我们这个三角 mask 设置什么值好呢?答案是 -inf,(这个 -infkey_padding_mask 的讲解中也出现了,这里就来说说为什么要用-inf)。我们上面提到了,attention score 要经过一个 softmax 才变成 attention_weights
我们都知道 softmax 的式子可以表示为:
超平实版 Pytorch Self-Attention:参数详解(尤其是 mask)(使用 nn.MultiheadAttention) - 图3
当我们 attention score 的值设置为 -inf (可以看作这里式子里的 zj=−inf),于是通过 softmax 之后我们的 attention weight就会趋近于0了,这就是为什么我们这里的两个 mask 都要用到 -inf
三角形的样子和 softmax 前后的情况可见下图:(图源见水印)
image.png