在上一讲中,我们使用了高斯核函数对Query和Key之间进行建模。我们将高斯核函数的指数视为注意力得分函数,这个函数的结果之后又会送入softmax运算,最后获得了键值对的概率分布(也就是注意力权重)。最后,注意力机制的输出是这些注意力权重值的加权总和。
实际上,注意力机制可以抽象成下图这样的过程:
Attention Scoring Functions - 图1
Attention Pooling函数f的算式也可以进行如下抽象:

a即注意力得分函数,它可以有多种实现方式。

Masked Softmax Operation

回想Seq2Seq中的mask操作,由于输入与输出不一定等长,所以我们需要将超出序列长度的地方遮住,也就是全部置零。这里也是差不多的思想,就是将softmax的输出加一个mask罢了。

  1. def masked_softmax(X, valid_lens):
  2. """Perform softmax operation by masking elements on the last axis."""
  3. # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor
  4. if valid_lens is None:
  5. return nn.functional.softmax(X, dim=-1)
  6. else:
  7. shape = X.shape
  8. if valid_lens.dim() == 1:
  9. valid_lens = torch.repeat_interleave(valid_lens, shape[1])
  10. else:
  11. valid_lens = valid_lens.reshape(-1)
  12. # On the last axis, replace masked elements with a very large negative
  13. # value, whose exponentiation outputs 0
  14. X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
  15. value=-1e6)
  16. return nn.functional.softmax(X.reshape(shape), dim=-1)

Additive Attention

当queries与keys不等长时,我们可以使用Additive Attention作为得分函数:

其中,都是可学习的参数,可采用全连接层实现。设Query、Keys以及Values的形状均为 (batch size, number of steps or sequence length in tokens, feature size),若q与k的token length不相等,则需要进行广播操作,最后得到一个得分

  1. #@save
  2. class AdditiveAttention(nn.Module):
  3. def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
  4. super(AdditiveAttention, self).__init__(**kwargs)
  5. self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
  6. self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
  7. self.w_v = nn.Linear(num_hiddens, 1, bias=False)
  8. self.dropout = nn.Dropout(dropout)
  9. def forward(self, queries, keys, values, valid_lens):
  10. queries, keys = self.W_q(queries), self.W_k(keys)
  11. # After dimension expansion, shape of `queries`: (`batch_size`, no. of
  12. # queries, 1, `num_hiddens`) and shape of `keys`: (`batch_size`, 1,
  13. # no. of key-value pairs, `num_hiddens`). Sum them up with
  14. # broadcasting
  15. features = queries.unsqueeze(2) + keys.unsqueeze(1)
  16. features = torch.tanh(features)
  17. # There is only one output of `self.w_v`, so we remove the last
  18. # one-dimensional entry from the shape. Shape of `scores`:
  19. # (`batch_size`, no. of queries, no. of key-value pairs)
  20. scores = self.w_v(features).squeeze(-1)
  21. self.attention_weights = masked_softmax(scores, valid_lens)
  22. # Shape of `values`: (`batch_size`, no. of key-value pairs, value
  23. # dimension)
  24. tmp = self.dropout(self.attention_weights)
  25. tmp2 = torch.bmm(tmp, values)
  26. return torch.bmm(self.dropout(self.attention_weights), values)

Scaled Dot-Product Attention

点乘是一个效率更高的得分函数实现方案,但点乘运算通常要求query与key的向量具有相同的长度d。query与key中的所有元素均值为0,独立同分布。

现有n个query,m个键值对,query以及key长度都是d,values长度是v,则点乘结果为:

它的形状是(n,v)

  1. #@save
  2. class DotProductAttention(nn.Module):
  3. """Scaled dot product attention."""
  4. def __init__(self, dropout, **kwargs):
  5. super(DotProductAttention, self).__init__(**kwargs)
  6. self.dropout = nn.Dropout(dropout)
  7. # Shape of `queries`: (`batch_size`, no. of queries, `d`)
  8. # Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`)
  9. # Shape of `values`: (`batch_size`, no. of key-value pairs, value
  10. # dimension)
  11. # Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries)
  12. def forward(self, queries, keys, values, valid_lens=None):
  13. d = queries.shape[-1]
  14. # Set `transpose_b=True` to swap the last two dimensions of `keys`
  15. scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
  16. self.attention_weights = masked_softmax(scores, valid_lens)
  17. return torch.bmm(self.dropout(self.attention_weights), values)
  • 当query和key是不同长度的向量时,我们可以使用加性注意力。当它们相同时,使用乘性注意力在计算上更加有效。