以数组A和数组B的相加为例, 其余数学运算同理
核心:如果相加的两个数组的shape不同, 就会触发广播机制, 1)程序会自动执行操作使得A.shape==B.shape, 2)对应位置进行相加
有两种情况能够进行广播
- A.ndim > B.ndim, 并且A.shape最后几个元素包含B.shape, 比如下面三种情况, 注意不要混淆ndim和shape这两个基本概念
- A.shape=(2,3,4,5), B.shape=(3,4,5)
- A.shape=(2,3,4,5), B.shape=(4,5)
- A.shape=(2,3,4,5), B.shape=(5)
- A.ndim == B.ndim, 并且A.shape和B.shape对应位置的元素要么相同要么其中一个是1, 比如
- A.shape=(1,9,4), B.shape=(15,1,4)
- A.shape=(1,9,4), B.shape=(15,1,1)
import torch
a = torch.arange(1,25).reshape((2,3,4))
b = torch.arange(1,13).reshape((3,4))
a+b
Out[7]:
tensor([[[ 2, 4, 6, 8],
[10, 12, 14, 16],
[18, 20, 22, 24]],
[[14, 16, 18, 20],
[22, 24, 26, 28],
[30, 32, 34, 36]]])
# 将(3,4)扩展为 (2,3,4)再对应相加,同理对应所有四则运算