43.pdf

    1. import torch
    2. from torch import nn
    3. from torch import optim
    4. class MyLinear(nn.Module):
    5. def __init__(self, inp, outp):
    6. super(MyLinear, self).__init__()
    7. # requires_grad = True
    8. self.w = nn.Parameter(torch.randn(outp, inp))
    9. self.b = nn.Parameter(torch.randn(outp))
    10. def forward(self, x):
    11. x = x @ self.w.t() + self.b
    12. return x
    13. class Flatten(nn.Module):
    14. def __init__(self):
    15. super(Flatten, self).__init__()
    16. def forward(self, input):
    17. return input.view(input.size(0), -1)
    18. class TestNet(nn.Module):
    19. def __init__(self):
    20. super(TestNet, self).__init__()
    21. self.net = nn.Sequential(nn.Conv2d(1, 16, stride=1, padding=1),
    22. nn.MaxPool2d(2, 2),
    23. Flatten(),
    24. nn.Linear(1*14*14, 10))
    25. def forward(self, x):
    26. return self.net(x)
    27. class BasicNet(nn.Module):
    28. def __init__(self):
    29. super(BasicNet, self).__init__()
    30. self.net = nn.Linear(4, 3)
    31. def forward(self, x):
    32. return self.net(x)
    33. class Net(nn.Module):
    34. def __init__(self):
    35. super(Net, self).__init__()
    36. self.net = nn.Sequential(BasicNet(),
    37. nn.ReLU(),
    38. nn.Linear(3, 2))
    39. def forward(self, x):
    40. return self.net(x)
    41. def main():
    42. device = torch.device('cuda')
    43. net = Net()
    44. net.to(device)
    45. net.train()
    46. net.eval()
    47. # net.load_state_dict(torch.load('ckpt.mdl'))
    48. #
    49. #
    50. # torch.save(net.state_dict(), 'ckpt.mdl')
    51. for name, t in net.named_parameters():
    52. print('parameters:', name, t.shape)
    53. for name, m in net.named_children():
    54. print('children:', name, m)
    55. for name, m in net.named_modules():
    56. print('modules:', name, m)
    57. if __name__ == '__main__':
    58. main()