高维张量转置:
>>> tensor = torch.randn(1, 2, 3)
>>> tensor # shape of tensor is [1, 2, 3]
tensor([[[ 0.1264, -0.7503, 0.5522],
[ 0.0680, 1.0128, 0.1585]]])
>>> tensor = torch.transpose(tensor, dim0=1, dim1=2)
>>> tensor # shape of tensor is [1, 3, 2]
tensor([[[ 0.1264, 0.0680],
[-0.7503, 1.0128],
[ 0.5522, 0.1585]]])
>>> tensor = torch.transpose(tensor, dim0=0, dim1=2)
>>> tensor # shape of tensor is [2, 3, 1]
tensor([[[ 0.1264],
[-0.7503],
[ 0.5522]],
[[ 0.0680],
[ 1.0128],
[ 0.1585]]])
首先,我们要先学着阅读张量表达的数据结构
1.先看有几个括号,有几个括号就是几阶张量
tensor([[[ 0.1264, -0.7503, 0.5522],
[ 0.0680, 1.0128, 0.1585]]])
这里有三个括号,说明是三维张量。
我们从外往里进行检索,最外层只有一个所以是1
第二层里面有2个元素,所以是2
第三层就到了最具体的元素层了,所以是3
>>> tensor = torch.randn(1, 2, 3)
>>> tensor # shape of tensor is [1, 2, 3]
tensor([[[ 0.1264, -0.7503, 0.5522],
[ 0.0680, 1.0128, 0.1585]]])
维度主要就是看第几个括号里面有几个元素。
这也是我们索引的主要方法
现在我们来说,如何对高维张量进行转置
>>> tensor = torch.transpose(tensor, dim0=1, dim1=2)
>>> tensor # shape of tensor is [1, 3, 2]
tensor([[[ 0.1264, 0.0680],
[-0.7503, 1.0128],
[ 0.5522, 0.1585]]])
>>> tensor = torch.transpose(tensor, dim0=0, dim1=2)
>>> tensor # shape of tensor is [2, 3, 1]
tensor([[[ 0.1264],
[-0.7503],
[ 0.5522]],
[[ 0.0680],
[ 1.0128],
[ 0.1585]]])
首先,原来的数据格式是 [1,3,2] 我们要转置成 [2,3,1]
我们可以发现3是不发生改变的。
我们换个视角看待这些数据。
由于最后一位是1,所以,我们要把三个向量组合起来
>>> tensor # shape of tensor is [2, 3, 1]
tensor([[[ 0.1264],
[-0.7503],
[ 0.5522]],
[[ 0.0680],
[ 1.0128],
[ 0.1585]]])
import torch
import numpy as np
tensor_ = torch.from_numpy(np.arange(24).reshape(1,2,3,4))
print(tensor_)
# tensor([[[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
#
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]]])
tensor_t = tensor_.transpose(1, 3)
print(tensor_t)
# tensor([[[[ 0, 12],
# [ 4, 16],
# [ 8, 20]],
#
# [[ 1, 13],
# [ 5, 17],
# [ 9, 21]],
#
# [[ 2, 14],
# [ 6, 18],
# [10, 22]],
#
# [[ 3, 15],
# [ 7, 19],
# [11, 23]]]])
tensor_tt = tensor_t.transpose(1, 3)
print(tensor_tt)
# tensor([[[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
#
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]]])
本来的形状是(1,2,3,4)现在我们转变2和4,说明1和3不变
主要关注在3这个维度。
我们要确保3这个维度上的数据都是一组的。
# tensor([[[[ 0, 12],
# [ 4, 16],
# [ 8, 20]],
#
# [[ 1, 13],
# [ 5, 17],
# [ 9, 21]],
#
# [[ 2, 14],
# [ 6, 18],
# [10, 22]],
#
# [[ 3, 15],
# [ 7, 19],
# [11, 23]]]])
如果最外层是2及以上的数字,说明不同组之间的数据是不流通的。