import math
import torch
from torch import nn
from d2l import torch as d2l
为了能够使多个头并行计算,
上面的 MultiHeadAttention 类使用了下面定义的两个转置函数。
具体来说,transpose_output 函数反转了 transpose_qkv 函数的操作
def transpose_qkv(X, num_heads):
# 输入 `X` 的形状: (`batch_size`, 查询或者“键-值”对的个数, `num_hiddens`).
# 输出 `X` 的形状: (`batch_size`, 查询或者“键-值”对的个数, `num_heads`,
# `num_hiddens` / `num_heads`)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# 输出 `X` 的形状: (`batch_size`, `num_heads`, 查询或者“键-值”对的个数,
# `num_hiddens` / `num_heads`)
X = X.permute(0, 2, 1, 3)
# `output` 的形状: (`batch_size` * `num_heads`, 查询或者“键-值”对的个数,
# `num_hiddens` / `num_heads`)
return X.reshape(-1, X.shape[2], X.shape[3])
def transpose_output(X, num_heads):
"""逆转 `transpose_qkv` 函数的操作"""
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
10.5.1. 模型
d2l.DotProductAttention()函数在10.3.3小节
class MultiHeadAttention(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
# `queries`, `keys`, or `values` 的形状:
# (`batch_size`, 查询或者“键-值”对的个数, `num_hiddens`)
# `valid_lens` 的形状:
# (`batch_size`,) or (`batch_size`, 查询的个数)
# 经过变换后,输出的 `queries`, `keys`, or `values` 的形状:
# (`batch_size` * `num_heads`, 查询或者“键-值”对的个数,
# `num_hiddens` / `num_heads`)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# 在轴 0,将第一项(标量或者矢量)复制 `num_heads` 次,
# 然后如此复制第二项,然后诸如此类。
valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)
# `output` 的形状: (`batch_size` * `num_heads`, 查询的个数,
# `num_hiddens` / `num_heads`)
output = self.attention(queries, keys, values, valid_lens)
# `output_concat` 的形状: (`batch_size`, 查询的个数, `num_hiddens`)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)