12.Linear
https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear
torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
以vgg16网络举例,图中这一步骤就用到了线性层,in_fearure = 4096 , out_feature = 10000
如下图是一个线性神经网络
in_feature就是输入的x的个数
out_feature就是输出的个数
线性层会根据神经网络计算公式:Ax+bias得到output,在pytorch框架下使用时只需要设置好in_feature,out_feature 和 bias
torch.reshape()函数的使用
https://pytorch.org/docs/1.2.0/torch.html#torch.reshape
#返回一个张量,内容和原来的张量相同,但是具有不同形状.#并且尽可能返回视图,否则才会返回拷贝,#因此,需要注意内存共享问题.#传入的参数可以有一个-1,#表示其具体值由其他维度信息和元素总个数推断出来.
torch.flatten()
https://pytorch.org/docs/stable/generated/torch.flatten.html?highlight=flatten#torch.flatten
用于将数据展平
torch.flatten(input, start_dim=0, end_dim=-1)
import torchimport torchvisionfrom torch import nnfrom torch.nn import Linearfrom torch.utils.data import DataLoaderdataset = torchvision.datasets.CIFAR10('./dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)#这里设置drop_last = True的原因是最后一个Batch_size不足64,展平后不是196608不符合设置的线性层输入 因此设置一下dataloader = DataLoader(dataset,batch_size=64,drop_last=True)class DEMO(nn.Module):def __init__(self):super(DEMO, self).__init__()self.linear1 = Linear(in_features=196608, out_features=10, bias=False)def forward(self,x):x = self.linear1(x)return xdemo = DEMO()#torch.reshape()函数官方文档# 返回一个张量,内容和原来的张量相同,但是具有不同形状.# 并且尽可能返回视图,否则才会返回拷贝,# 因此,需要注意内存共享问题.# 传入的参数可以有一个-1,# 表示其具体值由其他维度信息和元素总个数推断出来.for data in dataloader:imgs,targets = dataprint(imgs.shape)# output = torch.reshape(imgs,(1,1,1,-1))output = torch.flatten(imgs)print(output.shape)output_linear = demo(output)print(output_linear.shape)
在实际使用时,torchvision中有很多已经搭建好的模型可以直接使用
