7.nn_Module

neural network

神经网络的很多工具都在torch.nn里面

containers会提供神经网络的骨架

containers中有很多模块,我们常用的是Module模块

Base class for all neural network modules

Module给所有神经网络提供骨架

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. #定义一个类Model ,其继承父类nn.Module
  4. class Model(nn.Module):
  5. #继承父类之后初始化,在创建实例时会进行初始化
  6. def __init__(self):
  7. #必须要的
  8. super(Model, self).__init__()
  9. #以下可以自己编写
  10. self.conv1 = nn.Conv2d(1, 20, 5)
  11. self.conv2 = nn.Conv2d(20, 20, 5)
  12. #前向传递 input x ,conv1 → relu → conv2 → relu → output
  13. def forward(self, x):
  14. x = F.relu(self.conv1(x))
  15. return F.relu(self.conv2(x))

例子

  1. from torch import nn
  2. import torch
  3. class DEMO(nn.Module):
  4. def __init__(self):
  5. super(DEMO,self).__init__()
  6. def forward(self, input):
  7. output = input +1
  8. return output
  9. # 创建一个实例demo
  10. demo = DEMO()
  11. x = torch.tensor(1.0)
  12. # x输入给实例demo
  13. output = demo(x)
  14. print(output)