Java 类名:com.alibaba.alink.pipeline.tensorflow.TFTableModelPredictor
Python 类名:TFTableModelPredictor

功能介绍

TFTableModelTrainer 或者 TF2TableModelTrainer 调用 fit 方法产生的模型,可以进行预测。

参数说明

名称 中文名称 描述 类型 是否必须? 取值范围 默认值
outputSchemaStr Schema Schema。格式为”colname coltype[, colname2, coltype2[, …]]”,例如 “f0 string, f1 bigint, f2 double” String
graphDefTag graph标签 graph标签 String “serve”
inputSignatureDefs 输入 SignatureDef SavedModel 模型的输入 SignatureDef 名,用逗号分隔,需要与输入列一一对应,默认与选择列相同 String[] null
intraOpParallelism Op 间并发度 Op 间并发度 Integer 4
modelFilePath 模型的文件路径 模型的文件路径 String null
outputSignatureDefs TF 输出 SignatureDef 名 模型的输出 SignatureDef 名,多个输出时用逗号分隔,并且与输出 Schema 一一对应,默认与输出 Schema 中的列名相同 String[] null
overwriteSink 是否覆写已有数据 是否覆写已有数据 Boolean false
reservedCols 算法保留列名 算法保留列 String[] null
selectedCols 选中的列名数组 计算列对应的列名列表 String[] null
signatureDefKey signature标签 signature标签 String “serving_default”
modelStreamFilePath 模型流的文件路径 模型流的文件路径 String null
modelStreamScanInterval 扫描模型路径的时间间隔 描模型路径的时间间隔,单位秒 Integer 10
modelStreamStartTime 模型流的起始时间 模型流的起始时间。默认从当前时刻开始读。使用yyyy-mm-dd hh:mm:ss.fffffffff格式,详见Timestamp.valueOf(String s) String null

脚本路径说明

脚本路径可以是以下形式:

  • 本地文件:file:// 加绝对路径,例如 file:///tmp/dnn.py
  • Java 包中的资源文件:res:// 加路径,例如 res:///dnn.py
  • http/https 文件:http://https:// 路径;
  • OSS 文件:oss:// 加路径和 Endpoint 和 access key 等信息,例如oss://bucket/xxx/xxx/xxx.py?host=xxx&access_key_id=xxx&access_key_secret=xxx
  • HDFS 文件:hdfs:// 加路径;

    代码示例

    以下代码仅用于示意,可能需要修改部分代码或者配置环境后才能正常运行!

    Python 代码

    ``` import json

source = RandomTableSourceBatchOp() \ .setNumRows(100) \ .setNumCols(10)

colNames = source.getColNames() source = source.select(“*, case when RAND() > 0.5 then 1. else 0. end as label”) label = “label”

userParams = { ‘featureCols’: json.dumps(colNames), ‘labelCol’: label, ‘batch_size’: 16, ‘num_epochs’: 1 }

trainer = TF2TableModelTrainer() \ .setUserFiles([“https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/tf_dnn_train.py“]) \ .setMainScriptFile(“https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/tf_dnn_train.py“) \ .setUserParams(json.dumps(userParams)) \ .setOutputSchemaStr(“logits double”) \ .setOutputSignatureDefs([“logits”]) \ .setSignatureDefKey(“predict”) \ .setInferSelectedCols(colNames) model = trainer.fit(source) model.transform(source).print()

  1. ### Java 代码

import com.alibaba.alink.common.utils.JsonConverter; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.RandomTableSourceBatchOp; import com.alibaba.alink.pipeline.tensorflow.TF2TableModelTrainer; import com.alibaba.alink.pipeline.tensorflow.TFTableModelPredictor; import org.junit.Test;

import java.util.HashMap; import java.util.Map;

public class TF2TableModelTrainerTest {

  1. @Test
  2. public void testTF2TableModelTrainer() throws Exception {
  3. BatchOperator.setParallelism(3);
  4. BatchOperator<?> source = new RandomTableSourceBatchOp()
  5. .setNumRows(100L)
  6. .setNumCols(10);
  7. String[] colNames = source.getColNames();
  8. source = source.select("*, case when RAND() > 0.5 then 1. else 0. end as label");
  9. String label = "label";
  10. Map <String, Object> userParams = new HashMap <>();
  11. userParams.put("featureCols", JsonConverter.toJson(colNames));
  12. userParams.put("labelCol", label);
  13. userParams.put("batch_size", 16);
  14. userParams.put("num_epochs", 1);
  15. TF2TableModelTrainer trainer = new TF2TableModelTrainer()
  16. .setUserFiles(new String[] {"res:///tf_dnn_train.py"})
  17. .setMainScriptFile("res:///tf_dnn_train.py")
  18. .setUserParams(JsonConverter.toJson(userParams))
  19. .setNumWorkers(2)
  20. .setNumPSs(1)
  21. .setOutputSchemaStr("logits double")
  22. .setOutputSignatureDefs(new String[]{"logits"})
  23. .setSignatureDefKey("predict")
  24. .setInferSelectedCols(colNames);
  25. TFTableModelPredictor model = trainer.fit(source);
  26. model.transform(source).print();
  27. }

}

```