在生物学中,注意力通常是由非自愿性提示以及自愿性提示主导的。非自愿性提示基于环境中对象的显眼性(人们倾向于注意到环境中显眼的东西);自愿性提示基于自身的意愿(你想喝咖啡,因此会注意到环境中的咖啡)

Queries,Keys,Values

Attention Pooling - 图1
了解了自愿性及非自愿性提示的概念后,再来看什么是QKV,Query对应自愿性提示,Keys实际上就是对应非自愿性提示,Value对应感官的输入。
注意力机制与普通全连接层以及池化层实际上是不同的,原因在于它们仅仅是Keys,但没有包含非自愿性提示(Query)。
Keys和Values是成对的。

Nadaraya-Watson Kernel Regression

这是一个1964年提出的经典算法,它可以诠释注意力机制在机器学习中是如何实现的。这一方法实际上就Attention Pooling的一种实现。

数据生成

考虑一个回归问题:给定一个键值对集,我们应该如何将这个函数f拟合出来?

现在我们人工基于规则构造一个数据集,其中为噪声项,它服从标准差为0.5,均值为0的正态分布

Average Pooling

先用一个非常弱智的函数看看拟合效果,就是将所有值求个平均:

结果不出意料:
image.png

Nonparametric Attention Pooling

平均池化忽略了输入Xi,那么Nadaraya和Watson想出了一个更好的拟合方式:

其中K是一个核函数,这里没必要管它。
回忆一下Attention中的QKV,我们可以将它表示成这样的形式:

事实上,这里的就是Query,而则是键值对。是注意力权重,所有键值对的注意力权重之和为1。
将K设为高斯核函数,可进行如下化简:

那么我们可以注意到,此处键与挨得越近,那么相应的就会得到更多的注意力。
总之,这是一种无参数注意力池化机制。效果好多了
image.png

Parametric Attention Pooling

我们可以对回归函数再做一些改进,比如加一个参数:

该参数是可学习的

Training

  1. import torch
  2. from torch import nn
  3. from d2l import torch as d2l
  4. n_train = 50 # No. of training examples
  5. x_train, _ = torch.sort(torch.rand(n_train) * 5) # Training inputs
  6. def f(x):
  7. return 2 * torch.sin(x) + x ** 0.8
  8. y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,)) # Training outputs
  9. x_test = torch.arange(0, 5, 0.1) # Testing examples
  10. y_truth = f(x_test) # Ground-truth outputs for the testing examples
  11. n_test = len(x_test) # No. of testing examples
  12. print(n_test)
  13. def plot_kernel_reg(y_hat):
  14. d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
  15. xlim=[0, 5], ylim=[-1, 5])
  16. d2l.plt.plot(x_train, y_train, 'o', alpha=0.5)
  17. d2l.plt.show()
  18. class NWKernelRegression(nn.Module):
  19. def __init__(self, **kwargs):
  20. super().__init__(**kwargs)
  21. self.w = nn.Parameter(torch.rand((1,), requires_grad=True))
  22. def forward(self, queries, keys, values):
  23. # Shape of the output `queries` and `attention_weights`:
  24. # (no. of queries, no. of key-value pairs)
  25. queries = queries.repeat_interleave(keys.shape[1]).reshape(
  26. (-1, keys.shape[1]))
  27. self.attention_weights = nn.functional.softmax(
  28. -((queries - keys) * self.w)**2 / 2, dim=1)
  29. # Shape of `values`: (no. of queries, no. of key-value pairs)
  30. return torch.bmm(self.attention_weights.unsqueeze(1),
  31. values.unsqueeze(-1)).reshape(-1)
  32. # Shape of `X_tile`: (`n_train`, `n_train`), where each column contains the
  33. # same training inputs
  34. X_tile = x_train.repeat((n_train, 1))
  35. # Shape of `Y_tile`: (`n_train`, `n_train`), where each column contains the
  36. # same training outputs
  37. Y_tile = y_train.repeat((n_train, 1))
  38. # Shape of `keys`: ('n_train', 'n_train' - 1)
  39. keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape(
  40. (n_train, -1))
  41. # Shape of `values`: ('n_train', 'n_train' - 1)
  42. values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape(
  43. (n_train, -1))
  44. net = NWKernelRegression()
  45. loss = nn.MSELoss(reduction='none')
  46. trainer = torch.optim.SGD(net.parameters(), lr=0.5)
  47. animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])
  48. for epoch in range(5):
  49. trainer.zero_grad()
  50. # Note: L2 Loss = 1/2 * MSE Loss. PyTorch has MSE Loss which is slightly
  51. # different from MXNet's L2Loss by a factor of 2. Hence we halve the loss
  52. l = loss(net(x_train, keys, values), y_train) / 2
  53. l.sum().backward()
  54. trainer.step()
  55. print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
  56. animator.add(epoch + 1, float(l.sum()))
  57. # Shape of `keys`: (`n_test`, `n_train`), where each column contains the same
  58. # training inputs (i.e., same keys)
  59. keys = x_train.repeat((n_test, 1))
  60. # Shape of `value`: (`n_test`, `n_train`)
  61. values = y_train.repeat((n_test, 1))
  62. y_hat = net(x_test, keys, values).unsqueeze(1).detach()
  63. plot_kernel_reg(y_hat)

效果非常好
image.png