49.pdf

    1. import numpy as np
    2. import torch
    3. import torch.nn as nn
    4. import torch.optim as optim
    5. from matplotlib import pyplot as plt
    6. num_time_steps = 50
    7. input_size = 1
    8. hidden_size = 16
    9. output_size = 1
    10. lr=0.01
    11. class Net(nn.Module):
    12. def __init__(self, ):
    13. super(Net, self).__init__()
    14. self.rnn = nn.RNN(
    15. input_size=input_size,
    16. hidden_size=hidden_size,
    17. num_layers=1,
    18. batch_first=True,
    19. )
    20. for p in self.rnn.parameters():
    21. nn.init.normal_(p, mean=0.0, std=0.001)
    22. self.linear = nn.Linear(hidden_size, output_size)
    23. def forward(self, x, hidden_prev):
    24. out, hidden_prev = self.rnn(x, hidden_prev)
    25. # [b, seq, h]
    26. out = out.view(-1, hidden_size)
    27. out = self.linear(out)
    28. out = out.unsqueeze(dim=0)
    29. return out, hidden_prev
    30. model = Net()
    31. criterion = nn.MSELoss()
    32. optimizer = optim.Adam(model.parameters(), lr)
    33. hidden_prev = torch.zeros(1, 1, hidden_size)
    34. for iter in range(6000):
    35. start = np.random.randint(3, size=1)[0]
    36. time_steps = np.linspace(start, start + 10, num_time_steps)
    37. data = np.sin(time_steps)
    38. data = data.reshape(num_time_steps, 1)
    39. x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
    40. y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
    41. output, hidden_prev = model(x, hidden_prev)
    42. hidden_prev = hidden_prev.detach()
    43. loss = criterion(output, y)
    44. model.zero_grad()
    45. loss.backward()
    46. # for p in model.parameters():
    47. # print(p.grad.norm())
    48. # torch.nn.utils.clip_grad_norm_(p, 10)
    49. optimizer.step()
    50. if iter % 100 == 0:
    51. print("Iteration: {} loss {}".format(iter, loss.item()))
    52. start = np.random.randint(3, size=1)[0]
    53. time_steps = np.linspace(start, start + 10, num_time_steps)
    54. data = np.sin(time_steps)
    55. data = data.reshape(num_time_steps, 1)
    56. x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
    57. y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
    58. predictions = []
    59. input = x[:, 0, :]
    60. for _ in range(x.shape[1]):
    61. input = input.view(1, 1, 1)
    62. (pred, hidden_prev) = model(input, hidden_prev)
    63. input = pred
    64. predictions.append(pred.detach().numpy().ravel()[0])
    65. x = x.data.numpy().ravel()
    66. y = y.data.numpy()
    67. plt.scatter(time_steps[:-1], x.ravel(), s=90)
    68. plt.plot(time_steps[:-1], x.ravel())
    69. plt.scatter(time_steps[1:], predictions)
    70. plt.show()