Tensor 维度变换

维度变换是深度学习中对 Tensor 的一个非常重要的操作,通过改变 Tensor 数据的维度,便于后期的计算和处理。对于一个 shape = Pytorch 进阶操作 - 图1,shape 就相当于 Tensor 数据的视图,我们可以对它做一个不同理解上的变换,比如将 shape 变成 Pytorch 进阶操作 - 图2,这个操作叫做 reshape 操作。主要的维度变换操作还有挤压维度(squeeze) / 增加维度(unsqueeze),维度交换(transpose / t / permute),维度扩展(expand / repeat),维度自动扩展(broadcasting)。

Reshape / view

假如我们有 4 张 28 * 28 RGB 像素彩色 3 通道的图片,我们可以将它们初始化为一个形状为 Pytorch 进阶操作 - 图3 的 Tensor,这个形状的含义是 Pytorch 进阶操作 - 图4。如果我们调用 Tensor.reshape() 将它们的形状变换成 Pytorch 进阶操作 - 图5,含义就变成 Pytorch 进阶操作 - 图6,p 是指像素 pixel。这样我们就将图片像素有 28 行 28 列的信息抹掉了,变换成了图片总共由 784 个像素组成。我们要注意的是,变换前后,数据的大小是不变的,如果将数据的大小改变了,比如改成 Pytorch 进阶操作 - 图7,这样就会报错。

Tensor.reshape() 接收 shape 参数,如果你懒得计算若干个维度合并后是多少,可以直接写为 -1,即 Pytorch 进阶操作 - 图8。如果我们这样写 Pytorch 进阶操作 - 图9,就相当于将形状改为 Pytorch 进阶操作 - 图10,这样就把图片的颜色信息也抹掉了。当然,也可以调用 tf.reshape 将修改后的形状改回原形,但这样做的前提是我们要记住最初的形状信息。如果我们记错初始信息,比如我们误将图片像素的行记成了列,列记成了行,也就是把 Pytorch 进阶操作 - 图11 记成 Pytorch 进阶操作 - 图12,这样恢复的图片信息是有误的。

  1. a = torch.rand(4, 28, 28, 3)
  2. print(a.shape) # torch.Size([4, 28, 28, 3])
  3. b = a.reshape(4, 28 * 28, 3)
  4. print(b.shape) # torch.Size([4, 784, 3])
  5. c = a.reshape(4, -1, 3)
  6. print(c.shape) # torch.Size([4, 784, 3])

Squeeze / unsqueeze

挤压维度或者增加维度的意思是不改变原数据,减少或增加一个长度为 1 的维度。同样以上方的图片为例,如果我们在其初始形状的第 0 维度的位置上增加一个维度,它的形状就从 4 维的 Pytorch 进阶操作 - 图13 变成了 5 维的 Pytorch 进阶操作 - 图14,我们对这个 Tensor 的理解方式也发生了转变。新增加的这个维度表示我们将之前的 4 张图片分到了一个组别里,它们属于同一个组别。如果我们使用 Tensor.squeeze() 就可以将增加维度后的 Tensor 挤压回初始形状的 Tensor。

Unsqueeze 接口都接收维度的位置,该位置的取值范围为 Pytorch 进阶操作 - 图15,表示增加维度后,增加的维度在 Tensor 中的位置。Squeeze 接口只能接收维度长度为 1 的位置,若不接受参数,默认挤压掉 tensor 中的所有长度为 1 的维度。

  1. d = a.unsqueeze(0)
  2. print(d.shape) # torch.Size([1, 4, 28, 28, 3])
  3. e = d.squeeze(0)
  4. print(e.shape) # torch.Size([4, 28, 28, 3])
  5. f = d.squeeze()
  6. print(f.shape) # torch.Size([4, 28, 28, 3])

Expand / repeat

如果我们现在又拿到了一张 28 * 28 灰度单通道的图片,并且需要与前面四张图片的 Tensor 作相加运算,怎么办呢?首先,该灰度图片可以转化为形状为 Pytorch 进阶操作 - 图16 的 Tensor,然后使用 Tensor.unsqueeze() 将其维度增加到四维,形状变为 Pytorch 进阶操作 - 图17。因为前面 4 张图片的 Tensor 形状为 Pytorch 进阶操作 - 图18,所以它们是不能直接相加的。我们需要将新图片的 Tensor 扩展到与 4 张图片的 Tensor 相同的形状,这里就用到了 expand 或 repeat 接口。两者的不同在于 expand 接口接收扩展维度后的形状信息,它只改变我们的理解方式,并不消耗内存,repeat 是实实在在地将数据复制了多遍,接收的参数是对应维度要复制的次数。

  1. g = torch.rand(28, 28)
  2. g = g.unsqueeze(2)
  3. g = g.unsqueeze(0)
  4. print(g.shape) # torch.Size([1, 28, 28, 1])
  5. h = g.expand([4, 28, 28, 3])
  6. print(h.shape) # torch.Size([4, 28, 28, 3])
  7. i = g.repeat(4, 1, 1, 3)
  8. print(i.shape) # torch.Size([4, 28, 28, 3])

Transpose / t / permute

transpose 在数学上的意义是转置,对于一个矩阵,将它按照一个轴翻转。transpose 在 pytorch 中代表着维度的交换,比如图片 Pytorch 进阶操作 - 图19,我么要把它的长和宽对调,变成 Pytorch 进阶操作 - 图20,就需要用到 Tensor.transpose()。对于维度为 2 的 Tensor,也就是矩阵,我们可以使用 Tensor.t() 对其转置,该函数也仅限于 2 维的 Tensor。

Tensor.transpose() 只支持维度的两两交换,但这样是有缺陷的,例如上述的 4 张图片的 Tensor 的形状,其原含义为 Pytorch 进阶操作 - 图21。现在我们将位置为 1 和 3 的维度对调,Tensor 形状的含义就变成了 Pytorch 进阶操作 - 图22,w 和 h 的相对位置也变化了,如果想要 w 和 h 的相对位置不变,还需要将对调维度后的 Tensor 中位置为 2 和 3 的维度再对调一次,这样就稍显麻烦。此时我们需要 Tensor.permute() 接口将维度交换一步到位。

  1. j = a.transpose(1, 3)
  2. print(j.shape) # torch.Size([4, 3, 28, 28])
  3. j = j.transpose(2, 3)
  4. print(j.shape) # torch.Size([4, 3, 28, 28])
  5. k = a.permute(0, 3, 1, 2)
  6. print(k.shape) # torch.Size([4, 3, 28, 28])
  7. print(torch.all(torch.eq(j, k))) # tensor(True)

Broadcast

Broadcast 在英文中是广播的意思,我们可以把它翻作自动扩展。顾名思义,它可以把 Tensor 的维度自动扩展到一定的形状,并且同 expand 一样不消耗内存,它不是一个接口,而是一种机制,满足一定条件,它会自动被调用。使用 broadcast 机制的前提是小维度(也就是位置靠后的维度)要对齐,我们还是拿上面灰度图片的 Pytorch 进阶操作 - 图23和 4 张彩图的 Pytorch 进阶操作 - 图24 为例,如果想要使用 broadcast,我们首先必须使用 unsqueeze 接口手动将 Pytorch 进阶操作 - 图25 扩展为 Pytorch 进阶操作 - 图26,使小维度对齐,然后再通过 broadcast 机制将 Pytorch 进阶操作 - 图27 扩展为 Pytorch 进阶操作 - 图28

为了体现 broadcast 的简便性,我们再举一个例子。拿学校班级的例子来说,某个年级有 4 个班级,每个班级有 36 名学生,每名学生有 8 门课,那么将它们转化为一个形状为 Pytorch 进阶操作 - 图29 的 Tensor。现在进行了一次考试,由于考试的难度较高,学生的成绩普遍较低,年级主任要求给每个同学加 5 分。这里 5 分是一个标量,它的形状是 [1],如果要和上面的 Tensor 做运算,我们需要先使用 unsqueeze 接口将 Pytorch 进阶操作 - 图30 增加为 Pytorch 进阶操作 - 图31,然后使用 expand 接口将 Pytorch 进阶操作 - 图32 扩展为 Pytorch 进阶操作 - 图33,但如果通过 broadcast 机制,两个 Tensor 相加时,它会直接把标量 Pytorch 进阶操作 - 图34 扩展成 Pytorch 进阶操作 - 图35,省去增加维度的步骤。

合并与切割

Pytorch 中用于数据合并和切割的接口分别有一对,合并的接口是 cat 和 stack,切割的接口是 split 和 chunk。

Cat / stack

还是拿上面班级成绩的例子来说,假设现在有另一所学校同年级的 5 个班级,每个班级 36 名学生,每名学生 8 门课,那么该学校学生成绩单的 Tensor 形状为 Pytorch 进阶操作 - 图36。现在要将这两个学校该年级的成绩单合并,我们就需要用到 cat 接口将两个 Tensor 合并为 Pytorch 进阶操作 - 图37。此时 torch.cat() 接收两个参数,第一个参数是要合并的所有 Tensor 的列表,第二个参数是要合并的维度位置。使用 cat 接口的前提是除了要合并的维度外,其他维度的形状相同。

如果现在第一所学校的校长要求该年级与另一个年级的学生成绩单合并,另一个年级的学生和课程组成与该年级相同,这两个年级合并后的成绩单要求按照年级分组,那么 cat 接口就不适用了。我们想要将两个形状为 Pytorch 进阶操作 - 图38 的 Tensor 合并为 Pytorch 进阶操作 - 图39,可以通过 stack 接口实现,此时 torch.stack() 接收两个参数,与 cat 接口接收的参数相同,但其第二个参数代指的含义变为合并后新出现的维度所在的位置。

  1. a = torch.rand(4, 36, 8)
  2. b = torch.rand(5, 36, 8)
  3. c = torch.cat([a, b], dim=0)
  4. print(c.shape) # torch.Size([9, 36, 8])
  5. d = torch.rand(4, 36, 8)
  6. e = torch.stack([a, d], dim=0)
  7. print(e.shape) # torch.Size([2, 4, 36, 8])

Split / chunk

Split 和 chunk 接口的区别在于前者是按照长度拆分,后者是按照数量拆分。比如列表 Pytorch 进阶操作 - 图40,如果我们按照长度拆分,可以拆成四个长度为 1 的列表 Pytorch 进阶操作 - 图41Pytorch 进阶操作 - 图42Pytorch 进阶操作 - 图43Pytorch 进阶操作 - 图44,或者两个长度为 2 的列表 Pytorch 进阶操作 - 图45Pytorch 进阶操作 - 图46,或者长度分别为 3 和 1 的列表 Pytorch 进阶操作 - 图47Pytorch 进阶操作 - 图48,或 Pytorch 进阶操作 - 图49Pytorch 进阶操作 - 图50。如果我们按照数量拆分,可以拆成 2 个列表 Pytorch 进阶操作 - 图51Pytorch 进阶操作 - 图52 或 4 个列表 Pytorch 进阶操作 - 图53Pytorch 进阶操作 - 图54Pytorch 进阶操作 - 图55Pytorch 进阶操作 - 图56。两个接口的区别就在于此。

以最初的四个班级的例子来说,如果我们按照长度 1 将班级拆分开,可以通过 Tensor.split(1, dim=0) 实现。如果分别按照长度为 3 和 1 将班级拆分,通过 Tensor.split([3, 1], dim=0)。按照数量拆分只能等分原来的数据,如 Tensor.chunk(2, dim=0)

  1. f1, f2, f3, f4 = a.split(1, dim=0)
  2. print(f1.shape, f2.shape, f3.shape, f4.shape)
  3. # torch.Size([1, 36, 8]) torch.Size([1, 36, 8]) torch.Size([1, 36, 8]) torch.Size([1, 36, 8])
  4. g1, g2 = a.split([3, 1], dim=0)
  5. print(g1.shape, g2.shape)
  6. # torch.Size([3, 36, 8]) torch.Size([1, 36, 8])
  7. h1, h2 = a.chunk(2, dim=0)
  8. print(h1.shape, h2.shape)
  9. # torch.Size([2, 36, 8]) torch.Size([2, 36, 8])

数据统计

Pytorch 中常用的统计属性有求 Tensor 的范数(norm),Tensor 中数据的均值(mean)、求和(sum)、累乘(prod)、最大值(max)、最小值(min)、最大值的位置(argmax)、最小值的位置(argmin)、前 k 大/小的值(topk)、第 k 大/小的值(kthvalue),Tensor 的比较(gt,eq,equal),括号内为接口名称。

Norm-p

关于范数的计算,这里我们只考虑向量的范数,第一范数是绝对值的和 Pytorch 进阶操作 - 图57,二范数是平方和开根号 Pytorch 进阶操作 - 图58,无穷范数是最大值的绝对值 Pytorch 进阶操作 - 图59Tensor.norm() 可以接收两个参数 ord 和 dim,ord 指的是范数的级别,dim 是求解范数要抹掉的维度,如果不填 dim,默认是全维求解。

  1. a = torch.full([2, 2, 2], 1)
  2. print(a)
  3. # tensor([[[1., 1.],
  4. # [1., 1.]],
  5. #
  6. # [[1., 1.],
  7. # [1., 1.]]])
  8. print(a.norm(2)) # tensor(2.8284)
  9. # 使用第二范数原公式验证,结果相同。
  10. print(torch.sqrt(torch.sum(torch.sqrt(a))))
  11. # tensor(2.8284)
  12. print(a.norm(2, dim=1))
  13. # tensor([[1.4142, 1.4142],
  14. # [1.4142, 1.4142]])
  15. print(torch.sqrt(torch.sum(torch.sqrt(a[0][0]))), torch.sqrt(torch.sum(torch.sqrt(a[0][1]))),
  16. torch.sqrt(torch.sum(torch.sqrt(a[1][0]))), torch.sqrt(torch.sum(torch.sqrt(a[1][1]))))
  17. # tensor(1.4142) tensor(1.4142) tensor(1.4142) tensor(1.4142)
  18. # 第一范数
  19. print(a.norm(1))
  20. # tensor(8.)
  21. # 使用第一范数原公式验证,结果相同。
  22. print(torch.sum(torch.abs(a)))
  23. # tensor(8.)

Mean / sum / prod / min / argmax

这些统计接口就相对简单了,要注意的是 argmax 和 argmax 会返回最大值和最小值在 Tensor 中的位置,也就是索引,这对于我们确定预测值非常有帮助。比如我们的手写数字体问题,假如我们训练出一个可以识别手写数字体的模型,当我们给模型喂入一张新的手写数字体图片,它会返回给我们一个结果。因为模型会根据新的图片产生一个单位向量,向量中的每个元素代表是正确结果的概率,所以返回结果的依据就来源于产生的向量中最大值的索引。

  1. b = torch.rand(2, 3)
  2. print(b)
  3. # tensor([[0.0090, 0.6813, 0.5821],
  4. # [0.5943, 0.5482, 0.9192]])
  5. print(b.mean(), b.prod(), b.sum())
  6. # tensor(0.5557) tensor(0.0011) tensor(3.3341)
  7. print(b.min(), b.max(), b.argmin(), b.argmax())
  8. # tensor(0.0090) tensor(0.9192) tensor(0) tensor(5)

Topk / kthvalue

Topk 接口可以获得比 max 和 min 接口更多的信息,通过设置参数 largest = Pytorch 进阶操作 - 图60 可以获得最大的 k 个数据的值和位置或最小的 k 个数据的值和位置。Kthvalue 则可以返回第 k 大的数据的值和位置。

  1. c = torch.rand(3, 4)
  2. print(c)
  3. # tensor([[0.3411, 0.0850, 0.2546, 0.3397],
  4. # [0.1620, 0.2101, 0.8018, 0.6463],
  5. # [0.6258, 0.0545, 0.0312, 0.5660]])
  6. print(c.topk(2))
  7. # torch.return_types.topk(
  8. # values=tensor([[0.3411, 0.3397],
  9. # [0.8018, 0.6463],
  10. # [0.6258, 0.5660]]),
  11. # indices=tensor([[0, 3],
  12. # [2, 3],
  13. # [0, 3]]))
  14. print(c.topk(2, largest=False))
  15. # torch.return_types.topk(
  16. # values=tensor([[0.0850, 0.2546],
  17. # [0.1620, 0.2101],
  18. # [0.0312, 0.0545]]),
  19. # indices=tensor([[1, 2],
  20. # [0, 1],
  21. # [2, 1]]))
  22. print(c.kthvalue(3))
  23. # torch.return_types.kthvalue(
  24. # values=tensor([0.3397, 0.6463, 0.5660]),
  25. # indices=tensor([3, 3, 3]))

Gt / eq / equal

Gt 代表 great,torch.gt(Tensor, n) 相当于 Tensor > n,Tensor 中的每个元素都会跟 n 作比较,并返回 Pytorch 进阶操作 - 图61 列表。Eq 接口的作用是在两个 Tensor 中的每个对应元素之间作比较,返回 Pytorch 进阶操作 - 图62 列表,equal 接口则是比较两个 Tensor 是否相同,返回 Pytorch 进阶操作 - 图63

d = torch.arange(0, 8).reshape(2, 4)
e = torch.stack([torch.arange(0, 4), torch.arange(0, 4)], dim=0).reshape(2, 4)
print(d, '\n', e)
# tensor([[0, 1, 2, 3],
#         [4, 5, 6, 7]])
#  tensor([[0, 1, 2, 3],
#         [0, 1, 2, 3]])
print(d > 0)
# tensor([[False,  True,  True,  True],
#         [ True,  True,  True,  True]])
print(torch.gt(d, 0))
# tensor([[False,  True,  True,  True],
#         [ True,  True,  True,  True]])
print(torch.eq(d, e))
# tensor([[ True,  True,  True,  True],
#         [False, False, False, False]])
print(torch.equal(d, e))
# False