Java 类名:com.alibaba.alink.operator.batch.tensorflow.TFSavedModelPredictBatchOp
Python 类名:TFSavedModelPredictBatchOp

功能介绍

该组件支持直接使用 SavedModel 进行预测。
模型路径需要时一个压缩文件,解压后能得到一个目录,目录内包含 SavedModel 的文件。

参数说明

名称 中文名称 描述 类型 是否必须? 取值范围 默认值
modelPath 模型的URL路径 模型的URL路径 String
outputSchemaStr Schema Schema。格式为”colname coltype[, colname2, coltype2[, …]]”,例如 “f0 string, f1 bigint, f2 double” String
graphDefTag graph标签 graph标签 String “serve”
inputSignatureDefs 输入 SignatureDef SavedModel 模型的输入 SignatureDef 名,用逗号分隔,需要与输入列一一对应,默认与选择列相同 String[] null
intraOpParallelism Op 间并发度 Op 间并发度 Integer 4
outputSignatureDefs TF 输出 SignatureDef 名 模型的输出 SignatureDef 名,多个输出时用逗号分隔,并且与输出 Schema 一一对应,默认与输出 Schema 中的列名相同 String[] null
reservedCols 算法保留列名 算法保留列 String[] null
selectedCols 选中的列名数组 计算列对应的列名列表 String[] null
signatureDefKey signature标签 signature标签 String “serving_default”

模型路径说明

模型路径可以是以下形式:

  • 本地文件:file:// 加绝对路径,例如 file:///tmp/dnn.py
  • Java 包中的资源文件:res:// 加路径,例如 res:///dnn.py
  • http/https 文件:http://https:// 路径;
  • OSS 文件:oss:// 加路径和 Endpoint 和 access key 等信息,例如oss://bucket/xxx/xxx/xxx.py?host=xxx&access_key_id=xxx&access_key_secret=xxx
  • HDFS 文件:hdfs:// 加路径;

    代码示例

    以下代码仅用于示意,可能需要修改部分代码或者配置环境后才能正常运行!

    Python 代码

    1. test = AkSourceBatchOp()\
    2. .setFilePath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/mnist_test_vector.ak");
    3. test = VectorToTensorBatchOp()\
    4. .setTensorDataType("float")\
    5. .setTensorShape([1, 28, 28, 1])\
    6. .setSelectedCol("vec")\
    7. .setOutputCol("tensor")\
    8. .setReservedCols(["label"])\
    9. .linkFrom(test)
    10. predictor = TFSavedModelPredictBatchOp()\
    11. .setModelPath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/mnist_model_tf.zip")\
    12. .setSelectedCols(["tensor"])\
    13. .setInputSignatureDefs(["input_1"])\
    14. .setOutputSignatureDefs(["output_1"])\
    15. .setOutputSchemaStr("probabilities FLOAT_TENSOR")
    16. test = predictor.linkFrom(test).select("label, probabilities")
    17. test.print()

    Java 代码

    1. import com.alibaba.alink.operator.batch.BatchOperator;
    2. import com.alibaba.alink.operator.batch.dataproc.VectorToTensorBatchOp;
    3. import com.alibaba.alink.operator.batch.source.AkSourceBatchOp;
    4. import com.alibaba.alink.operator.batch.tensorflow.TFSavedModelPredictBatchOp;
    5. import org.junit.Test;
    6. public class TFSavedModelPredictBatchOpTest {
    7. @Test
    8. public void testTFSavedModelPredictBatchOp() throws Exception {
    9. BatchOperator.setParallelism(1);
    10. BatchOperator <?> test = new AkSourceBatchOp()
    11. .setFilePath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/mnist_test_vector.ak");
    12. test = new VectorToTensorBatchOp()
    13. .setTensorDataType("float")
    14. .setTensorShape(1, 28, 28, 1)
    15. .setSelectedCol("vec")
    16. .setOutputCol("tensor")
    17. .setReservedCols("label")
    18. .linkFrom(test);
    19. BatchOperator <?> predictor = new TFSavedModelPredictBatchOp()
    20. .setModelPath("https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/mnist_model_tf.zip")
    21. .setSelectedCols("tensor")
    22. .setInputSignatureDefs(new String[] {"input_1"})
    23. .setOutputSignatureDefs(new String[] {"output_1"})
    24. .setOutputSchemaStr("probabilities FLOAT_TENSOR");
    25. test = predictor.linkFrom(test).select("label, probabilities");
    26. test.print();
    27. }
    28. }