1. import math
  2. import torch
  3. from torch import nn
  4. from d2l import torch as d2l

为了能够使多个头并行计算,
上面的 MultiHeadAttention 类使用了下面定义的两个转置函数。
具体来说,transpose_output 函数反转了 transpose_qkv 函数的操作

  1. def transpose_qkv(X, num_heads):
  2. # 输入 `X` 的形状: (`batch_size`, 查询或者“键-值”对的个数, `num_hiddens`).
  3. # 输出 `X` 的形状: (`batch_size`, 查询或者“键-值”对的个数, `num_heads`,
  4. # `num_hiddens` / `num_heads`)
  5. X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
  6. # 输出 `X` 的形状: (`batch_size`, `num_heads`, 查询或者“键-值”对的个数,
  7. # `num_hiddens` / `num_heads`)
  8. X = X.permute(0, 2, 1, 3)
  9. # `output` 的形状: (`batch_size` * `num_heads`, 查询或者“键-值”对的个数,
  10. # `num_hiddens` / `num_heads`)
  11. return X.reshape(-1, X.shape[2], X.shape[3])
  1. def transpose_output(X, num_heads):
  2. """逆转 `transpose_qkv` 函数的操作"""
  3. X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
  4. X = X.permute(0, 2, 1, 3)
  5. return X.reshape(X.shape[0], X.shape[1], -1)

10.5.1. 模型

d2l.DotProductAttention()函数在10.3.3小节

  1. class MultiHeadAttention(nn.Module):
  2. def __init__(self, key_size, query_size, value_size, num_hiddens,
  3. num_heads, dropout, bias=False, **kwargs):
  4. super(MultiHeadAttention, self).__init__(**kwargs)
  5. self.num_heads = num_heads
  6. self.attention = d2l.DotProductAttention(dropout)
  7. self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
  8. self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
  9. self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
  10. self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
  11. def forward(self, queries, keys, values, valid_lens):
  12. # `queries`, `keys`, or `values` 的形状:
  13. # (`batch_size`, 查询或者“键-值”对的个数, `num_hiddens`)
  14. # `valid_lens` 的形状:
  15. # (`batch_size`,) or (`batch_size`, 查询的个数)
  16. # 经过变换后,输出的 `queries`, `keys`, or `values` 的形状:
  17. # (`batch_size` * `num_heads`, 查询或者“键-值”对的个数,
  18. # `num_hiddens` / `num_heads`)
  19. queries = transpose_qkv(self.W_q(queries), self.num_heads)
  20. keys = transpose_qkv(self.W_k(keys), self.num_heads)
  21. values = transpose_qkv(self.W_v(values), self.num_heads)
  22. if valid_lens is not None:
  23. # 在轴 0,将第一项(标量或者矢量)复制 `num_heads` 次,
  24. # 然后如此复制第二项,然后诸如此类。
  25. valid_lens = torch.repeat_interleave(
  26. valid_lens, repeats=self.num_heads, dim=0)
  27. # `output` 的形状: (`batch_size` * `num_heads`, 查询的个数,
  28. # `num_hiddens` / `num_heads`)
  29. output = self.attention(queries, keys, values, valid_lens)
  30. # `output_concat` 的形状: (`batch_size`, 查询的个数, `num_hiddens`)
  31. output_concat = transpose_output(output, self.num_heads)
  32. return self.W_o(output_concat)