Attention Scoring Functions
:label:sec_attention-scoring-functions
In :numref:sec_nadaraya-waston,
we used a Gaussian kernel to model
interactions between queries and keys.
Treating the exponent of the Gaussian kernel
in :eqref:eq_nadaraya-waston-gaussian
as an attention scoring function (or scoring function for short),
the results of this function were
essentially fed into
a softmax operation.
As a result,
we obtained
a probability distribution (attention weights)
over values that are paired with keys.
In the end,
the output of the attention pooling
is simply a weighted sum of the values
based on these attention weights.
At a high level,
we can use the above algorithm
to instantiate the framework of attention mechanisms
in :numref:fig_qkv.
Denoting an attention scoring function by $a$,
:numref:fig_attention_output
illustrates how the output of attention pooling
can be computed as a weighted sum of values.
Since attention weights are
a probability distribution,
the weighted sum is essentially
a weighted average.
:label:
fig_attention_output
Mathematically, suppose that we have a query $\mathbf{q} \in \mathbb{R}^q$ and $m$ key-value pairs $(\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)$, where any $\mathbf{k}_i \in \mathbb{R}^k$ and any $\mathbf{v}_i \in \mathbb{R}^v$. The attention pooling $f$ is instantiated as a weighted sum of the values:
f(\mathbf{q}, (\mathbf{k}1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i \in \mathbb{R}^v,
:eqlabel:eq_attn-pooling
where the attention weight (scalar) for the query $\mathbf{q}$ and key $\mathbf{k}_i$ is computed by the softmax operation of an attention scoring function $a$ that maps two vectors to a scalar:
\alpha(\mathbf{q}, \mathbf{k}i) = \mathrm{softmax}(a(\mathbf{q}, \mathbf{k}_i)) = \frac{\exp(a(\mathbf{q}, \mathbf{k}_i))}{\sum{j=1}^m \exp(a(\mathbf{q}, \mathbf{k}_j))} \in \mathbb{R}.
:eqlabel:eq_attn-scoring-alpha
As we can see, different choices of the attention scoring function $a$ lead to different behaviors of attention pooling. In this section, we introduce two popular scoring functions that we will use to develop more sophisticated attention mechanisms later.
```{.python .input} import math from d2l import mxnet as d2l from mxnet import np, npx from mxnet.gluon import nn npx.set_np()
```{.python .input}#@tab pytorchfrom d2l import torch as d2limport mathimport torchfrom torch import nn
Masked Softmax Operation
As we just mentioned,
a softmax operation is used to
output a probability distribution as attention weights.
In some cases,
not all the values should be fed into attention pooling.
For instance,
for efficient minibatch processing in :numref:sec_machine_translation,
some text sequences are padded with
special tokens that do not carry meaning.
To get an attention pooling
over
only meaningful tokens as values,
we can specify a valid sequence length (in number of tokens)
to filter out those beyond this specified range
when computing softmax.
In this way,
we can implement such a masked softmax operation
in the following masked_softmax function,
where any value beyond the valid length
is masked as zero.
```{.python .input}
@save
def masked_softmax(X, valid_lens): “””Perform softmax operation by masking elements on the last axis.”””
# `X`: 3D tensor, `valid_lens`: 1D or 2D tensorif valid_lens is None:return npx.softmax(X)else:shape = X.shapeif valid_lens.ndim == 1:valid_lens = valid_lens.repeat(shape[1])else:valid_lens = valid_lens.reshape(-1)# On the last axis, replace masked elements with a very large negative# value, whose exponentiation outputs 0X = npx.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, True,value=-1e6, axis=1)return npx.softmax(X).reshape(shape)
```{.python .input}#@tab pytorch#@savedef masked_softmax(X, valid_lens):"""Perform softmax operation by masking elements on the last axis."""# `X`: 3D tensor, `valid_lens`: 1D or 2D tensorif valid_lens is None:return nn.functional.softmax(X, dim=-1)else:shape = X.shapeif valid_lens.dim() == 1:valid_lens = torch.repeat_interleave(valid_lens, shape[1])else:valid_lens = valid_lens.reshape(-1)# On the last axis, replace masked elements with a very large negative# value, whose exponentiation outputs 0X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,value=-1e6)return nn.functional.softmax(X.reshape(shape), dim=-1)
To demonstrate how this function works, consider a minibatch of two $2 \times 4$ matrix examples, where the valid lengths for these two examples are two and three, respectively. As a result of the masked softmax operation, values beyond the valid lengths are all masked as zero.
```{.python .input} masked_softmax(np.random.uniform(size=(2, 2, 4)), d2l.tensor([2, 3]))
```{.python .input}#@tab pytorchmasked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
Similarly, we can also use a two-dimensional tensor to specify valid lengths for every row in each matrix example.
```{.python .input} masked_softmax(np.random.uniform(size=(2, 2, 4)), d2l.tensor([[1, 3], [2, 4]]))
```{.python .input}#@tab pytorchmasked_softmax(torch.rand(2, 2, 4), d2l.tensor([[1, 3], [2, 4]]))
Additive Attention
:label:subsec_additive-attention
In general, when queries and keys are vectors of different lengths, we can use additive attention as the scoring function. Given a query $\mathbf{q} \in \mathbb{R}^q$ and a key $\mathbf{k} \in \mathbb{R}^k$, the additive attention scoring function
a(\mathbf q, \mathbf k) = \mathbf w_v^\top \text{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R},
:eqlabel:eq_additive-attn
where
learnable parameters
$\mathbf W_q\in\mathbb R^{h\times q}$, $\mathbf W_k\in\mathbb R^{h\times k}$, and $\mathbf w_v\in\mathbb R^{h}$.
Equivalent to :eqref:eq_additive-attn,
the query and the key are concatenated
and fed into an MLP with a single hidden layer
whose number of hidden units is $h$, a hyperparameter.
By using $\tanh$ as the activation function and disabling
bias terms,
we implement additive attention in the following.
```{.python .input}
@save
class AdditiveAttention(nn.Block): “””Additive attention.””” def init(self, numhiddens, dropout, **kwargs): super(AdditiveAttention, self)._init(**kwargs)
# Use `flatten=False` to only transform the last axis so that the# shapes for the other axes are kept the sameself.W_k = nn.Dense(num_hiddens, use_bias=False, flatten=False)self.W_q = nn.Dense(num_hiddens, use_bias=False, flatten=False)self.w_v = nn.Dense(1, use_bias=False, flatten=False)self.dropout = nn.Dropout(dropout)def forward(self, queries, keys, values, valid_lens):queries, keys = self.W_q(queries), self.W_k(keys)# After dimension expansion, shape of `queries`: (`batch_size`, no. of# queries, 1, `num_hiddens`) and shape of `keys`: (`batch_size`, 1,# no. of key-value pairs, `num_hiddens`). Sum them up with# broadcastingfeatures = np.expand_dims(queries, axis=2) + np.expand_dims(keys, axis=1)features = np.tanh(features)# There is only one output of `self.w_v`, so we remove the last# one-dimensional entry from the shape. Shape of `scores`:# (`batch_size`, no. of queries, no. of key-value pairs)scores = np.squeeze(self.w_v(features), axis=-1)self.attention_weights = masked_softmax(scores, valid_lens)# Shape of `values`: (`batch_size`, no. of key-value pairs, value# dimension)return npx.batch_dot(self.dropout(self.attention_weights), values)
```{.python .input}#@tab pytorch#@saveclass AdditiveAttention(nn.Module):def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):super(AdditiveAttention, self).__init__(**kwargs)self.W_k = nn.Linear(key_size, num_hiddens, bias=False)self.W_q = nn.Linear(query_size, num_hiddens, bias=False)self.w_v = nn.Linear(num_hiddens, 1, bias=False)self.dropout = nn.Dropout(dropout)def forward(self, queries, keys, values, valid_lens):queries, keys = self.W_q(queries), self.W_k(keys)# After dimension expansion, shape of `queries`: (`batch_size`, no. of# queries, 1, `num_hiddens`) and shape of `keys`: (`batch_size`, 1,# no. of key-value pairs, `num_hiddens`). Sum them up with# broadcastingfeatures = queries.unsqueeze(2) + keys.unsqueeze(1)features = torch.tanh(features)# There is only one output of `self.w_v`, so we remove the last# one-dimensional entry from the shape. Shape of `scores`:# (`batch_size`, no. of queries, no. of key-value pairs)scores = self.w_v(features).squeeze(-1)self.attention_weights = masked_softmax(scores, valid_lens)# Shape of `values`: (`batch_size`, no. of key-value pairs, value# dimension)return torch.bmm(self.dropout(self.attention_weights), values)
Let us demonstrate the above AdditiveAttention class
with a toy example,
where shapes (batch size, number of steps or sequence length in tokens, feature size)
of queries, keys, and values
are ($2$, $1$, $20$), ($2$, $10$, $2$),
and ($2$, $10$, $4$), respectively.
The attention pooling output
has a shape of (batch size, number of steps for queries, feature size for values).
```{.python .input} queries, keys = d2l.normal(0, 1, (2, 1, 20)), d2l.ones((2, 10, 2))
The two value matrices in the values minibatch are identical
values = np.arange(40).reshape(1, 10, 4).repeat(2, axis=0) valid_lens = d2l.tensor([2, 6])
attention = AdditiveAttention(num_hiddens=8, dropout=0.1) attention.initialize() attention(queries, keys, values, valid_lens)
```{.python .input}#@tab pytorchqueries, keys = d2l.normal(0, 1, (2, 1, 20)), d2l.ones((2, 10, 2))# The two value matrices in the `values` minibatch are identicalvalues = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)valid_lens = d2l.tensor([2, 6])attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,dropout=0.1)attention.eval()attention(queries, keys, values, valid_lens)
Although additive attention contains learnable parameters, since every key is the same in this example, the attention weights are uniform, determined by the specified valid lengths.
```{.python .input}
@tab all
d2l.show_heatmaps(d2l.reshape(attention.attention_weights, (1, 1, 2, 10)), xlabel=’Keys’, ylabel=’Queries’)
## Scaled Dot-Product AttentionA more computationally efficientdesign for the scoring function can besimply dot product.However,the dot product operationrequires that both the query and the keyhave the same vector length, say $d$.Assume thatall the elements of the query and the keyare independent random variableswith zero mean and unit variance.The dot product ofboth vectors has zero mean and a variance of $d$.To ensure that the variance of the dot productstill remains one regardless of vector length,the *scaled dot-product attention* scoring function$$a(\mathbf q, \mathbf k) = \mathbf{q}^\top \mathbf{k} /\sqrt{d}$$divides the dot product by $\sqrt{d}$.In practice,we often think in minibatchesfor efficiency,such as computing attentionfor$n$ queries and $m$ key-value pairs,where queries and keys are of length $d$and values are of length $v$.The scaled dot-product attentionof queries $\mathbf Q\in\mathbb R^{n\times d}$,keys $\mathbf K\in\mathbb R^{m\times d}$,and values $\mathbf V\in\mathbb R^{m\times v}$is$$ \mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}.$$:eqlabel:`eq_softmax_QK_V`In the following implementation of the scaled dot product attention, we use dropout for model regularization.```{.python .input}#@saveclass DotProductAttention(nn.Block):"""Scaled dot product attention."""def __init__(self, dropout, **kwargs):super(DotProductAttention, self).__init__(**kwargs)self.dropout = nn.Dropout(dropout)# Shape of `queries`: (`batch_size`, no. of queries, `d`)# Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`)# Shape of `values`: (`batch_size`, no. of key-value pairs, value# dimension)# Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries)def forward(self, queries, keys, values, valid_lens=None):d = queries.shape[-1]# Set `transpose_b=True` to swap the last two dimensions of `keys`scores = npx.batch_dot(queries, keys, transpose_b=True) / math.sqrt(d)self.attention_weights = masked_softmax(scores, valid_lens)return npx.batch_dot(self.dropout(self.attention_weights), values)
```{.python .input}
@tab pytorch
@save
class DotProductAttention(nn.Module): “””Scaled dot product attention.””” def init(self, dropout, kwargs): super(DotProductAttention, self).init(kwargs) self.dropout = nn.Dropout(dropout)
# Shape of `queries`: (`batch_size`, no. of queries, `d`)# Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`)# Shape of `values`: (`batch_size`, no. of key-value pairs, value# dimension)# Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries)def forward(self, queries, keys, values, valid_lens=None):d = queries.shape[-1]# Set `transpose_b=True` to swap the last two dimensions of `keys`scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)self.attention_weights = masked_softmax(scores, valid_lens)return torch.bmm(self.dropout(self.attention_weights), values)
To demonstrate the above `DotProductAttention` class,we use the same keys, values, and valid lengths from the earlier toy examplefor additive attention.For the dot product operation,we make the feature size of queriesthe same as that of keys.```{.python .input}queries = d2l.normal(0, 1, (2, 1, 2))attention = DotProductAttention(dropout=0.5)attention.initialize()attention(queries, keys, values, valid_lens)
```{.python .input}
@tab pytorch
queries = d2l.normal(0, 1, (2, 1, 2)) attention = DotProductAttention(dropout=0.5) attention.eval() attention(queries, keys, values, valid_lens)
Same as in the additive attention demonstration,since `keys` contains the same elementthat cannot be differentiated by any query,uniform attention weights are obtained.```{.python .input}#@tab alld2l.show_heatmaps(d2l.reshape(attention.attention_weights, (1, 1, 2, 10)),xlabel='Keys', ylabel='Queries')
Summary
- We can compute the output of attention pooling as a weighted average of values, where different choices of the attention scoring function lead to different behaviors of attention pooling.
- When queries and keys are vectors of different lengths, we can use the additive attention scoring function. When they are the same, the scaled dot-product attention scoring function is more computationally efficient.
Exercises
- Modify keys in the toy example and visualize attention weights. Do additive attention and scaled dot-product attention still output the same attention weights? Why or why not?
- Using matrix multiplications only, can you design a new scoring function for queries and keys with different vector lengths?
- When queries and keys have the same vector length, is vector summation a better design than dot product for the scoring function? Why or why not?
:begin_tab:mxnet
Discussions
:end_tab:
:begin_tab:pytorch
Discussions
:end_tab:
