keras tune 使用


安装

pip install keras-tuner [–upgrade]

import kerastuner as kt

from tensorflow import keras

写函数的时候使用hp来指代超参,

tuner = kt.RandomSearch(bulid_model,objective=‘val_loss’,max_trial=5)

上面使用随机搜索通过检测验证损失来尝试五个不同模型

  1. tuner.search(x_train, y_train, epochs=5, validation_data=(x_val, y_val))
  2. best_model = tuner.get_best_models()[0]

搜索模型超参,按照目标度量来进行排序

  1. tuner.search_space_summary()

查看参数搜索空间

  1. 例子
  2. def build_model(hp):
  3. model = keras.Sequential()
  4. model.add(layers.Flatten())
  5. for i in range(hp.Int("num_layers", 2, 20)):
  6. model.add(
  7. layers.Dense(
  8. units=hp.Int("units_" + str(i), min_value=32, max_value=512, step=32),
  9. activation="relu",
  10. )
  11. )
  12. model.add(layers.Dense(10, activation="softmax"))
  13. model.compile(
  14. optimizer=keras.optimizers.Adam(hp.Choice("learning_rate", [1e-2, 1e-3, 1e-4])),
  15. loss="categorical_crossentropy",
  16. metrics=["accuracy"],
  17. )
  18. return model
  1. 使用HyperModel类的子类来
  2. from keras_tuner import HyperModel
  3. class MyHyperModel(HyperModel):
  4. def __init__(self, classes):
  5. self.classes = classes
  6. def build(self, hp):
  7. model = keras.Sequential()
  8. model.add(layers.Flatten())
  9. model.add(
  10. layers.Dense(
  11. units=hp.Int("units", min_value=32, max_value=512, step=32),
  12. activation="relu",
  13. )
  14. )
  15. model.add(layers.Dense(self.classes, activation="softmax"))
  16. model.compile(
  17. optimizer=keras.optimizers.Adam(
  18. hp.Choice("learning_rate", values=[1e-2, 1e-3, 1e-4])
  19. ),
  20. loss="categorical_crossentropy",
  21. metrics=["accuracy"],
  22. )
  23. return model
  24. hypermodel = MyHyperModel(classes=10)
  25. tuner = RandomSearch(
  26. hypermodel,
  27. objective="val_accuracy",
  28. max_trials=3,
  29. overwrite=True,
  30. directory="my_dir",
  31. project_name="helloworld",
  32. )
  33. tuner.search(x_train, y_train, epochs=2, validation_data=(x_val, y_val))

只要重写build方法就能实现很好的模型共享和复用

主要是hp.Int(name,min_value,max_value,step,default)

和hp.choice(name,value(这应该是个可以迭代的对象))

hp = HyperParameters()