• 参考文献

    Self-Attention with Relative Position Representations (Shaw et al.2018): https://arxiv.org/pdf/1803.02155.pdf

    • 公式

    普通的self-attention输出:
    相对位置编码 - 图1
    引入两个与相对位置有关的向量相对位置编码 - 图2
    相对位置编码 - 图3
    假设如果序列中两个元素的距离超过相对位置编码 - 图4,认为这两个元素之间位置信息无意义。将相对位置编码 - 图5
    定义为可训练的向量相对位置编码 - 图6
    相对位置编码 - 图7
    公式可以进行如下高效实现:
    相对位置编码 - 图8
    相对位置编码 - 图9

    • 代码

      1. def _relative_attention_inner(x, y, z, transpose):
      2. """Relative position-aware dot-product attention inner calculation.
      3. This batches matrix multiply calculations to avoid unnecessary broadcasting.
      4. Args:
      5. x: Tensor with shape [batch_size, heads, length or 1, length or depth].
      6. y: Tensor with shape [batch_size, heads, length or 1, depth].
      7. z: Tensor with shape [length or 1, length, depth].
      8. transpose: Whether to transpose inner matrices of y and z. Should be true if
      9. last dimension of x is depth, not length.
      10. Returns:
      11. A Tensor with shape [batch_size, heads, length, length or depth].
      12. """
      13. batch_size = tf.shape(x)[0]
      14. heads = x.get_shape().as_list()[1]
      15. length = tf.shape(x)[2]
      16. # xy_matmul is [batch_size, heads, length or 1, length or depth]
      17. xy_matmul = tf.matmul(x, y, transpose_b=transpose)
      18. # x_t is [length or 1, batch_size, heads, length or depth]
      19. x_t = tf.transpose(x, [2, 0, 1, 3])
      20. # x_t_r is [length or 1, batch_size * heads, length or depth]
      21. x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1])
      22. # x_tz_matmul is [length or 1, batch_size * heads, length or depth]
      23. x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose)
      24. # x_tz_matmul_r is [length or 1, batch_size, heads, length or depth]
      25. x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1])
      26. # x_tz_matmul_r_t is [batch_size, heads, length or 1, length or depth]
      27. x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1, 2, 0, 3])
      28. return xy_matmul + x_tz_matmul_r_t