广播语义

原文: https://pytorch.org/docs/stable/notes/broadcasting.html

许多 PyTorch 操作都支持NumPy Broadcasting Semantics

简而言之,如果 PyTorch 操作支持广播,则其 Tensor 参数可以自动扩展为相等大小(无需复制数据)。

一般语义

如果满足以下规则,则两个张量是“可广播的”:

  • 每个张量具有至少一个维度。

  • 从尾随尺寸开始迭代尺寸尺寸时,尺寸尺寸必须相等,其中之一为 1,或者不存在其中之一。

例如:

  1. >>> x=torch.empty(5,7,3)
  2. >>> y=torch.empty(5,7,3)
  3. # same shapes are always broadcastable (i.e. the above rules always hold)
  4. >>> x=torch.empty((0,))
  5. >>> y=torch.empty(2,2)
  6. # x and y are not broadcastable, because x does not have at least 1 dimension
  7. # can line up trailing dimensions
  8. >>> x=torch.empty(5,3,4,1)
  9. >>> y=torch.empty( 3,1,1)
  10. # x and y are broadcastable.
  11. # 1st trailing dimension: both have size 1
  12. # 2nd trailing dimension: y has size 1
  13. # 3rd trailing dimension: x size == y size
  14. # 4th trailing dimension: y dimension doesn't exist
  15. # but:
  16. >>> x=torch.empty(5,2,4,1)
  17. >>> y=torch.empty( 3,1,1)
  18. # x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3

如果两个张量xy是“可广播的”,则所得张量大小的计算如下:

  • 如果xy的维数不相等,则在张量的维数前面加 1,以使其长度相等。

  • 然后,对于每个尺寸尺寸,所得尺寸尺寸是该尺寸上xy尺寸的最大值。

For Example:

  1. # can line up trailing dimensions to make reading easier
  2. >>> x=torch.empty(5,1,4,1)
  3. >>> y=torch.empty( 3,1,1)
  4. >>> (x+y).size()
  5. torch.Size([5, 3, 4, 1])
  6. # but not necessary:
  7. >>> x=torch.empty(1)
  8. >>> y=torch.empty(3,1,7)
  9. >>> (x+y).size()
  10. torch.Size([3, 1, 7])
  11. >>> x=torch.empty(5,2,4,1)
  12. >>> y=torch.empty(3,1,1)
  13. >>> (x+y).size()
  14. RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

就地语义

一个复杂之处在于,就地操作不允许就地张量由于广播而改变形状。

For Example:

  1. >>> x=torch.empty(5,3,4,1)
  2. >>> y=torch.empty(3,1,1)
  3. >>> (x.add_(y)).size()
  4. torch.Size([5, 3, 4, 1])
  5. # but:
  6. >>> x=torch.empty(1,3,1)
  7. >>> y=torch.empty(3,1,7)
  8. >>> (x.add_(y)).size()
  9. RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.

向后兼容

只要每个张量中的元素数量相等,以前的 PyTorch 版本都可以在具有不同形状的张量上执行某些逐点函数。 然后,通过将每个张量视为一维来执行逐点操作。 PyTorch 现在支持广播,并且“一维”按点行为被认为已弃用,并且在张量不可广播但具有相同数量元素的情况下会生成 Python 警告。

注意,在两个张量不具有相同形状但可广播且具有相同元素数量的情况下,广播的引入会导致向后不兼容的更改。 例如:

  1. >>> torch.add(torch.ones(4,1), torch.randn(4))

以前会产生一个具有大小:torch.Size([4,1])的张量,但现在会产生一个具有以下大小:torch.Size([4,4])的张量。 为了帮助确定代码中可能存在广播引入的向后不兼容的情况,可以将 <cite>torch.utils.backcompat.broadcast_warning.enabled</cite> 设置为 <cite>True</cite> ,这将生成一个 python 在这种情况下发出警告。

For Example:

  1. >>> torch.utils.backcompat.broadcast_warning.enabled=True
  2. >>> torch.add(torch.ones(4,1), torch.ones(4))
  3. __main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
  4. Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.