参考来源:
CSDN:【Pytorch】对比 expand 和 repeat 函数

**expand()****repeat()** 函数是 pytorch 中常用于进行张量数据复制和维度扩展的函数,但其工作机制差别很大,本文对这两个函数进行对比。

1. expand()

  1. tensor.expand(*sizes)

**expand()** 函数用于将张量中单数维的数据扩展到指定的 size
首先解释下什么叫单数维singleton dimensions),张量在某个维度上的 size1 ,则称为单数维。比如 **zeros(2,3,4)** 不存在单数维,而 **zeros(2,1,4)** 在第二个维度(即维度 1 )上为单数维。**expand()** 函数仅仅能作用于这些单数维的维度上。
参数 ***sizes** 用于逐个指定各个维度扩展后的大小(也可以理解为拓展的次数),对于不需要或者无法(即非单数维)进行扩展的维度,对应位置可写上原始维度大小或直接写作 **-1**
**expand()** 函数可能导致原始张量的升维,其作用在张量前面的维度上,因此通过 **expand()** 函数可将张量数据复制多份(可理解为沿着第一个 **batch** 的维度上)。
另一个值得注意的点是:**expand()** 函数并不会重新分配内存,返回结果仅仅是原始张量上的一个视图。

  1. import torch
  2. a = torch.tensor([1, 0, 2])
  3. b = a.expand(2, -1) # 第一个维度为升维,第二个维度保持原样
  4. print('b:\n', b)
  5. """
  6. b:
  7. tensor([[1, 0, 2],
  8. [1, 0, 2]])
  9. """
  1. import torch
  2. a = torch.tensor([[1], [0], [2]])
  3. print('a.shape:', a.shape)
  4. b = a.expand(-1, 2) # 保持第一个维度,第二个维度只有一个元素,可扩展
  5. print('b:\n', b)
  6. """
  7. a.shape: torch.Size([3, 1])
  8. b:
  9. tensor([[1, 1],
  10. [0, 0],
  11. [2, 2]])
  12. """
  1. import torch
  2. a = torch.tensor([[1], [0], [2]])
  3. print('a.shape:', a.shape)
  4. b = a.expand(2, -1) # 第一个维度有三个元素,不可扩展
  5. print('b:\n', b)
  6. """
  7. Traceback (most recent call last):
  8. File "test.py", line 13, in <module>
  9. b = a.expand(2, -1) # 保持第一个维度,第二个维度只有一个元素,可扩展
  10. RuntimeError: The expanded size of the tensor (2) must match the existing size (3) at non-singleton dimension 0. Target sizes: [2, -1]. Tensor sizes: [3, 1]
  11. Process finished with exit code 1
  12. """

2. expand_as()

expand_as() 函数可视为 expand() 的另一种表达,其 size 通过函数传递的目标张量的 size 来定义。

  1. import torch
  2. a = torch.tensor([1, 0, 2])
  3. b = torch.zeros(2, 3)
  4. c = a.expand_as(b) # a照着b的维度大小进行拓展
  5. print('a.shape:', a.shape)
  6. print('b.shape:', b.shape)
  7. print('c:\n', c)
  8. print('c.shape:', c.shape)
  9. """
  10. a.shape: torch.Size([3])
  11. b.shape: torch.Size([2, 3])
  12. c:
  13. tensor([[1, 0, 2],
  14. [1, 0, 2]])
  15. c.shape: torch.Size([2, 3])
  16. """

3. repeat()

前文提及 expand() 仅能作用于**单数维**,那对于非单数维的拓展,那就需要借助于 repeat() 函数了。

  1. tensor.repeat(*sizes)

参数 ***sizes** 指定了原始张量在各维度上复制的次数。整个原始张量作为一个整体进行复制,这与 Numpy 中的 **repeat()** 函数截然不同,而更接近于 tile() 函数的效果。
**expand()**不同,**repeat()**函数会真正的复制数据并存放于内存中。
下面是一个简单的例子:

  1. import torch
  2. a = torch.tensor([1, 0, 2])
  3. b = a.repeat(3, 2) # 在轴0上复制3份,在轴1上复制2份
  4. print('b:\n', b)
  5. """
  6. b:
  7. tensor([[1, 0, 2, 1, 0, 2],
  8. [1, 0, 2, 1, 0, 2],
  9. [1, 0, 2, 1, 0, 2]])
  10. """

4. repeat_intertile()

Pytorch 中,与 **Numpy****repeat()** 函数相类似的函数为 **torch.repeat_interleave()**

  1. torch.repeat_interleave(input, repeats, dim=None)

参数 **input** 为原始张量,**repeats** 为指定轴上的复制次数,而 **dim** 为复制的操作轴,若取值为 **None** 则默认将所有元素进行复制,并会返回一个 flatten(压平) 之后的一维张量。
**repeat()** 将整个原始张量作为整体不同,**repeat_interleave()** 操作是逐元素的。
下面是一个简单的例子:

  1. import torch
  2. a = torch.tensor([[1], [0], [2]])
  3. b = torch.repeat_interleave(a, repeats=3) # 结果flatten
  4. print('b:\n', b)
  5. """
  6. b:
  7. tensor([1, 1, 1, 0, 0, 0, 2, 2, 2])
  8. """
  1. import torch
  2. a = torch.tensor([[1], [0], [2]])
  3. c = torch.repeat_interleave(a, repeats=3, dim=1) # 沿着axis=1逐元素复制
  4. print('c:\n', c)
  5. """
  6. c:
  7. tensor([[1, 1, 1],
  8. [0, 0, 0],
  9. [2, 2, 2]])
  10. """