参考来源:
CSDN:超平实版 Pytorch Self-Attention: 参数详解(尤其是mask)(使用nn.MultiheadAttention)
Self-Attention 的结构图
self-attention 的具体结构参照下图。
(图中为输出第二项 attention output 的情况,k 与 q 为 key、query 的缩写)
本文中将使用 Pytorch 的 **torch.nn.MultiheadAttention**
来实现 **self-attention**
。
2. forward 输入中的 query、key、value
首先,前三个输入是最重要的部分 query、key、value。由图 1 可知,我们 self-attention 的这三样东西其实是一样的,它们的形状都是:**(L,N,E)**
。
**L**
:输入 sequence
的长度(例如一个句子的长度)。**N**
:批大小(例如一个批的句子个数)。**E**
:词向量长度。
3. forward 的输出
输出的内容很少只有两项:
- attn_output
即通过 self-attention
之后,从每一个词语位置输出来的 attention
。其形状为 (L,N,E)
,是和输入的 query
它们形状一样的。因为毕竟只是给 value
乘了一个 weight
。
- attn_output_weights
即 attention weights
,形状是 (N,L,L)
,因为每一个单词和任意另一个单词之间都会产生一个 weight
,所以每一句句子的 weight
数量是 L*L
。
4. 实例化一个 nn.MultiheadAttention
这里对 MultiheadAttention
进行一个实例化并传入一些参数,实例化之后我们得到的东西我们就可以向它传入 input
了。
实例化时的代码:
multihead_attn = nn.MultiheadAttention(
embed_dim,
num_heads
)
其中,embed_dim
是每一个单词本来的词向量长度;num_heads
是我们 MultiheadAttention
的 head
的数量。
pytorch 的 MultiheadAttention
应该使用的是 Narrow self-attention
机制,即,把 embedding
分割成 num_heads
份,每一份分别拿来做一下 attention
。
也就是说:单词 1 的第一份、单词 2 的第一份、单词 3 的第一份…会当成一个 sequence
,做一次我们图 1 所示的 self-attention
。
然后,单词 1 的第二份、单词 2 的第二份、单词 3 的第二份…也会做一次
直到单词 1 的第 num_heads
份、单词 2 的第 num_heads
份、单词 3 的第 num_heads
份…也做完 self-attention
。
从每一份我们都会得到一个 (L,N,E)
形状的输出,我们把这些全部 concat
在一起,会得到一个 (L,N,E*num_heads)
的张量。
这时候,我们拿一个矩阵,把这个张量的维度变回 (L,N,E)
即可输出。
5. 进行 forward 操作
我们把我们刚才实例化好的 multihead_attn
拿来进行 forward
操作(即输入 input
得到 output
):
attn_output, attn_output_weights = multihead_attn(query, key, value)
6. 关于 mask
mask
可以理解成遮罩、面具,作用是帮助我们“遮挡”掉我们不需要的东西,即让被遮挡的东西不影响我们的 attention
过程。
在 forward
的时候,有两个 mask
参数可以设置:
**key_padding_mask**
每一个 batch
的每一个句子的长度一般是不可能完全相同的,所以我们会使用 padding
把一些空缺补上。而这里的这个 key_padding_mask
是用来“遮挡”这些 padding
的。
这个 mask
是二元(binary
)的,也就是说,它是一个矩阵和我们 key
的大小是一样的,里面的值是 1 或 0,我们先取得 key
中有 padding
的位置,然后把 mask
里相应位置的数字设置为 1,这样 attention
就会把 key
相应的部分变为”-inf
“。 (为什么变为 -inf
我们稍后再说)
**attn_mask**
这个 mask
经常是用来遮挡“正确答案”的:
假如你想要用这个模型每次预测下一个单词,我们每一个位置的 attention
输出是怎么得来的?是不是要看一遍整个序列,然后每一个单词都计算一个 attention weight
?那也就是说,你在预测第 5 个词的时候,你其实会看到整个序列,这样的话你在预测之前不就已经知道第 5 个单词是什么了,这就是作弊了。
我们不想让模型作弊,因为在真实使用这个模型去预测的时候,我们是没有整个序列的信息的。那么怎么办?那就让第 5 个单词的 attention weight=0
吧,即声明:我不想看这个单词,我的注意力一点也别分给它。
如何让这个 weight=0
:
我们先想象一下,我们目前拥有的 attention scores
是什么样的?(注:attention_score
是 attention_weight
的初始样子,经过 softmax
之后会变成 attention_weight
。attention_score
和 weight
的形状是一样的,毕竟只有一个 softmax
的差别)
我们之前提到,attention weights
的形状是 L*L
,因为每个单词两两之间都有一个 weight
。
如下图所示,我用蓝笔圈出的部分,就是“我想要预测x2”时,整个 sequence
的 attention score
情况。我用叉划掉的地方,是我们希望 =0
的位置,因为我们想让 x2
、x3
、x4
的权值为 0
,即:预测 x2
的时候,我们的注意力只能放在 x1
上。
对于其他行,你可以以此类推,发现我们需要一个三角形区域的 attention weight=0
(见最底下的图), 这时候我们的 attn_mask
这时候就出场了,把这个 mask
做成三角形即可。
和 key_padding_mask
不同,我们的 attn_mask
不是 binary
的,而是一个“additive mask
”。
什么是 additive mask
呢?就是我们 mask
上设置的值,会被加到我们原本的 attention score
上。我们要让三角形区域的 weight=0
,我们这个三角 mask
设置什么值好呢?答案是 -inf
,(这个 -inf
在 key_padding_mask
的讲解中也出现了,这里就来说说为什么要用-inf
)。我们上面提到了,attention score
要经过一个 softmax
才变成 attention_weights
。
我们都知道 softmax
的式子可以表示为:
当我们 attention score
的值设置为 -inf
(可以看作这里式子里的 zj=−inf
),于是通过 softmax
之后我们的 attention weight
就会趋近于0
了,这就是为什么我们这里的两个 mask
都要用到 -inf
。
三角形的样子和 softmax
前后的情况可见下图:(图源见水印)