nn.Module

  • 继承nn.Module,输入数值,经过forward输出
  • 示例代码:
  1. import torch
  2. from torch import nn
  3. class FirstTime(nn.Module):
  4. def __init__(self):
  5. super().__init__()
  6. def forward(self,input):
  7. output = input + 1
  8. return output
  9. first = FirstTime()
  10. input = torch.tensor(1.0)
  11. output = first(input)
  12. print(output)