张量操作

一、张量拼接与切分

1. torch.cat()

功能:将张量按维度dim进行降维
tensors:张量序列
dim:要拼接的维度
image.png

2. torch.stack()

功能:在新创建的维度dim进行降维(创建一个新的维度)
tensors:张量序列
dim:要拼接的维度
image.png

3. torch.chunk()

功能:将张量按维度dim进行平均切分
返回值:张量列表
注意事项:若不能整除,最后一份张量小于其它张量,前面的张量长度等于整除后向上取整。
input:要切分的张量
chunks:要切分的份数
dim:要切分的维度
image.png

4. torch.split()

功能:将张量按维度dim进行切分,并且可以指定切分的长度
返回值:张量列表
tensor:要切分的张量
split_size_or_sections:为int时,表示每一份长度;为list时,按list元素切分;
dim:要切分的维度
image.png

二、张量的索引

1. torch.index_select()

功能:在指定的维度dim上,根据指定的index索引数据
返回值:根据index索引的数据按dim拼接的张量
input:要索引的张量
dim:要索引的维度
index:要索引数据的序号,必须是个张量,dtype=torch.long
image.png

2. torch.masked_select()

功能:按mask中的True进行索引,通常用来**筛选**数据
返回值:一维张量(因为不能确定mask中True的个数,所以无法确定shape)
input:要索引的张量
mask:与input同形状的布尔类型张量
image.png
image.png

三、张量的变换

1. torch.reshape()

功能:变换张量形状
注意事项:当张量在内存中是连续时,新张量与input共享数据内存
input:要变换的张量
shape:新张量的形状
image.png当shape中维度为-1时,表示该维度由其它维度计算自动得到

2. torch.transpose()

功能:交换张量的两个维度,通常在图片预处理时进行变换,将channel交换到前面。
input:要交换的张量
dim0:要交换的维度
dim1:要交换的维度
image.png

3. torch.t()

功能:2维张量转置,对矩阵而言,等价于torch.transpose(input, 0, 1)

4.torch.squeeze()

功能:压缩长度为1的维度(轴)
dim:若为None,移除所有长度为1的轴;若指定维度,当且仅当该指定的轴长度为1时,才可以被移除;

5.torch.unsqueeze()

功能:依据dim扩展维度
dim:扩展的维度

张量的运算

一、加减乘除

1. torch.add()

功能:逐元素计算 input+alpha*other,除了加法还可以实现先乘后加,代码更简洁。
input:第一个张量
alpha:乘法因子
other:第二个张量

2. torch.addcdiv()

功能:逐元素计算 image.png

3. torch.addcmul()

功能:逐元素计算image.png

4. torch.sub()

5. torch.div()

6. torch.mul()

二、对数、指数、幂函数

1. torch.log(input,out=None)

2. torch.log10(input,out=None)

3. torch.log2(input,out=None)

4. torch.exp(input,out=None)

5. torch.pow()

三、三角函数

1. torch.abs(input, out=None)

2. torch.acos(input, out=None)

3. torch.cosh(input, out=None)

4. torch.cos(input, out=None)

5. torch.asin(input, out=None)

6. torch.atan(input, out=None)

7. torch.atan2(input, other out=None)

线性回归

概念

定义:线性回归是分析一个变量与另外一(多)个变量之间关系的方法
y = wx + b
因变量:y 自变量:x 关系:线性
分析:求解w,b

求解步骤

  1. 确定模型; 如Model:y = wx + b
  2. 选择损失函数,如MSE: image.png
  3. 求解梯度并更新w,b: w = w - LR w.grad b = b - LR w.grad (LR为学习率learning rate)

image.pngimage.pngimage.pngimage.pngimage.pngimage.png