TF Hub上有数百个训练好的模型,这里选择EfficientNet模型,TF Hub上链接地址为:
https://hub.tensorflow.google.cn/google/efficientnet/b0/classification/1
页面显示如下
20211119144417.jpg :::info 注意:TF Hub模型页面上的两个按钮。“Copy URL”为在线调用此模型提供了链接地址;通过点击“Download”按钮,可以将模型保存到本地,离线使用。本节将详细讲解这两种使用方式。 :::

由于选择的EfficientNet模型要求输入的Tensor shape为(96, 96, 3)�,所以我们使用TRAIN_96_FILE和TEST_96_FILE作为训练集和测试集,代码如下

  1. train_set = new AkSourceBatchOp().setFilePath(DATA_DIR + TRAIN_96_FILE);
  2. test_set = new AkSourceBatchOp().setFilePath(DATA_DIR + TEST_96_FILE);

完整的模型结构如下,TF Hub的EfficientNet将作为整体模型的一个“层”,即为结构中的keras_layer。

  1. _________________________________________________________________
  2. Layer (type) Output Shape Param #
  3. =================================================================
  4. tensor (InputLayer) [(None, 96, 96, 3)] 0
  5. _________________________________________________________________
  6. keras_layer (KerasLayer) (None, 1000) 5330564
  7. _________________________________________________________________
  8. flatten (Flatten) (None, 1000) 0
  9. _________________________________________________________________
  10. logits (Dense) (None, 1) 1001
  11. =================================================================
  12. Total params: 5,331,565
  13. Trainable params: 1,001
  14. Non-trainable params: 5,330,564
  15. _________________________________________________________________

26.3.1 在线使用TF Hub模型

直接将模型的URL,作为hub.KerasLayer的参数,完整代码如下:

  1. def efficientnet(train_set, test_set) :
  2. if not(os.path.exists(DATA_DIR + MODEL_EFNET_FILE)):
  3. train_set\
  4. .link(
  5. KerasSequentialClassifierTrainBatchOp()\
  6. .setTensorCol("tensor")\
  7. .setLabelCol("label")\
  8. .setLayers([
  9. "hub.KerasLayer('https://hub.tensorflow.google.cn/google/efficientnet/b0/classification/1')",
  10. "Flatten()"
  11. ])\
  12. .setNumEpochs(5)\
  13. .setIntraOpParallelism(1)\
  14. .setSaveCheckpointsEpochs(0.5)\
  15. .setValidationSplit(0.1)\
  16. .setSaveBestOnly(True)
  17. .setBestMetric("auc")
  18. )\
  19. .link(
  20. AkSinkBatchOp()\
  21. .setFilePath(DATA_DIR + MODEL_EFNET_FILE)
  22. )
  23. BatchOperator.execute()
  24. KerasSequentialClassifierPredictBatchOp()\
  25. .setPredictionCol(PREDICTION_COL)\
  26. .setPredictionDetailCol(PREDICTION_DETAIL_COL)\
  27. .setReservedCols(["relative_path", "label"])\
  28. .linkFrom(
  29. AkSourceBatchOp().setFilePath(DATA_DIR + MODEL_EFNET_FILE),
  30. test_set
  31. )\
  32. .lazyPrint(10)\
  33. .lazyPrintStatistics()\
  34. .link(
  35. EvalBinaryClassBatchOp()\
  36. .setLabelCol("label")\
  37. .setPredictionDetailCol(PREDICTION_DETAIL_COL)\
  38. .lazyPrintMetrics()
  39. )
  40. BatchOperator.execute()

模型评估结果如下,明显优于前面的CNN模型。
image.png

  1. Summary:
  2. | colName|count|missing|sum|mean|variance|min|max|
  3. |-------------|-----|-------|---|----|--------|---|---|
  4. |relative_path| 2500| 0|NaN| NaN| NaN|NaN|NaN|
  5. | label| 2500| 0|NaN| NaN| NaN|NaN|NaN|
  6. | pred| 2500| 0|NaN| NaN| NaN|NaN|NaN|
  7. | pred_info| 2500| 0|NaN| NaN| NaN|NaN|NaN|
  8. -------------------------------- Metrics: --------------------------------
  9. Auc:0.9895 Accuracy:0.9496 Precision:0.9558 Recall:0.9401 F1:0.9478 LogLoss:0.1301
  10. |Pred\Real| dog| cat|
  11. |---------|----|----|
  12. | dog|1145| 53|
  13. | cat| 73|1229|

26.3.2 离线使用TF Hub模型

在TF Hub页面上点击“Download”,将模型下载到本地,文件名为“1.tar”,解压到文件夹(名称为“1”),该文件夹下内容如下图所示,包含两个子文件夹和两个文件。该文件夹对应的路径为:DATA_DIR + “1”
20211118205353.jpg

离线使用TF Hub模型,只需将hub.KerasLayer的参数设为本地保存离线模型的文件夹路径。完整代码如下:

  1. def efficientnet_offline(train_set, test_set) :
  2. if not(os.path.exists(DATA_DIR + MODEL_EFNET_OFFLINE_FILE)):
  3. train_set\
  4. .link(
  5. KerasSequentialClassifierTrainBatchOp()\
  6. .setTensorCol("tensor")\
  7. .setLabelCol("label")\
  8. .setLayers([
  9. "hub.KerasLayer('" + DATA_DIR + "1')",
  10. "Flatten()"
  11. ])\
  12. .setNumEpochs(5)\
  13. .setIntraOpParallelism(1)\
  14. .setSaveCheckpointsEpochs(0.5)\
  15. .setValidationSplit(0.1)\
  16. .setSaveBestOnly(True)\
  17. .setBestMetric("auc")
  18. )\
  19. .link(
  20. AkSinkBatchOp()\
  21. .setFilePath(DATA_DIR + MODEL_EFNET_OFFLINE_FILE)
  22. )
  23. BatchOperator.execute()
  24. KerasSequentialClassifierPredictBatchOp()\
  25. .setPredictionCol(PREDICTION_COL)\
  26. .setPredictionDetailCol(PREDICTION_DETAIL_COL)\
  27. .setReservedCols(["relative_path", "label"])\
  28. .linkFrom(
  29. AkSourceBatchOp().setFilePath(DATA_DIR + MODEL_EFNET_OFFLINE_FILE),
  30. test_set
  31. )\
  32. .lazyPrint(10)\
  33. .lazyPrintStatistics()\
  34. .link(
  35. EvalBinaryClassBatchOp()\
  36. .setLabelCol("label")\
  37. .setPredictionDetailCol(PREDICTION_DETAIL_COL)\
  38. .lazyPrintMetrics()
  39. )
  40. BatchOperator.execute()

模型评估结果如下,与在线使用TF Hub模型的结果一致。
image.png

  1. Summary:
  2. | colName|count|missing|sum|mean|variance|min|max|
  3. |-------------|-----|-------|---|----|--------|---|---|
  4. |relative_path| 2500| 0|NaN| NaN| NaN|NaN|NaN|
  5. | label| 2500| 0|NaN| NaN| NaN|NaN|NaN|
  6. | pred| 2500| 0|NaN| NaN| NaN|NaN|NaN|
  7. | pred_info| 2500| 0|NaN| NaN| NaN|NaN|NaN|
  8. -------------------------------- Metrics: --------------------------------
  9. Auc:0.9912 Accuracy:0.952 Precision:0.9645 Recall:0.936 F1:0.95 LogLoss:0.1209
  10. |Pred\Real| dog| cat|
  11. |---------|----|----|
  12. | dog|1140| 42|
  13. | cat| 78|1240|