在上一讲中,我们使用了高斯核函数对Query和Key之间进行建模。我们将高斯核函数的指数视为注意力得分函数,这个函数的结果之后又会送入softmax运算,最后获得了键值对的概率分布(也就是注意力权重)。最后,注意力机制的输出是这些注意力权重值的加权总和。
实际上,注意力机制可以抽象成下图这样的过程:
Attention Pooling函数f的算式也可以进行如下抽象:
Masked Softmax Operation
回想Seq2Seq中的mask操作,由于输入与输出不一定等长,所以我们需要将超出序列长度的地方遮住,也就是全部置零。这里也是差不多的思想,就是将softmax的输出加一个mask罢了。
def masked_softmax(X, valid_lens):"""Perform softmax operation by masking elements on the last axis."""# `X`: 3D tensor, `valid_lens`: 1D or 2D tensorif valid_lens is None:return nn.functional.softmax(X, dim=-1)else:shape = X.shapeif valid_lens.dim() == 1:valid_lens = torch.repeat_interleave(valid_lens, shape[1])else:valid_lens = valid_lens.reshape(-1)# On the last axis, replace masked elements with a very large negative# value, whose exponentiation outputs 0X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,value=-1e6)return nn.functional.softmax(X.reshape(shape), dim=-1)
Additive Attention
当queries与keys不等长时,我们可以使用Additive Attention作为得分函数:
其中,都是可学习的参数,可采用全连接层实现。设Query、Keys以及Values的形状均为 (batch size, number of steps or sequence length in tokens, feature size),若q与k的token length不相等,则需要进行广播操作,最后得到一个得分
#@saveclass AdditiveAttention(nn.Module):def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):super(AdditiveAttention, self).__init__(**kwargs)self.W_k = nn.Linear(key_size, num_hiddens, bias=False)self.W_q = nn.Linear(query_size, num_hiddens, bias=False)self.w_v = nn.Linear(num_hiddens, 1, bias=False)self.dropout = nn.Dropout(dropout)def forward(self, queries, keys, values, valid_lens):queries, keys = self.W_q(queries), self.W_k(keys)# After dimension expansion, shape of `queries`: (`batch_size`, no. of# queries, 1, `num_hiddens`) and shape of `keys`: (`batch_size`, 1,# no. of key-value pairs, `num_hiddens`). Sum them up with# broadcastingfeatures = queries.unsqueeze(2) + keys.unsqueeze(1)features = torch.tanh(features)# There is only one output of `self.w_v`, so we remove the last# one-dimensional entry from the shape. Shape of `scores`:# (`batch_size`, no. of queries, no. of key-value pairs)scores = self.w_v(features).squeeze(-1)self.attention_weights = masked_softmax(scores, valid_lens)# Shape of `values`: (`batch_size`, no. of key-value pairs, value# dimension)tmp = self.dropout(self.attention_weights)tmp2 = torch.bmm(tmp, values)return torch.bmm(self.dropout(self.attention_weights), values)
Scaled Dot-Product Attention
点乘是一个效率更高的得分函数实现方案,但点乘运算通常要求query与key的向量具有相同的长度d。query与key中的所有元素均值为0,独立同分布。
现有n个query,m个键值对,query以及key长度都是d,values长度是v,则点乘结果为:
它的形状是(n,v)
#@saveclass DotProductAttention(nn.Module):"""Scaled dot product attention."""def __init__(self, dropout, **kwargs):super(DotProductAttention, self).__init__(**kwargs)self.dropout = nn.Dropout(dropout)# Shape of `queries`: (`batch_size`, no. of queries, `d`)# Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`)# Shape of `values`: (`batch_size`, no. of key-value pairs, value# dimension)# Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries)def forward(self, queries, keys, values, valid_lens=None):d = queries.shape[-1]# Set `transpose_b=True` to swap the last two dimensions of `keys`scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)self.attention_weights = masked_softmax(scores, valid_lens)return torch.bmm(self.dropout(self.attention_weights), values)
- 当query和key是不同长度的向量时,我们可以使用加性注意力。当它们相同时,使用乘性注意力在计算上更加有效。
