TF Hub上有数百个训练好的模型,这里选择EfficientNet模型,TF Hub上链接地址为:
https://hub.tensorflow.google.cn/google/efficientnet/b0/classification/1
页面显示如下
:::info
注意:TF Hub模型页面上的两个按钮。“Copy URL”为在线调用此模型提供了链接地址;通过点击“Download”按钮,可以将模型保存到本地,离线使用。本节将详细讲解这两种使用方式。
:::
由于选择的EfficientNet模型要求输入的Tensor shape为(96, 96, 3)�,所以我们使用TRAIN_96_FILE和TEST_96_FILE作为训练集和测试集,代码如下
train_set = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_96_FILE);
test_set = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_96_FILE);
完整的模型结构如下,TF Hub的EfficientNet将作为整体模型的一个“层”,即为结构中的keras_layer。
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
tensor (InputLayer) [(None, 96, 96, 3)] 0
_________________________________________________________________
keras_layer (KerasLayer) (None, 1000) 5330564
_________________________________________________________________
flatten (Flatten) (None, 1000) 0
_________________________________________________________________
logits (Dense) (None, 1) 1001
=================================================================
Total params: 5,331,565
Trainable params: 1,001
Non-trainable params: 5,330,564
_________________________________________________________________
26.3.1 在线使用TF Hub模型
直接将模型的URL,作为hub.KerasLayer的参数,完整代码如下:
BatchOperator.setParallelism(1);
if (!new File(DATA_DIR + MODEL_EFNET_FILE).exists()) {
train_set
.link(
new KerasSequentialClassifierTrainBatchOp()
.setTensorCol("tensor")
.setLabelCol("label")
.setLayers(
"hub.KerasLayer('https://hub.tensorflow.google.cn/google/efficientnet/b0/classification/1')",
"Flatten()"
)
.setNumEpochs(5)
.setIntraOpParallelism(1)
.setSaveCheckpointsEpochs(0.5)
.setValidationSplit(0.1)
.setSaveBestOnly(true)
.setBestMetric("auc")
)
.link(
new AkSinkBatchOp()
.setFilePath(DATA_DIR + MODEL_EFNET_FILE)
);
BatchOperator.execute();
}
new KerasSequentialClassifierPredictBatchOp()
.setPredictionCol(PREDICTION_COL)
.setPredictionDetailCol(PREDICTION_DETAIL_COL)
.setReservedCols("relative_path", "label")
.linkFrom(
new AkSourceBatchOp().setFilePath(DATA_DIR + MODEL_EFNET_FILE),
test_set
)
.lazyPrint(10)
.lazyPrintStatistics()
.link(
new EvalBinaryClassBatchOp()
.setLabelCol("label")
.setPredictionDetailCol(PREDICTION_DETAIL_COL)
.lazyPrintMetrics()
);
BatchOperator.execute();
模型评估结果如下,明显优于前面的CNN模型。
-------------------------------- Metrics: --------------------------------
Auc:0.9886 Accuracy:0.9492 Precision:0.9612 Recall:0.9376 F1:0.9493 LogLoss:0.1357
|Pred\Real| dog| cat|
|---------|----|----|
| dog|1188| 48|
| cat| 79|1185|
26.3.2 离线使用TF Hub模型
在TF Hub页面上点击“Download”,将模型下载到本地,文件名为“1.tar”,解压到文件夹(名称为“1”),该文件夹下内容如下图所示,包含两个子文件夹和两个文件。该文件夹对应的路径为:DATA_DIR + “1”
离线使用TF Hub模型,只需将hub.KerasLayer的参数设为本地保存离线模型的文件夹路径。完整代码如下:
BatchOperator.setParallelism(1);
if (!new File(DATA_DIR + MODEL_EFNET_OFFLINE_FILE).exists()) {
train_set
.link(
new KerasSequentialClassifierTrainBatchOp()
.setTensorCol("tensor")
.setLabelCol("label")
.setLayers(
"hub.KerasLayer('" + DATA_DIR + "1')",
"Flatten()"
)
.setNumEpochs(5)
.setIntraOpParallelism(1)
.setSaveCheckpointsEpochs(0.5)
.setValidationSplit(0.1)
.setSaveBestOnly(true)
.setBestMetric("auc")
)
.link(
new AkSinkBatchOp()
.setFilePath(DATA_DIR + MODEL_EFNET_OFFLINE_FILE)
);
BatchOperator.execute();
}
new KerasSequentialClassifierPredictBatchOp()
.setPredictionCol(PREDICTION_COL)
.setPredictionDetailCol(PREDICTION_DETAIL_COL)
.setReservedCols("relative_path", "label")
.linkFrom(
new AkSourceBatchOp().setFilePath(DATA_DIR + MODEL_EFNET_OFFLINE_FILE),
test_set
)
.lazyPrint(10)
.lazyPrintStatistics()
.link(
new EvalBinaryClassBatchOp()
.setLabelCol("label")
.setPredictionDetailCol(PREDICTION_DETAIL_COL)
.lazyPrintMetrics()
);
BatchOperator.execute();
模型评估结果如下,与在线使用TF Hub模型的结果一致。
-------------------------------- Metrics: --------------------------------
Auc:0.9889 Accuracy:0.9492 Precision:0.9474 Recall:0.9526 F1:0.95 LogLoss:0.1308
|Pred\Real| dog| cat|
|---------|----|----|
| dog|1207| 67|
| cat| 60|1166|