多头注意力

:label:sec_multihead-attention

在实践中,当给定相同的查询、键和值的集合时, 我们希望模型可以基于相同的注意力机制学习到不同的行为, 然后将不同的行为作为知识组合起来, 捕获序列内各种范围的依赖关系 (例如,短距离依赖和长距离依赖关系)。 因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces)可能是有益的。

为此,与其只使用单独一个注意力汇聚, 我们可以用独立学习得到的$h$组不同的 线性投影(linear projections)来变换查询、键和值。 然后,这$h$组变换后的查询、键和值将并行地送到注意力汇聚中。 最后,将这$h$个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出。 这种设计被称为多头注意力(multihead attention) :cite:Vaswani.Shazeer.Parmar.ea.2017。 对于$h$个注意力汇聚输出,每一个注意力汇聚都被称作一个(head)。 :numref:fig_multi-head-attention 展示了使用全连接层来实现可学习的线性变换的多头注意力。

多头注意力:多个头连结然后线性变换 :label:fig_multi-head-attention

模型

在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。 给定查询$\mathbf{q} \in \mathbb{R}^{d_q}$、 键$\mathbf{k} \in \mathbb{R}^{d_k}$和 值$\mathbf{v} \in \mathbb{R}^{d_v}$, 每个注意力头$\mathbf{h}_i$($i = 1, \ldots, h$)的计算方法为:

\mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},

其中,可学习的参数包括 $\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}$、 $\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}$和 $\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}$, 以及代表注意力汇聚的函数$f$。 $f$可以是 :numref:sec_attention-scoring-functions中的 加性注意力和缩放点积注意力。 多头注意力的输出需要经过另一个线性转换, 它对应着$h$个头连结后的结果,因此其可学习参数是 $\mathbf W_o\in\mathbb R^{p_o\times h p_v}$:

\mathbf W_o \begin{bmatrix}\mathbf h_1\\vdots\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.

基于这种设计,每个头都可能会关注输入的不同部分, 可以表示比简单加权平均值更复杂的函数。

```{.python .input} from d2l import mxnet as d2l import math from mxnet import autograd, np, npx from mxnet.gluon import nn npx.set_np()

  1. ```{.python .input}
  2. #@tab pytorch
  3. from d2l import torch as d2l
  4. import math
  5. import torch
  6. from torch import nn

```{.python .input}

@tab tensorflow

from d2l import tensorflow as d2l import tensorflow as tf

  1. ## 实现
  2. 在实现过程中,我们[**选择缩放点积注意力作为每一个注意力头**]。
  3. 为了避免计算代价和参数代价的大幅增长,
  4. 我们设定$p_q = p_k = p_v = p_o / h$
  5. 值得注意的是,如果我们将查询、键和值的线性变换的输出数量设置为
  6. $p_q h = p_k h = p_v h = p_o$
  7. 则可以并行计算$h$个头。
  8. 在下面的实现中,$p_o$是通过参数`num_hiddens`指定的。
  9. ```{.python .input}
  10. #@save
  11. class MultiHeadAttention(nn.Block):
  12. """多头注意力"""
  13. def __init__(self, num_hiddens, num_heads, dropout, use_bias=False,
  14. **kwargs):
  15. super(MultiHeadAttention, self).__init__(**kwargs)
  16. self.num_heads = num_heads
  17. self.attention = d2l.DotProductAttention(dropout)
  18. self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
  19. self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
  20. self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
  21. self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
  22. def forward(self, queries, keys, values, valid_lens):
  23. # queries,keys,values的形状:
  24. # (batch_size,查询或者“键-值”对的个数,num_hiddens)
  25. # valid_lens 的形状:
  26. # (batch_size,)或(batch_size,查询的个数)
  27. # 经过变换后,输出的queries,keys,values 的形状:
  28. # (batch_size*num_heads,查询或者“键-值”对的个数,
  29. # num_hiddens/num_heads)
  30. queries = transpose_qkv(self.W_q(queries), self.num_heads)
  31. keys = transpose_qkv(self.W_k(keys), self.num_heads)
  32. values = transpose_qkv(self.W_v(values), self.num_heads)
  33. if valid_lens is not None:
  34. # 在轴0,将第一项(标量或者矢量)复制num_heads次,
  35. # 然后如此复制第二项,然后诸如此类。
  36. valid_lens = valid_lens.repeat(self.num_heads, axis=0)
  37. # output的形状:(batch_size*num_heads,查询的个数,
  38. # num_hiddens/num_heads)
  39. output = self.attention(queries, keys, values, valid_lens)
  40. # output_concat的形状:(batch_size,查询的个数,num_hiddens)
  41. output_concat = transpose_output(output, self.num_heads)
  42. return self.W_o(output_concat)

```{.python .input}

@tab pytorch

@save

class MultiHeadAttention(nn.Module): “””多头注意力””” def init(self, keysize, querysize, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs): super(MultiHeadAttention, self).__init(**kwargs) self.num_heads = num_heads self.attention = d2l.DotProductAttention(dropout) self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) self.W_k = nn.Linear(key_size, num_hiddens, bias=bias) self.W_v = nn.Linear(value_size, num_hiddens, bias=bias) self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

  1. def forward(self, queries, keys, values, valid_lens):
  2. # queries,keys,values的形状:
  3. # (batch_size,查询或者“键-值”对的个数,num_hiddens)
  4. # valid_lens 的形状:
  5. # (batch_size,)或(batch_size,查询的个数)
  6. # 经过变换后,输出的queries,keys,values 的形状:
  7. # (batch_size*num_heads,查询或者“键-值”对的个数,
  8. # num_hiddens/num_heads)
  9. queries = transpose_qkv(self.W_q(queries), self.num_heads)
  10. keys = transpose_qkv(self.W_k(keys), self.num_heads)
  11. values = transpose_qkv(self.W_v(values), self.num_heads)
  12. if valid_lens is not None:
  13. # 在轴0,将第一项(标量或者矢量)复制num_heads次,
  14. # 然后如此复制第二项,然后诸如此类。
  15. valid_lens = torch.repeat_interleave(
  16. valid_lens, repeats=self.num_heads, dim=0)
  17. # output的形状:(batch_size*num_heads,查询的个数,
  18. # num_hiddens/num_heads)
  19. output = self.attention(queries, keys, values, valid_lens)
  20. # output_concat的形状:(batch_size,查询的个数,num_hiddens)
  21. output_concat = transpose_output(output, self.num_heads)
  22. return self.W_o(output_concat)
  1. ```{.python .input}
  2. #@tab tensorflow
  3. #@save
  4. class MultiHeadAttention(tf.keras.layers.Layer):
  5. """多头注意力"""
  6. def __init__(self, key_size, query_size, value_size, num_hiddens,
  7. num_heads, dropout, bias=False, **kwargs):
  8. super().__init__(**kwargs)
  9. self.num_heads = num_heads
  10. self.attention = d2l.DotProductAttention(dropout)
  11. self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
  12. self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
  13. self.W_v = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
  14. self.W_o = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
  15. def call(self, queries, keys, values, valid_lens, **kwargs):
  16. # queries,keys,values的形状:
  17. # (batch_size,查询或者“键-值”对的个数,num_hiddens)
  18. # valid_lens 的形状:
  19. # (batch_size,)或(batch_size,查询的个数)
  20. # 经过变换后,输出的queries,keys,values 的形状:
  21. # (batch_size*num_heads,查询或者“键-值”对的个数,
  22. # num_hiddens/num_heads)
  23. queries = transpose_qkv(self.W_q(queries), self.num_heads)
  24. keys = transpose_qkv(self.W_k(keys), self.num_heads)
  25. values = transpose_qkv(self.W_v(values), self.num_heads)
  26. if valid_lens is not None:
  27. # 在轴0,将第一项(标量或者矢量)复制num_heads次,
  28. # 然后如此复制第二项,然后诸如此类。
  29. valid_lens = tf.repeat(valid_lens, repeats=self.num_heads, axis=0)
  30. # output的形状:(batch_size*num_heads,查询的个数,
  31. # num_hiddens/num_heads)
  32. output = self.attention(queries, keys, values, valid_lens, **kwargs)
  33. # output_concat的形状:(batch_size,查询的个数,num_hiddens)
  34. output_concat = transpose_output(output, self.num_heads)
  35. return self.W_o(output_concat)

为了能够[使多个头并行计算], 上面的MultiHeadAttention类将使用下面定义的两个转置函数。 具体来说,transpose_output函数反转了transpose_qkv函数的操作。

```{.python .input}

@save

def transpose_qkv(X, num_heads): “””为了多注意力头的并行计算而变换形状”””

  1. # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
  2. # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
  3. # num_hiddens/num_heads)
  4. X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
  5. # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
  6. # num_hiddens/num_heads)
  7. X = X.transpose(0, 2, 1, 3)
  8. # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
  9. # num_hiddens/num_heads)
  10. return X.reshape(-1, X.shape[2], X.shape[3])

@save

def transpose_output(X, num_heads): “””逆转transpose_qkv函数的操作””” X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) X = X.transpose(0, 2, 1, 3) return X.reshape(X.shape[0], X.shape[1], -1)

  1. ```{.python .input}
  2. #@tab pytorch
  3. #@save
  4. def transpose_qkv(X, num_heads):
  5. """为了多注意力头的并行计算而变换形状"""
  6. # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
  7. # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
  8. # num_hiddens/num_heads)
  9. X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
  10. # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
  11. # num_hiddens/num_heads)
  12. X = X.permute(0, 2, 1, 3)
  13. # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
  14. # num_hiddens/num_heads)
  15. return X.reshape(-1, X.shape[2], X.shape[3])
  16. #@save
  17. def transpose_output(X, num_heads):
  18. """逆转transpose_qkv函数的操作"""
  19. X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
  20. X = X.permute(0, 2, 1, 3)
  21. return X.reshape(X.shape[0], X.shape[1], -1)

```{.python .input}

@tab tensorflow

@save

def transpose_qkv(X, num_heads): “””为了多注意力头的并行计算而变换形状”””

  1. # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
  2. # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
  3. # num_hiddens/num_heads)
  4. X = tf.reshape(X, shape=(X.shape[0], X.shape[1], num_heads, -1))
  5. # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
  6. # num_hiddens/num_heads)
  7. X = tf.transpose(X, perm=(0, 2, 1, 3))
  8. # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
  9. # num_hiddens/num_heads)
  10. return tf.reshape(X, shape=(-1, X.shape[2], X.shape[3]))

@save

def transpose_output(X, num_heads): “””逆转transpose_qkv函数的操作””” X = tf.reshape(X, shape=(-1, num_heads, X.shape[1], X.shape[2])) X = tf.transpose(X, perm=(0, 2, 1, 3)) return tf.reshape(X, shape=(X.shape[0], X.shape[1], -1))

  1. 下面我们使用键和值相同的小例子来[**测试**]我们编写的`MultiHeadAttention`类。
  2. 多头注意力输出的形状是(`batch_size``num_queries``num_hiddens`)。
  3. ```{.python .input}
  4. num_hiddens, num_heads = 100, 5
  5. attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
  6. attention.initialize()

```{.python .input}

@tab pytorch

num_hiddens, num_heads = 100, 5 attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5) attention.eval()

  1. ```{.python .input}
  2. #@tab tensorflow
  3. num_hiddens, num_heads = 100, 5
  4. attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
  5. num_hiddens, num_heads, 0.5)

```{.python .input}

@tab mxnet, pytorch

batch_size, num_queries = 2, 4 num_kvpairs, valid_lens = 6, d2l.tensor([3, 2]) X = d2l.ones((batch_size, num_queries, num_hiddens)) Y = d2l.ones((batch_size, num_kvpairs, num_hiddens)) attention(X, Y, Y, valid_lens).shape

  1. ```{.python .input}
  2. #@tab tensorflow
  3. batch_size, num_queries = 2, 4
  4. num_kvpairs, valid_lens = 6, d2l.tensor([3, 2])
  5. X = tf.ones((batch_size, num_queries, num_hiddens))
  6. Y = tf.ones((batch_size, num_kvpairs, num_hiddens))
  7. attention(X, Y, Y, valid_lens, training=False).shape

小结

  • 多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。
  • 基于适当的张量操作,可以实现多头注意力的并行计算。

练习

  1. 分别可视化这个实验中的多个头的注意力权重。
  2. 假设我们有一个完成训练的基于多头注意力的模型,现在希望修剪最不重要的注意力头以提高预测速度。如何设计实验来衡量注意力头的重要性呢?

:begin_tab:mxnet Discussions :end_tab:

:begin_tab:pytorch Discussions :end_tab: