参考来源:
博客园:PyTorch笔记之 squeeze() 和 unsqueeze()
1. squeeze() 函数
squeeze()
用来去掉向量的一个维度,只有维度为 **1**
的那一维才能去掉。
Example:
初始化 1
个向量 shape
为 (1,2,3)
的向量
import torch
a = torch.rand((1,2,3))
去掉第 0
维,第 0
维的大小是 1
,所以可以去掉第 0
维,去掉后向量的 shape
是(2,3)
。
去掉最后一维,最后一维的大小是 3
,所以不会操作成功,向量的 shape
仍然是 (1,2,3)
。
2. unsqueeze() 函数
从函数名字就可以看出,unsqueeze()
和 squeeze()
的功能是相反的,squeeze
是去掉 1
维,那 unsqueeze()
就是增加 1
维。
Example:
增加第 4
维,此时向量的 shape
是(1,2,3,1)
。
插入第 0
维,我们初始化一个 shape
为(2,3)
的向量,然后在插入第 0
维,插入后向量的 shape
为(1,2,3)
。