参考

  • AI SUMMER这篇文章写的很好,很直观,很清晰:https://theaisummer.com/positional-embeddings/

    前言

    这里讨论的相对位置编码的实现策略来自于Music Transformer
    这里有一篇介绍性的文章:https://gudgud96.github.io/2020/04/01/annotated-music-transformer/,图例非常清晰。
    首先理解下相对位置自注意力中关于位置嵌入的一些细节。
    image.png相对注意力的一些相关概念。摘自Music Transformer。在不考虑head维度时:

  • 相对位置编码的理解 - 图2: relative position embedding,大小为相对位置编码的理解 - 图3

  • 相对位置编码的理解 - 图4: 来自Shaw论文中引入的相对位置嵌入的中间表示,大小为相对位置编码的理解 - 图5
  • 相对位置编码的理解 - 图6: 表示相对位置编码与query的交互结果,大小为相对位置编码的理解 - 图7,即在相对位置编码的理解 - 图8维度上进行了累加
  • Music Transformer的一点工作就是将这个会占用较大存储空间的中间表示相对位置编码的理解 - 图9去掉,直接得到相对位置编码的理解 - 图10,如下图所示

image.png
要注意这里的相对位置编码的理解 - 图12表示的是针对相对位置相对位置编码的理解 - 图13的嵌入,最小相对位置为相对位置编码的理解 - 图14,最大为0(因为需要考虑因果关系,前面的i看不到后面的j),所以有相对位置编码的理解 - 图15个位置。
而对于我们这里将要讨论的不考虑因果关系的情况,最小相对位置为相对位置编码的理解 - 图16,最大为相对位置编码的理解 - 图17。所以我们的位置嵌入相对位置编码的理解 - 图18形状为相对位置编码的理解 - 图19

代码分析

首先找份代码来看看,https://github.com/lucidrains/bottleneck-transformer-pytorch/blob/main/bottleneck_transformer_pytorch/bottleneck_transformer_pytorch.py实现的相对位置编码涉及到几个关键的组件:

  1. import torch
  2. import torch.nn as nn
  3. from einops import rearrange
  4. def relative_to_absolute(q):
  5. """
  6. Converts the dimension that is specified from the axis
  7. from relative distances (with length 2*tokens-1) to absolute distance (length tokens)
  8. borrowed from lucidrains:
  9. https://github.com/lucidrains/bottleneck-transformer-pytorch/blob/main/bottleneck_transformer_pytorch/bottleneck_transformer_pytorch.py#L21
  10. Input: [bs, heads, length, 2*length - 1]
  11. Output: [bs, heads, length, length]
  12. """
  13. b, h, l, _, device, dtype = *q.shape, q.device, q.dtype
  14. dd = {'device': device, 'dtype': dtype}
  15. col_pad = torch.zeros((b, h, l, 1), **dd)
  16. x = torch.cat((q, col_pad), dim=3) # zero pad 2l-1 to 2l
  17. flat_x = rearrange(x, 'b h l c -> b h (l c)')
  18. flat_pad = torch.zeros((b, h, l - 1), **dd)
  19. flat_x_padded = torch.cat((flat_x, flat_pad), dim=2)
  20. final_x = flat_x_padded.reshape(b, h, l + 1, 2 * l - 1)
  21. final_x = final_x[:, :, :l, (l - 1):]
  22. return final_x
  23. def rel_pos_emb_1d(q, rel_emb, shared_heads):
  24. """
  25. Same functionality as RelPosEmb1D
  26. Args:
  27. q: a 4d tensor of shape [batch, heads, tokens, dim]
  28. rel_emb: a 2D or 3D tensor
  29. of shape [ 2*tokens-1 , dim] or [ heads, 2*tokens-1 , dim]
  30. """
  31. if shared_heads:
  32. emb = torch.einsum('b h t d, r d -> b h t r', q, rel_emb)
  33. else:
  34. emb = torch.einsum('b h t d, h r d -> b h t r', q, rel_emb)
  35. return relative_to_absolute(emb)
  36. class RelPosEmb1DAISummer(nn.Module):
  37. def __init__(self, tokens, dim_head, heads=None):
  38. """
  39. Output: [batch head tokens tokens]
  40. Args:
  41. tokens: the number of the tokens of the seq
  42. dim_head: the size of the last dimension of q
  43. heads: if None representation is shared across heads.
  44. else the number of heads must be provided
  45. """
  46. super().__init__()
  47. scale = dim_head ** -0.5
  48. self.shared_heads = heads if heads is not None else True
  49. if self.shared_heads:
  50. self.rel_pos_emb = nn.Parameter(torch.randn(2 * tokens - 1, dim_head) * scale)
  51. else:
  52. self.rel_pos_emb = nn.Parameter(torch.randn(heads, 2 * tokens - 1, dim_head) * scale)
  53. def forward(self, q):
  54. return rel_pos_emb_1d(q, self.rel_pos_emb, self.shared_heads)

可以看到:

  • RelPosEmb1DAISummer初始化了相对位置编码的理解 - 图20
  • rel_pos_emb_1drelative_to_absolute提供相对位置编码的理解 - 图21(为了便于书写,我们将其设为相对位置编码的理解 - 图22),通过在relative_to_absolute中各种形变和padding,从而得到了

理解的难点在relative_to_absolute中的实现过程。
这里会把相对位置编码的理解 - 图23从一个相对位置编码的理解 - 图24tensor转化为一个相对位置编码的理解 - 图25的tensor。这个过程实际上就是一个从表中查找的过程。
这里的实现其实有些晦涩,直接阅读代码是很难明白其中的意义。接下来会重点说这个。
需要注意的是,下面的分析都是按照1D的token序列来解释的,实际上2D的也是将H和W分别基于1D的策略处理的。也就是将H或者W合并到头索引那一维度,即这里的heads,结果就和1D的一致了,只是还会多一个额外的广播的过程。如下代码:

  1. import torch.nn as nn
  2. from einops import rearrange
  3. from self_attention_cv.pos_embeddings.relative_embeddings_1D import RelPosEmb1D
  4. class RelPosEmb2DAISummer(nn.Module):
  5. def __init__(self, feat_map_size, dim_head, heads=None):
  6. """
  7. Based on Bottleneck transformer paper
  8. paper: https://arxiv.org/abs/2101.11605 . Figure 4
  9. Output: qr^T [batch head tokens tokens]
  10. Args:
  11. tokens: the number of the tokens of the seq
  12. dim_head: the size of the last dimension of q
  13. heads: if None representation is shared across heads.
  14. else the number of heads must be provided
  15. """
  16. super().__init__()
  17. self.h, self.w = feat_map_size # height , width
  18. self.total_tokens = self.h * self.w
  19. self.shared_heads = heads if heads is not None else True
  20. self.emb_w = RelPosEmb1D(self.h, dim_head, heads)
  21. self.emb_h = RelPosEmb1D(self.w, dim_head, heads)
  22. def expand_emb(self, r, dim_size):
  23. # Decompose and unsqueeze dimension
  24. r = rearrange(r, 'b (h x) i j -> b h x () i j', x=dim_size)
  25. expand_index = [-1, -1, -1, dim_size, -1, -1] # -1 indicates no expansion
  26. r = r.expand(expand_index)
  27. return rearrange(r, 'b h x1 x2 y1 y2 -> b h (x1 y1) (x2 y2)')
  28. def forward(self, q):
  29. """
  30. Args:
  31. q: [batch, heads, tokens, dim_head]
  32. Returns: [ batch, heads, tokens, tokens]
  33. """
  34. assert self.total_tokens == q.shape[2], f'Tokens {q.shape[2]} of q must \
  35. be equal to the product of the feat map size {self.total_tokens} '
  36. # out: [batch head*w h h]
  37. r_h = self.emb_w(rearrange(q, 'b h (x y) d -> b (h x) y d', x=self.h, y=self.w))
  38. r_w = self.emb_h(rearrange(q, 'b h (x y) d -> b (h y) x d', x=self.h, y=self.w))
  39. q_r = self.expand_emb(r_h, self.h) + self.expand_emb(r_w, self.w)
  40. return q_r

提前的思考

首先我们要明确,为什么对于每个维度为相对位置编码的理解 - 图26的token 相对位置编码的理解 - 图27,其对应的整体相对位置编码的理解 - 图28会有相对位置编码的理解 - 图29这样一个缩减的过程?
因为对于长为相对位置编码的理解 - 图30的序列中的每一个元素相对位置编码的理解 - 图31,实际上与之可能有关的元素相对位置编码的理解 - 图32最多只有相对位置编码的理解 - 图33个(废话,O(∩_∩)O哈哈~)。
所以对于每个元素,实际上这里的相对位置编码的理解 - 图34并不会都用到。这里的相对位置编码的理解 - 图35只是所有可能会用到的情形(分别对应于各种相对距离相对位置编码的理解 - 图36)。

这里需要说明的一点是,有些相对注意力的策略中,会使用固定的窗口。 即对于窗口之外的j,和窗口边界上的j的相对距离认为是一样的,即相对位置编码的理解 - 图37,我们这里介绍的可以看做是相对位置编码的理解 - 图38。 例如这个实现:https://github.com/TensorUI/relative-position-pytorch/blob/master/relative_position.py

所以这里前面展示的这个函数relative_to_absolute实际上就是在做这样一件事:从相对位置编码的理解 - 图39中抽取对应于各个token相对位置编码的理解 - 图40真实存在的相对距离相对位置编码的理解 - 图41的位置嵌入集合相对位置编码的理解 - 图42来得到最终的相对位置编码的理解 - 图43

背后的动机

为了便于展示这个代码描述的过程的动机,我们首先构造一个简单的序列相对位置编码的理解 - 图44,包含5个元素,则相对位置编码的理解 - 图45。这里嵌入的维度为相对位置编码的理解 - 图46。则位置相对位置编码的理解 - 图47对应的相对距离矩阵可以表示为:
image.png
图1
这里红色标记表示各个位置上的相对距离。
我们再看下假定已经得到的相对位置编码的理解 - 图49
image.png
图2
这里对各个相对位置编码的理解 - 图51都提供了独立的一套嵌入相对位置编码的理解 - 图52。为了直观的展示,这里我们也展示了对于这相对位置编码的理解 - 图53个相对位置的相对距离,同时也标注了对应于嵌入矩阵各列的绝对索引。
接下来我们就需要提取想要的那部分嵌入的tensor了。
这个时候,我们需要明白,我们要获取的是哪部分结果:
image.png
图3
这里实际上就是结合了图1中已经得到的相对距离和图2中的相对位置编码的理解 - 图55,从而就可以明白,红色的这部分区域正是我们想要的那部分合理索引对应的位置编码。
稍微整理下,也就是如下的绝对索引对应的嵌入信息相对位置编码的理解 - 图56(形状与相对位置编码的理解 - 图57一致,可以直接元素级相加):
image.png
图4
而前面的代码relative_to_absolute正是在做这样一件事。就是通过不断的paddingreshape来从图3中获得图4中这些绝对索引对应的嵌入。

对应的流程

关于代码的流程,参考链接中的图例非常直观:

  1. col_pad = torch.zeros((b, h, l, 1), **dd)
  2. x = torch.cat((q, col_pad), dim=3) # zero pad 2l-1 to 2l

image.png

  1. flat_x = rearrange(x, 'b h l c -> b h (l c)')

image.png

  1. flat_pad = torch.zeros((b, h, l - 1), **dd)
  2. flat_x_padded = torch.cat((flat_x, flat_pad), dim=2)

image.png

  1. final_x = flat_x_padded.reshape(b, h, l + 1, 2 * l - 1)
  2. final_x = final_x[:, :, :l, (l - 1):]

image.png
将提取的内容对应于原始的相对位置编码的理解 - 图63中,可以看到是如下区域,正如前面的分析所示。
image.png