最近读到一些论文,涉及到了该函数的使用:
- -
另外,在曾经的关于深度模型设计里空间偏移操作的实现的讨论(Pytorch中Spatial-Shift-Operation的5种实现策略)中,也同样提到了这一操作的使用。
虽然该函数的作用比较直观,就是一个基于提供的grid,调整对应于输出位置上的输入位置的操作。但是由于涉及到坐标表示,所以这里不同维度的顺序就格外重要了。
针对多维数组的索引,往往会有两种形式:
- 一种是深度学习模型特征处理中常用的,从高维到低维的思路。按照从左到右的顺序的索引;
- 一种则是延续数学中坐标系的思路,分为x、y、z等轴向维度。按照从右到左的顺序索引。
那么,这个函数是按照什么顺序呢?
为此,我们需要简单进行一下测试。
首先定义下新的采样网格:
>>> grid = torch.meshgrid(torch.linspace(-1, 1, 4), torch.linspace(-1, 1, 2), indexing='ij')
>>> grid
(tensor([[-1.0000, -1.0000],
[-0.3333, -0.3333],
[ 0.3333, 0.3333],
[ 1.0000, 1.0000]]), tensor([[-1., 1.],
[-1., 1.],
[-1., 1.],
[-1., 1.]]))
再定义数据:
>>> data = torch.arange(8).reshape(1, 1, 4, 2)
>>> data
tensor([[[[0, 1],
[2, 3],
[4, 5],
[6, 7]]]])
采样实验:
>>> F.grid_sample(data.float(), torch.stack(grid, dim=-1).unsqueeze(0).float(), mode='bilinear', align_corners=True)
tensor([[[[0.0000, 6.0000],
[0.3333, 6.3333],
[0.6667, 6.6667],
[1.0000, 7.0000]]]])
>>> F.grid_sample(data.float(), torch.stack(grid[::-1], dim=-1).unsqueeze(0).float(), mode='bilinear', align_corners=True)
tensor([[[[0.0000, 1.0000],
[2.0000, 3.0000],
[4.0000, 5.0000],
[6.0000, 7.0000]]]]
可以知道,这里的grid要保证通最后一维中:
- 0表示w方向的索引,即x方向的索引。
- 1表示h方向的索引,即y方向的索引。
grid最后一维的元素,[[z,]y,x]从左到右分别对应多维数组的最后三维/二维的从左到右的索引