本文重点在讨论Transformer的Positional embedding的形式。

笔者碎碎念:本文需要读者有基本的Transformer的知识。

参考链接:
1. Rotary Embeddings: A Relative Revolution
2. 让研究人员绞尽脑汁的Transformer位置编码

codebase:
TF版本:https://github.com/ZhuiyiTechnology/roformer
Pytorch版本:https://github.com/JunnYu/RoFormer_pytorch

各种不同的形式

总得来说,Transformer中的位置编码有下面几种:

  1. 绝对位置编码(可学习、三角、递归、相乘)
  2. 相对位置编码(经典、XLNET、T5、DeBERTa)
  3. 其他(CNN、复数)

各种位置编码的数学表达可以看参考链接2的文章,在本文不过多赘述。

本文在于讨论如何融合绝对位置和相对位置。

Rotary Postional Embedding(RoPE)

位置编码的选择通常需要在简单、灵活和效率之间做取舍。一个绝对位置编码很简单,但是可能泛化能力并不好,比如预训练的模型的句子长度与下游任务不一致。

Intuition

我们想要找到一个位置编码函数Rotary Positional Embedding - 图1Rotary Positional Embedding - 图2表示位置,Rotary Positional Embedding - 图3可以为Rotary Positional Embedding - 图4或者Rotary Positional Embedding - 图5(分别表示query和key),其位置分别为Rotary Positional Embedding - 图6。那么在计算attention的时候,他们之间的内积Rotary Positional Embedding - 图7应该只与Rotary Positional Embedding - 图8的值以及他们的相对位置Rotary Positional Embedding - 图9有关。

image.png

做法

RoPE将query表示为Rotary Positional Embedding - 图11
一旦位置编码被嵌入,内积的形式如下:
image.png
我们将位置编码函数写为复数的形式:
image.png
计算内积的相等形式为:
image.png
假设Rotary Positional Embedding - 图15,应用初始条件Rotary Positional Embedding - 图16,代入第一个式子得到:
image.png
那我们可以假设Rotary Positional Embedding - 图18

对于第二个式子得到Rotary Positional Embedding - 图19,那么Rotary Positional Embedding - 图20可以分解为Rotary Positional Embedding - 图21。对于Rotary Positional Embedding - 图22有:
image.png
可以看出右边的式子与Rotary Positional Embedding - 图24无关,因此设置Rotary Positional Embedding - 图25,并且有Rotary Positional Embedding - 图26
那么结合上述式子可以得到
image.png
对于Rotary Positional Embedding - 图28也是一样的形式。
在实现的过程中,需要写为矩阵形式而不是复数:
image.png
其中Rotary Positional Embedding - 图30,也即:
image.png
也就是说给位置为Rotary Positional Embedding - 图32Rotary Positional Embedding - 图33乘上矩阵Rotary Positional Embedding - 图34,位置为Rotary Positional Embedding - 图35Rotary Positional Embedding - 图36乘上矩阵Rotary Positional Embedding - 图37,再做Attention就可以包含相对位置信息了:
image.png
并且可以看出矩阵很稀疏,所以直接进行矩阵乘法消耗很大,可以直接由下面算式实现:
image.png

拓展到高维

image.png
相似地,该公式可以拓展到任意维度。

代码实现

看代码Rotary Positional Embedding - 图41的实现是按照google的余弦位置嵌入实现,即
image.png

  1. # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->RoFormer
  2. class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
  3. """This module produces sinusoidal positional embeddings of any length."""
  4. def __init__(
  5. self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None
  6. ):
  7. super().__init__(num_positions, embedding_dim)
  8. self.weight = self._init_weight(self.weight)
  9. @staticmethod
  10. def _init_weight(out: nn.Parameter):
  11. """
  12. Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
  13. the 2nd half of the vector. [dim // 2:]
  14. """
  15. n_pos, dim = out.shape
  16. position_enc = np.array(
  17. [
  18. [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
  19. for pos in range(n_pos)
  20. ]
  21. )
  22. out.requires_grad = False # set early to avoid an error in pytorch-1.8+
  23. sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
  24. # 前0到dim//2为sin
  25. # 后dim//2为cos
  26. out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
  27. out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
  28. out.detach_()
  29. return out
  30. @torch.no_grad()
  31. def forward(self, seq_len: int, past_key_values_length: int = 0):
  32. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  33. positions = torch.arange(
  34. past_key_values_length,
  35. past_key_values_length + seq_len,
  36. dtype=torch.long,
  37. device=self.weight.device,
  38. )
  39. return super().forward(positions)
  40. # .....
  41. # google sinusoidal
  42. # [sequence_length, embed_size_per_head] -> sin & cos [batch_size, num_heads, sequence_length, embed_size_per_head // 2]
  43. sinusoidal_pos = self.embed_positions(hidden_states.shape[1], past_key_values_length)[
  44. None, None, :, :
  45. ].chunk(2, dim=-1)# [1(bs),1(num_head),seq_len,embed_sz//2]

然后才是旋转

    @staticmethod
    def apply_rotary(x, sinusoidal_pos):
        sin, cos = sinusoidal_pos # 这里直接拿出sin,cos
        # rotary matrix M=[cos(m\theta)  -sin(m\theta)\\
        #                   sin(m\theta)  cos(m\theta)]

        x1, x2 = x[..., 0::2], x[..., 1::2] # 分奇偶拿出x1,x2

        # 以query为例,x1=[q0,q2,...,qd], x2=[q1,q3,...,q(d-1)]
        # 如果是旋转query key的话,下面这个直接cat就行,因为要进行矩阵乘法,最终会在这个维度求和。(只要保持query和key的最后一个dim的每一个位置对应上就可以)
        # torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)

        # 如果是旋转value的话,下面这个stack后再flatten才可以,因为训练好的模型最后一个dim是两两之间交替的。
        return torch.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1).flatten(-2, -1)