通过扩展keras.callbacks.Callback基类来创建一个自定义的回调函数

训练使保留一个列表的批量损失值

  1. class LossHistory(keras.callbacks.Callback):
  2. def on_train_begin(self, logs={}):
  3. self.losses = []
  4. def on_batch_end(self, batch, logs={}):
  5. self.losses.append(logs.get('loss'))
  6. model = Sequential()
  7. model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))
  8. model.add(Activation('softmax'))
  9. model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
  10. history = LossHistory()
  11. model.fit(x_train, y_train, batch_size=128, epochs=20, verbose=0, callbacks=[history])
  12. print(history.losses)