摘要
:::info 讲解了如何寻找二元二次函数的局部最优解 :::
前言
:::info
针对一个二元二次函数, 其最优解只有一个, 但是局部最优解可以有很多个, 本文采用的案例函数是:
这个函数有四个局部最优解, 同时这四个局部最优解恰巧也是全局最优解
分别是:
- f(3.0,2.0)=0.0
- f(-2.8,3.1)=0.0
- f(-3.7,-3.2)=0.0
- f(3.5,-1.84)=0.0
本文演示了如何使用PyTorch利用自动求导机制寻找这四个最优解 :::
导入库
import numpy as npimport matplotlib.pyplot as pltimport torch
定义himmelblau函数
def himmelblau(x, y):return (x ** 2 + y - 11) ** 2 + (x + y ** 2 -7) ** 2
使用matplotlib画出该函数3D图
x = np.arange(-6, 6, 0.1)y = np.arange(-6, 6, 0.1)print('x, y range:', x.shape, y.shape)X, Y = np.meshgrid(x, y)print('X, Y maps:', X.shape, Y.shape)Z = himmelblau(X, Y)fig = plt.figure('himmelblau')ax = fig.gca(projection = '3d')ax.plot_surface(X, Y, Z)ax.view_init(45, 45)ax.set_xlabel('x')ax.set_ylabel('y')plt.show()
x, y range: (120,) (120,) X, Y maps: (120, 120) (120, 120)
编写train代码进行求出最优解
使用(x, y) = (4., 0.)进行初始化
x = torch.tensor([4., 0.], requires_grad=True)
optimizer = torch.optim.Adam([x], lr=1e-3)
for step in range(200001):
pred = himmelblau(x[0], x[1])
optimizer.zero_grad()
pred.backward()
optimizer.step()
if pred <= 1e-6:
print('求得的最优解(x, y, f(x))=({}, {}, {}) 使用步数:{}'.format(x[0], x[1], pred, step))
break
if step%1000 == 0:
print('Logs: step {}: x={}, y={}, f(x) = {}'.format(step, x[0], x[1], pred.item()))
Logs: step 0: x=3.999000072479248, y=-0.0009999999310821295, f(x) = 34.0 Logs: step 1000: x=3.5246353149414062, y=-1.026512861251831, f(x) = 6.0330986976623535 Logs: step 2000: x=3.5741987228393555, y=-1.764183521270752, f(x) = 0.09904692322015762 Logs: step 3000: x=3.584354877471924, y=-1.84756338596344, f(x) = 4.704379534814507e-06 求得的最优解(x, y, f(x))=(3.584392786026001, -1.8478690385818481, 9.875047908280976e-07) 使用步数:3114
:::tips
if pred <= 1e-6:
这行实际是是规定了求解的精读(可以粗略这么理解)
:::
使用(x, y)=(2.5, 30)进行初始化
x = torch.tensor([2.5, 30.], requires_grad=True)
optimizer = torch.optim.Adam([x], lr=1e-3)
for step in range(200001):
pred = himmelblau(x[0], x[1])
optimizer.zero_grad()
pred.backward()
optimizer.step()
if pred <= 1e-6:
print('求得的最优解(x, y, f(x))=({}, {}, {}) 使用步数:{}'.format(x[0], x[1], pred, step))
break
if step%1000 == 0:
print('Logs: step {}: x={}, y={}, f(x) = {}'.format(step, x[0], x[1], pred.item()))
Logs: step 0: x=2.499000072479248, y=29.999000549316406, f(x) = 802557.8125 Logs: step 1000: x=1.5265045166015625, y=29.02056884765625, f(x) = 700609.3125 Logs: step 2000: x=0.5924956202507019, y=28.07658576965332, f(x) = 611734.5625 Logs: step 3000: x=-0.32081228494644165, y=27.154788970947266, f(x) = 533328.3125 Logs: step 4000: x=-1.2193603515625, y=26.24701690673828, f(x) = 463679.84375 Logs: step 5000: x=-2.099128484725952, y=25.348405838012695, f(x) = 401659.65625 Logs: step 6000: x=-2.9469895362854004, y=24.45630645751953, f(x) = 346479.0 Logs: step 7000: x=-3.7395787239074707, y=23.569429397583008, f(x) = 297534.625 Logs: step 8000: x=-4.438563823699951, y=22.687227249145508, f(x) = 254308.140625 Logs: step 9000: x=-4.982048034667969, y=21.809616088867188, f(x) = 216301.75 Logs: step 10000: x=-5.286879062652588, y=20.936437606811523, f(x) = 182983.15625 Logs: step 11000: x=-5.320369243621826, y=20.068103790283203, f(x) = 153842.796875 Logs: step 12000: x=-5.195002555847168, y=19.204631805419922, f(x) = 128442.03125 Logs: step 13000: x=-5.039945602416992, y=18.346664428710938, f(x) = 106432.109375 Logs: step 14000: x=-4.8903069496154785, y=17.494752883911133, f(x) = 87481.8515625 Logs: step 15000: x=-4.745782375335693, y=16.64938735961914, f(x) = 71275.6328125 Logs: step 16000: x=-4.604536056518555, y=15.811263084411621, f(x) = 57519.81640625 Logs: step 17000: x=-4.465677261352539, y=14.980981826782227, f(x) = 45936.625 Logs: step 18000: x=-4.328790187835693, y=14.159292221069336, f(x) = 36268.5859375 Logs: step 19000: x=-4.19378662109375, y=13.347014427185059, f(x) = 28276.595703125 Logs: step 20000: x=-4.060785293579102, y=12.545125961303711, f(x) = 21740.5078125 Logs: step 21000: x=-3.9300661087036133, y=11.75468635559082, f(x) = 16457.8359375 Logs: step 22000: x=-3.802055597305298, y=10.976983070373535, f(x) = 12244.3916015625 Logs: step 23000: x=-3.6773273944854736, y=10.213469505310059, f(x) = 8933.12890625 Logs: step 24000: x=-3.556602716445923, y=9.46587085723877, f(x) = 6374.06591796875 Logs: step 25000: x=-3.440774440765381, y=8.73620319366455, f(x) = 4433.5908203125 Logs: step 26000: x=-3.3309121131896973, y=8.026880264282227, f(x) = 2993.987548828125 Logs: step 27000: x=-3.2282721996307373, y=7.340791702270508, f(x) = 1952.7083740234375 Logs: step 28000: x=-3.1342809200286865, y=6.6814470291137695, f(x) = 1221.6729736328125 Logs: step 29000: x=-3.0504744052886963, y=6.053155899047852, f(x) = 726.4359130859375 Logs: step 30000: x=-2.9783740043640137, y=5.461299896240234, f(x) = 405.27423095703125 Logs: step 31000: x=-2.9192614555358887, y=4.912687301635742, f(x) = 208.15016174316406 Logs: step 32000: x=-2.873851776123047, y=4.416078090667725, f(x) = 95.5833740234375 Logs: step 33000: x=-2.84191632270813, y=3.982802629470825, f(x) = 37.411094665527344 Logs: step 34000: x=-2.821989059448242, y=3.6272058486938477, f(x) = 11.484293937683105 Logs: step 35000: x=-2.811401128768921, y=3.365464210510254, f(x) = 2.372067451477051 Logs: step 36000: x=-2.806835412979126, y=3.2084977626800537, f(x) = 0.24599333107471466 Logs: step 37000: x=-2.8053927421569824, y=3.1447927951812744, f(x) = 0.007369478698819876 Logs: step 38000: x=-2.8051443099975586, y=3.1320815086364746, f(x) = 2.396338095422834e-05 求得的最优解(x, y, f(x))=(-2.805126667022705, 3.1314687728881836, 9.908628726407187e-07) 使用步数:38394
结论
:::info 由这两次不同的初始化(x, y)可知, 初始化的好坏能影响到求出答案的快慢, 也影响到是否能求出全局最优解,若初始化不好, 则会陷入到局部最优解当中而无法出来 :::
