10.3.3. 缩放点积注意力

使用点积可以得到计算效率更高的评分函数。
但是点积操作要求查询和键具有相同的长度 d。
假设查询和键的所有元素都是独立的随机变量,
并且都满足均值为 0,和方差为 1。
那么两个向量的点积的均值为 0,方差为 d。
为确保无论向量长度如何,
点积的方差在不考虑向量长度的情况下仍然是 1,
则可以使用 缩放点积注意力(scaled dot-product attention) 评分函数:

10.3. 注意力评分函数 - 图1

将点积除以10.3. 注意力评分函数 - 图2。在实践中,
我们通常从小批量的角度来考虑提高效率,
例如基于 n 个查询和 m个键-值对计算注意力,
其中查询和键的长度为 d,值的长度为 v。
查询 10.3. 注意力评分函数 - 图3、键 10.3. 注意力评分函数 - 图4 和值 10.3. 注意力评分函数 - 图5 的缩放点积注意力是

10.3. 注意力评分函数 - 图6

在下面的缩放点积注意力的实现中,我们使用了 dropout 进行模型正则化。

  1. class DotProductAttention(nn.Module):
  2. """缩放点积注意力"""
  3. def __init__(self, dropout, **kwargs):
  4. super(DotProductAttention, self).__init__(**kwargs)
  5. self.dropout = nn.Dropout(dropout)
  6. # `queries` 的形状:(`batch_size`, 查询的个数, `d`)
  7. # `keys` 的形状:(`batch_size`, “键-值”对的个数, `d`)
  8. # `values` 的形状:(`batch_size`, “键-值”对的个数, 值的维度)
  9. # `valid_lens` 的形状: (`batch_size`,) 或者 (`batch_size`, 查询的个数)
  10. def forward(self, queries, keys, values, valid_lens=None):
  11. d = queries.shape[-1]
  12. # 设置 `transpose_b=True` 为了交换 `keys` 的最后两个维度
  13. scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
  14. self.attention_weights = masked_softmax(scores, valid_lens)
  15. return torch.bmm(self.dropout(self.attention_weights), values)
  1. queries = torch.normal(0, 1, (2, 1, 2))
  2. attention = DotProductAttention(dropout=0.5)
  3. attention.eval()
  4. attention(queries, keys, values, valid_lens)
  1. tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]],
  2. [[10.0000, 11.0000, 12.0000, 13.0000]]])
  1. d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
  2. xlabel='Keys', ylabel='Queries')

image.png