22 Himmelblau.pdf

    1. import numpy as np
    2. from mpl_toolkits.mplot3d import Axes3D
    3. from matplotlib import pyplot as plt
    4. import torch
    5. def himmelblau(x):
    6. return (x[0] ** 2 + x[1] - 11) ** 2 + (x[0] + x[1] ** 2 - 7) ** 2
    7. x = np.arange(-6, 6, 0.1)
    8. y = np.arange(-6, 6, 0.1)
    9. print('x,y range:', x.shape, y.shape)
    10. X, Y = np.meshgrid(x, y)
    11. print('X,Y maps:', X.shape, Y.shape)
    12. Z = himmelblau([X, Y])
    13. fig = plt.figure('himmelblau')
    14. ax = fig.gca(projection='3d')
    15. ax.plot_surface(X, Y, Z)
    16. ax.view_init(60, -30)
    17. ax.set_xlabel('x')
    18. ax.set_ylabel('y')
    19. plt.show()
    20. # [1., 0.], [-4, 0.], [4, 0.]
    21. x = torch.tensor([-4., 0.], requires_grad=True)
    22. optimizer = torch.optim.Adam([x], lr=1e-3)
    23. for step in range(20000):
    24. pred = himmelblau(x)
    25. optimizer.zero_grad()
    26. pred.backward()
    27. optimizer.step()
    28. if step % 2000 == 0:
    29. print ('step {}: x = {}, f(x) = {}'
    30. .format(step, x.tolist(), pred.item()))