PyTorch

    1. # 用pytorch简单的构建全连接层神经网络
    2. # 对手写字进行识别
    3. import torch
    4. import torchvision
    5. # 参数函数,类似激活函数
    6. import torch.nn.functional as F
    7. # standard datasets
    8. import torchvision.datasets as datasets
    9. # 可以对数据集进行转换
    10. import torchvision.transforms as transforms
    11. from torch import optim # 参数优化
    12. from torch import nn # 所有神经网络模块
    13. # Gives easier dataset managment by creating mini batches etc
    14. from torch.utils.data import DataLoader
    15. # for nice progress bar!
    16. from tqdm import tqdm
    17. #--------------------------------------------------------------------
    18. class NN(nn.Module):
    19. def __init__(self, input_size, num_classes):
    20. super(NN, self).__init__()
    21. self.fc1 = nn.Linear(input_size, 50)
    22. self.fc2 = nn.Linear(50, num_classes)
    23. def forward(self, x):
    24. x = F.relu(self.fc1(x))
    25. x = self.fc2(x)
    26. return x
    27. #--------------------------------------------------------------------
    28. # Set device cuda for GPU if it's available otherwise run on the CPU
    29. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    30. # 参数设置
    31. input_size = 784
    32. num_classes = 10
    33. learning_rate = 0.001
    34. batch_size = 64
    35. num_epochs = 3
    36. #--------------------------------------------------------------------
    37. # Load Training and Test data
    38. train_dataset = datasets.MNIST(root='dataset/',train=True,transform=transforms.ToTensor(),download=True)
    39. test_dataset = datasets.MNIST(root='dataset/',train=False,transform=transforms.ToTensor(),download=True)
    40. train_loader = DataLoader(dataset=train_dataset,shuffle=True,batch_size=batch_size)
    41. test_loader = DataLoader(dataset=test_dataset,shuffle=True,batch_size=batch_size)
    42. #--------------------------------------------------------------------
    43. # 初始化网络
    44. model = NN(input_size=input_size,num_classes=num_classes).to(device)
    45. # loss and optimizer
    46. criterion= nn.CrossEntropyLoss()
    47. optimizer = optim.Adam(model.parameters(),lr=learning_rate)
    48. # 训练网络
    49. for epoch in range(num_epochs):
    50. for batch_idxm,(data,targets) in enumerate(tqdm(train_loader)):
    51. data = data.to(device)
    52. targets = targets.to(device)
    53. # [64,1,28,28] -> [64,1*28*28]
    54. print(data.shape)
    55. data = data.reshape(data.shape[0],-1)
    56. # 计算损失
    57. outputs = model(data)
    58. loss = criterion(outputs,targets)
    59. # 向后传播
    60. loss.backward()
    61. # 梯度归0
    62. optimizer.zero_grad()
    63. # 梯度优化
    64. optimizer.step()
    65. #--------------------------------------------------------------------
    66. # model save and load
    67. def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    68. print("=> Saving checkpoint")
    69. torch.save(state, filename)
    70. def load_checkpoint(checkpoint, model, optimizer):
    71. print("=> Loading checkpoint")
    72. model.load_state_dict(checkpoint["state_dict"])
    73. optimizer.load_state_dict(checkpoint["optimizer"])
    74. checkpoint = {"state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
    75. # Try save checkpoint
    76. save_checkpoint(checkpoint)
    77. # Try load checkpoint
    78. load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)
    79. #--------------------------------------------------------------------
    80. # Check accuracy on training & test to see how good our model
    81. def check_accuracy(loader, model):
    82. num_correct = 0
    83. num_samples = 0
    84. model.eval()
    85. with torch.no_grad():
    86. for x, y in loader:
    87. x = x.to(device=device)
    88. y = y.to(device=device)
    89. x = x.reshape(x.shape[0], -1)
    90. outputs = model(x)
    91. _, indexes = outputs.max(1)
    92. num_correct += (indexes == y).sum()
    93. num_samples += indexes.size(0) # batch_size
    94. model.train()
    95. return num_correct/num_samples
    96. #--------------------------------------------------------------------
    97. print(f"Accuracy on training set: {check_accuracy(train_loader, model)*100:.2f}")
    98. print(f"Accuracy on test set: {check_accuracy(test_loader, model)*100:.2f}")
    99. # Accuracy on training set: 15.65
    100. # Accuracy on test set: 15.53