import mathimport torchfrom torch import nnfrom d2l import torch as d2l
为了能够使多个头并行计算,
上面的 MultiHeadAttention 类使用了下面定义的两个转置函数。
具体来说,transpose_output 函数反转了 transpose_qkv 函数的操作
def transpose_qkv(X, num_heads):# 输入 `X` 的形状: (`batch_size`, 查询或者“键-值”对的个数, `num_hiddens`).# 输出 `X` 的形状: (`batch_size`, 查询或者“键-值”对的个数, `num_heads`,# `num_hiddens` / `num_heads`)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 输出 `X` 的形状: (`batch_size`, `num_heads`, 查询或者“键-值”对的个数,# `num_hiddens` / `num_heads`)X = X.permute(0, 2, 1, 3)# `output` 的形状: (`batch_size` * `num_heads`, 查询或者“键-值”对的个数,# `num_hiddens` / `num_heads`)return X.reshape(-1, X.shape[2], X.shape[3])
def transpose_output(X, num_heads):"""逆转 `transpose_qkv` 函数的操作"""X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])X = X.permute(0, 2, 1, 3)return X.reshape(X.shape[0], X.shape[1], -1)
10.5.1. 模型
d2l.DotProductAttention()函数在10.3.3小节
class MultiHeadAttention(nn.Module):def __init__(self, key_size, query_size, value_size, num_hiddens,num_heads, dropout, bias=False, **kwargs):super(MultiHeadAttention, self).__init__(**kwargs)self.num_heads = num_headsself.attention = d2l.DotProductAttention(dropout)self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):# `queries`, `keys`, or `values` 的形状:# (`batch_size`, 查询或者“键-值”对的个数, `num_hiddens`)# `valid_lens` 的形状:# (`batch_size`,) or (`batch_size`, 查询的个数)# 经过变换后,输出的 `queries`, `keys`, or `values` 的形状:# (`batch_size` * `num_heads`, 查询或者“键-值”对的个数,# `num_hiddens` / `num_heads`)queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)if valid_lens is not None:# 在轴 0,将第一项(标量或者矢量)复制 `num_heads` 次,# 然后如此复制第二项,然后诸如此类。valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)# `output` 的形状: (`batch_size` * `num_heads`, 查询的个数,# `num_hiddens` / `num_heads`)output = self.attention(queries, keys, values, valid_lens)# `output_concat` 的形状: (`batch_size`, 查询的个数, `num_hiddens`)output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)
