49.pdf
import numpy as npimport torchimport torch.nn as nnimport torch.optim as optimfrom matplotlib import pyplot as pltnum_time_steps = 50input_size = 1hidden_size = 16output_size = 1lr=0.01class Net(nn.Module): def __init__(self, ): super(Net, self).__init__() self.rnn = nn.RNN( input_size=input_size, hidden_size=hidden_size, num_layers=1, batch_first=True, ) for p in self.rnn.parameters(): nn.init.normal_(p, mean=0.0, std=0.001) self.linear = nn.Linear(hidden_size, output_size) def forward(self, x, hidden_prev): out, hidden_prev = self.rnn(x, hidden_prev) # [b, seq, h] out = out.view(-1, hidden_size) out = self.linear(out) out = out.unsqueeze(dim=0) return out, hidden_prevmodel = Net()criterion = nn.MSELoss()optimizer = optim.Adam(model.parameters(), lr)hidden_prev = torch.zeros(1, 1, hidden_size)for iter in range(6000): start = np.random.randint(3, size=1)[0] time_steps = np.linspace(start, start + 10, num_time_steps) data = np.sin(time_steps) data = data.reshape(num_time_steps, 1) x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1) y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1) output, hidden_prev = model(x, hidden_prev) hidden_prev = hidden_prev.detach() loss = criterion(output, y) model.zero_grad() loss.backward() # for p in model.parameters(): # print(p.grad.norm()) # torch.nn.utils.clip_grad_norm_(p, 10) optimizer.step() if iter % 100 == 0: print("Iteration: {} loss {}".format(iter, loss.item()))start = np.random.randint(3, size=1)[0]time_steps = np.linspace(start, start + 10, num_time_steps)data = np.sin(time_steps)data = data.reshape(num_time_steps, 1)x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)predictions = []input = x[:, 0, :]for _ in range(x.shape[1]): input = input.view(1, 1, 1) (pred, hidden_prev) = model(input, hidden_prev) input = pred predictions.append(pred.detach().numpy().ravel()[0])x = x.data.numpy().ravel()y = y.data.numpy()plt.scatter(time_steps[:-1], x.ravel(), s=90)plt.plot(time_steps[:-1], x.ravel())plt.scatter(time_steps[1:], predictions)plt.show()