对于昨天的代码使用代码自身的函数和模块进行进一步的精简

    1. import torch.nn as nn
    2. import time
    3. time_start=time.time()
    4. N,D_in,H,D_out = 64,1000,100,10
    5. x=torch.randn(N,D_in,requires_grad=True)#default is False
    6. y=torch.randn(N,D_out,requires_grad=True)
    7. class TwoLayerNet(torch.nn.Module):
    8. def __init__(self,D_in,H,D_out):
    9. # define model architecture
    10. super(TwoLayerNet,self).__init__()
    11. self.linear1=torch.nn.Linear(D_in,H,bias=False)
    12. self.linear2=torch.nn.Linear(H,D_out,bias=False)
    13. def forward(self,x):
    14. y_pred =self.linear2(self.linear1(x). clamp(min=0))
    15. return y_pred
    16. #model=torch.nn.Sequential(
    17. # torch.nn.Linear(D_in,H),
    18. # torch.nn.ReLU(),
    19. # torch.nn.Linear(H,D_out),
    20. #)
    21. model=TwoLayerNet(D_in,H,D_out)
    22. learning_rate=1e-4
    23. #torch.nn.init.normal_(model[0].weight)
    24. #torch.nn.init.normal_(model[2].weight)
    25. optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
    26. if torch.cuda.is_available():
    27. model=model.cuda()
    28. x=x.cuda()
    29. y=y.cuda()
    30. loss_func=nn.MSELoss(reduction='sum')
    31. for t in range(1000):
    32. # forward pass
    33. y_pred =model(x).cuda()
    34. #compute loss and use square loss
    35. loss = loss_func(y_pred , y).cuda()
    36. loss_num=loss.item()
    37. print(t,loss_num)
    38. #backward pass
    39. # compute gradiend
    40. optimizer.zero_grad()
    41. loss.backward()
    42. # Update weights
    43. optimizer.step()
    44. time_end=time.time()
    45. print('totally cost',time_end-time_start)

    同时定义了一个简单的游戏,然后没有告知规则的情况下利用数据进行训练
    经过测试训练100回合正确率61%,1000回合正确率94%,训练10000回合正确率98%

    1. def fizz_buzz_encode(i):
    2. if i % 15 == 0:
    3. return 3
    4. elif i % 5 == 0:
    5. return 2
    6. elif i % 3 == 0:
    7. return 1
    8. else:
    9. return 0
    10. def fizz_buzz_decode(i,prediction):
    11. return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]
    12. def helper(i):
    13. print(fizz_buzz_decode(i,fizz_buzz_encode(i)))
    14. import numpy as np
    15. import torch
    16. NUM_DIGITS = 10
    17. # Represent each input by an array of its binary digits.
    18. def binary_encode(i, num_digits):
    19. return np.array([i >> d & 1 for d in range(num_digits)][::-1])
    20. trX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)])
    21. trY = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)])
    22. # Define the model
    23. NUM_HIDDEN = 100
    24. model = torch.nn.Sequential(
    25. torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
    26. torch.nn.ReLU(),
    27. torch.nn.Linear(NUM_HIDDEN, 4)
    28. )
    29. loss_fn = torch.nn.CrossEntropyLoss()
    30. optimizer = torch.optim.SGD(model.parameters(), lr = 0.5)
    31. # Start training it
    32. BATCH_SIZE = 128
    33. if torch.cuda.is_available():
    34. model=model.cuda()
    35. trX=trX.cuda()
    36. trY=trY.cuda()
    37. for epoch in range(10000):
    38. for start in range(0, len(trX), BATCH_SIZE):
    39. end = start + BATCH_SIZE
    40. batchX = trX[start:end]
    41. batchY = trY[start:end]
    42. batchX=batchX.cuda()
    43. batchY=batchY.cuda()
    44. y_pred = model(batchX)
    45. y_pred=y_pred.cuda()
    46. loss = loss_fn(y_pred, batchY)
    47. optimizer.zero_grad()
    48. loss.backward()
    49. optimizer.step()
    50. # Find loss on training data
    51. loss = loss_fn(model(trX), trY).item()
    52. print('Epoch:', epoch, 'Loss:', loss)
    53. # Output now
    54. testX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
    55. testX=testX.cuda()
    56. with torch.no_grad():
    57. testY = model(testX)
    58. predictions = zip(range(1, 101), list(testY.max(1)[1].data.tolist()))
    59. print([fizz_buzz_decode(i, x) for (i, x) in predictions])
    60. print(np.sum(testY.cpu().max(1)[1].numpy() == np.array([fizz_buzz_encode(i) for i in range(1,101)])))
    61. testY.cpu().max(1)[1].numpy() == np.array([fizz_buzz_encode(i) for i in range(1,101)])