Demo code

  1. package benchmark.online;
  2. import org.apache.flink.api.common.typeinfo.TypeInformation;
  3. import org.apache.flink.api.java.tuple.Tuple2;
  4. import org.apache.flink.api.java.typeutils.RowTypeInfo;
  5. import org.apache.flink.table.functions.TableFunction;
  6. import org.apache.flink.types.Row;
  7. import com.alibaba.alink.common.AlinkGlobalConfiguration;
  8. import com.alibaba.alink.common.AlinkTypes;
  9. import com.alibaba.alink.common.MLEnvironmentFactory;
  10. import com.alibaba.alink.common.utils.TableUtil;
  11. import com.alibaba.alink.operator.batch.BatchOperator;
  12. import com.alibaba.alink.operator.batch.classification.LogisticRegressionTrainBatchOp;
  13. import com.alibaba.alink.operator.batch.sink.AkSinkBatchOp;
  14. import com.alibaba.alink.operator.batch.sink.AppendModelStreamFileSinkBatchOp;
  15. import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
  16. import com.alibaba.alink.operator.stream.StreamOperator;
  17. import com.alibaba.alink.operator.stream.classification.LogisticRegressionPredictStreamOp;
  18. import com.alibaba.alink.operator.stream.dataproc.JsonValueStreamOp;
  19. import com.alibaba.alink.operator.stream.dataproc.SplitStreamOp;
  20. import com.alibaba.alink.operator.stream.evaluation.EvalBinaryClassStreamOp;
  21. import com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp;
  22. import com.alibaba.alink.operator.stream.sink.ModelStreamFileSinkStreamOp;
  23. import com.alibaba.alink.operator.stream.source.CsvSourceStreamOp;
  24. import com.alibaba.alink.pipeline.LocalPredictor;
  25. import com.alibaba.alink.pipeline.Pipeline;
  26. import com.alibaba.alink.pipeline.PipelineModel;
  27. import com.alibaba.alink.pipeline.classification.LogisticRegression;
  28. import com.alibaba.alink.pipeline.dataproc.StandardScaler;
  29. import com.alibaba.alink.pipeline.feature.FeatureHasher;
  30. import org.apache.commons.lang3.ArrayUtils;
  31. import org.junit.Test;
  32. /**
  33. * https://www.kaggle.com/c/avazu-ctr-prediction/data
  34. */
  35. public class FtrlTest {
  36. private static final String[] ORIGIN_COL_NAMES = new String[] {
  37. "id", "click", "dt", "C1", "banner_pos",
  38. "site_id", "site_domain", "site_category", "app_id", "app_domain",
  39. "app_category", "device_id", "device_ip", "device_model", "device_type",
  40. "device_conn_type", "C14", "C15", "C16", "C17",
  41. "C18", "C19", "C20", "C21"
  42. };
  43. private static final String[] ORIGIN_COL_TYPES = new String[] {
  44. "string", "string", "string", "string", "int",
  45. "string", "string", "string", "string", "string",
  46. "string", "string", "string", "string", "string",
  47. "string", "int", "int", "int", "int",
  48. "int", "int", "int", "int"
  49. };
  50. private static final String[] COL_NAMES = new String[] {
  51. "id", "click",
  52. "dt_year", "dt_month", "dt_day", "dt_hour",
  53. "C1", "banner_pos",
  54. "site_id", "site_domain", "site_category", "app_id", "app_domain",
  55. "app_category", "device_id", "device_ip", "device_model", "device_type",
  56. "device_conn_type", "C14", "C15", "C16", "C17",
  57. "C18", "C19", "C20", "C21"
  58. };
  59. private static final String DATA_DIR = "https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/";
  60. private static final String SMALL_FILE = "avazu-small.csv";
  61. private static final String LARGE_FILE = "avazu-ctr-train-8M.csv";
  62. private static final String FEATURE_PIPELINE_MODEL_FILE = "/tmp/feature_model.csv";
  63. private static final String labelColName = "click";
  64. private static final String vecColName = "vec";
  65. static final String[] FEATURE_COL_NAMES =
  66. ArrayUtils.removeElements(COL_NAMES, labelColName, "id", "dt_year", "dt_month",
  67. "site_id", "site_domain", "app_id", "device_id", "device_ip", "device_model");
  68. static final String[] HIGH_FREQ_FEATURE_COL_NAMES = new String[] {"site_id", "site_domain", "device_id",
  69. "device_model"};
  70. static final String[] CATEGORY_FEATURE_COL_NAMES = new String[] {
  71. "C1", "banner_pos",
  72. "site_category", "app_domain",
  73. "app_category", "device_type",
  74. "device_conn_type"
  75. };
  76. static final String[] NUMERICAL_FEATURE_COL_NAMES =
  77. ArrayUtils.removeElements(FEATURE_COL_NAMES, CATEGORY_FEATURE_COL_NAMES);
  78. @Test
  79. public void trainFeatureModel() throws Exception {
  80. MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().setParallelism(4);
  81. MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(1);
  82. int numHashFeatures = 30000;
  83. Pipeline feature_pipeline = new Pipeline()
  84. .add(
  85. new StandardScaler()
  86. .setSelectedCols(NUMERICAL_FEATURE_COL_NAMES)
  87. )
  88. .add(
  89. new FeatureHasher()
  90. .setSelectedCols(ArrayUtils.addAll(FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES))
  91. .setCategoricalCols(ArrayUtils.addAll(CATEGORY_FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES))
  92. .setOutputCol(vecColName)
  93. .setNumFeatures(numHashFeatures).setReservedCols("click")
  94. );
  95. feature_pipeline.fit(getSmallBatchSet()).save(FEATURE_PIPELINE_MODEL_FILE,
  96. true);
  97. BatchOperator.execute();
  98. }
  99. @Test
  100. public void onlineTrainAndEval() throws Exception {
  101. PipelineModel featurePipelineModel = PipelineModel.load(FEATURE_PIPELINE_MODEL_FILE);
  102. AlinkGlobalConfiguration.setPrintProcessInfo(true);
  103. Tuple2 <StreamOperator, StreamOperator> sources = getStreamTrainTestData();
  104. StreamOperator <?> trainStream = sources.f0;
  105. StreamOperator <?> testStream = sources.f1;
  106. trainStream = featurePipelineModel.transform(trainStream);
  107. testStream = featurePipelineModel.transform(testStream);
  108. BatchOperator <?> trainBatch = featurePipelineModel.transform(getSmallBatchSet());
  109. StreamOperator.setParallelism(2);
  110. BatchOperator <?> model = new LogisticRegressionTrainBatchOp()
  111. .setVectorCol(vecColName)
  112. .setLabelCol(labelColName)
  113. .setWithIntercept(true)
  114. .linkFrom(trainBatch);
  115. StreamOperator <?> models = new FtrlTrainStreamOp(model)
  116. .setVectorCol(vecColName)
  117. .setLabelCol(labelColName)
  118. .setMiniBatchSize(1024)
  119. .setTimeInterval(10)
  120. .setWithIntercept(true)
  121. .setModelStreamFilePath("/tmp/avazu_fm_models")
  122. .linkFrom(trainStream);
  123. StreamOperator <?> predictResults = new LogisticRegressionPredictStreamOp(model)
  124. .setPredictionCol("predict")
  125. .setReservedCols(labelColName)
  126. .setPredictionDetailCol("details")
  127. .linkFrom(testStream, models);
  128. new EvalBinaryClassStreamOp()
  129. .setPredictionDetailCol("details").setLabelCol(labelColName).setTimeInterval(10).linkFrom(predictResults)
  130. .link(new JsonValueStreamOp().setSelectedCol("Data")
  131. .setReservedCols(new String[] {"Statistics"})
  132. .setOutputCols(new String[] {"Accuracy", "AUC", "ConfusionMatrix"})
  133. .setJsonPath("$.Accuracy", "$.AUC", "ConfusionMatrix")).print();
  134. StreamOperator.execute();
  135. }
  136. @Test
  137. public void onlineTrainAndSave() throws Exception {
  138. PipelineModel featurePipelineModel = PipelineModel.load(FEATURE_PIPELINE_MODEL_FILE);
  139. AlinkGlobalConfiguration.setPrintProcessInfo(true);
  140. Tuple2 <StreamOperator, StreamOperator> sources = getStreamTrainTestData();
  141. StreamOperator <?> trainStream = sources.f0;
  142. trainStream = featurePipelineModel.transform(trainStream);
  143. BatchOperator <?> trainBatch = featurePipelineModel.transform(getSmallBatchSet());
  144. StreamOperator.setParallelism(2);
  145. BatchOperator <?> model = new LogisticRegressionTrainBatchOp()
  146. .setVectorCol(vecColName)
  147. .setLabelCol(labelColName)
  148. .setWithIntercept(true)
  149. .linkFrom(trainBatch);
  150. StreamOperator <?> models = new FtrlTrainStreamOp(model)
  151. .setVectorCol(vecColName)
  152. .setLabelCol(labelColName)
  153. .setMiniBatchSize(1024)
  154. .setTimeInterval(10)
  155. .setWithIntercept(true)
  156. .setModelStreamFilePath("/tmp/rebase_ftrl_models")
  157. .linkFrom(trainStream);
  158. models.link(new ModelStreamFileSinkStreamOp().setFilePath("/tmp/ftrl_models"));
  159. StreamOperator.execute();
  160. }
  161. @Test
  162. public void BatchTrainAndSaveRebaseModel() throws Exception {
  163. PipelineModel featurePipelineModel = PipelineModel.load(FEATURE_PIPELINE_MODEL_FILE);
  164. BatchOperator <?> trainBatch = featurePipelineModel.transform(getSmallBatchSet());
  165. StreamOperator.setParallelism(2);
  166. BatchOperator <?> model1 = new LogisticRegressionTrainBatchOp()
  167. .setVectorCol(vecColName)
  168. .setLabelCol(labelColName)
  169. .setWithIntercept(true)
  170. .linkFrom(trainBatch);
  171. model1.link(new AppendModelStreamFileSinkBatchOp().setFilePath("/tmp/rebase_ftrl_models"));
  172. BatchOperator.execute();
  173. }
  174. @Test
  175. public void savePipelineModel() throws Exception {
  176. BatchOperator <?> trainBatch = getSmallBatchSet();
  177. int numHashFeatures = 30000;
  178. PipelineModel pipelineModel = new Pipeline()
  179. .add(
  180. new StandardScaler()
  181. .setSelectedCols(NUMERICAL_FEATURE_COL_NAMES)
  182. )
  183. .add(
  184. new FeatureHasher()
  185. .setSelectedCols(ArrayUtils.addAll(FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES))
  186. .setCategoricalCols(ArrayUtils.addAll(CATEGORY_FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES))
  187. .setOutputCol(vecColName)
  188. .setNumFeatures(numHashFeatures).setReservedCols("click")
  189. ).add(
  190. new LogisticRegression()
  191. .setVectorCol("vec")
  192. .setLabelCol("click")
  193. .setPredictionCol("pred")
  194. .setModelStreamFilePath("/tmp/ftrl_models")
  195. .setPredictionDetailCol("detail")
  196. .setMaxIter(10))
  197. .fit(trainBatch);
  198. pipelineModel.save().link(new AkSinkBatchOp().setOverwriteSink(true).setFilePath("/tmp/lr_pipeline.ak"));
  199. BatchOperator.execute();
  200. }
  201. @Test
  202. public void localPredictor() throws Exception {
  203. LocalPredictor predictor = new LocalPredictor("/tmp/lr_pipeline.ak",
  204. TableUtil.schema2SchemaStr(getSmallBatchSet().getSchema()));
  205. System.out.println(TableUtil.schema2SchemaStr(getSmallBatchSet().getSchema()));
  206. for (int i = 0; i < Integer.MAX_VALUE; ++i) {
  207. System.out.println(predictor.map(
  208. Row.of("220869541682524752", "0", 14, 10, 21, 2, "1005", 0, "1fbe01fe", "f3845767",
  209. "28905ebd", "ecad2386", "7801e8d9", "07d7df22", "a99f214a", "af1c0727", "a0f5f879", "1", "0",
  210. 15703, 320, 50, 1722, 0, 35, -1, 79)));
  211. Thread.sleep(5000);
  212. }
  213. }
  214. public static class SplitDataTime extends TableFunction <Row> {
  215. private Integer parseInt(String s) {
  216. if ('0' == s.charAt(0)) {
  217. return Integer.parseInt(s.substring(1));
  218. } else {
  219. return Integer.parseInt(s);
  220. }
  221. }
  222. public void eval(String str) {
  223. collect(Row.of(
  224. parseInt(str.substring(0, 2)),
  225. parseInt(str.substring(2, 4)),
  226. parseInt(str.substring(4, 6)),
  227. parseInt(str.substring(6, 8))
  228. ));
  229. }
  230. @Override
  231. public TypeInformation <Row> getResultType() {
  232. return new RowTypeInfo(AlinkTypes.INT, AlinkTypes.INT, AlinkTypes.INT, AlinkTypes.INT);
  233. }
  234. }
  235. private static Tuple2 <StreamOperator, StreamOperator> getStreamTrainTestData() {
  236. StringBuilder sbd = new StringBuilder();
  237. for (int i = 0; i < ORIGIN_COL_NAMES.length; i++) {
  238. if (i > 0) {
  239. sbd.append(",");
  240. }
  241. sbd.append(ORIGIN_COL_NAMES[i]).append(" ").append(ORIGIN_COL_TYPES[i]);
  242. }
  243. StreamOperator <?> source = new CsvSourceStreamOp()
  244. .setFilePath(DATA_DIR + FtrlRebaseTest.LARGE_FILE)
  245. .setSchemaStr(sbd.toString())
  246. .setIgnoreFirstLine(true)
  247. .udtf("dt", new String[] {"dt_year", "dt_month", "dt_day", "dt_hour"}, new SplitDataTime())
  248. .select(COL_NAMES);
  249. SplitStreamOp splitter = new SplitStreamOp().setFraction(0.5);
  250. source.link(splitter);
  251. return new Tuple2 <>(splitter, splitter.getSideOutput(0));
  252. }
  253. private static BatchOperator <?> getSmallBatchSet() {
  254. StringBuilder sbd = new StringBuilder();
  255. for (int i = 0; i < ORIGIN_COL_NAMES.length; i++) {
  256. if (i > 0) {
  257. sbd.append(",");
  258. }
  259. sbd.append(ORIGIN_COL_NAMES[i]).append(" ").append(ORIGIN_COL_TYPES[i]);
  260. }
  261. return new CsvSourceBatchOp()
  262. .setFilePath(DATA_DIR + FtrlRebaseTest.SMALL_FILE)
  263. .setSchemaStr(sbd.toString())
  264. .setIgnoreFirstLine(true)
  265. .udtf("dt", new String[] {"dt_year", "dt_month", "dt_day", "dt_hour"}, new SplitDataTime())
  266. .select(COL_NAMES);
  267. }
  268. }

Demo 功能介绍

该demo使用Ftrl 算法对Avazu 数据(https://www.kaggle.com/c/avazu-ctr-prediction/data)进行实时训练并生成模型流,并将模型流实时加载到推理服务中。另外我们还增加了模型rebase 的示例代码,能够很容易的完成用一个批模型定时重新拉回模型,防止模型跑偏。最后还提供了一个模型训练+预测+评估的示例代码。

函数说明

函数 任务类型 说明
trainFeatureModel() 批任务 训练特征工程模型,这个模型将对训练、预测、推理数据进行特征编码
�savePipelineModel() 批任务 训练PipelineModel,该模型是部署到线上服务的模型
�onlineTrainAndSave() 流任务 使用Ftrl实时训练模型,并定时将模型写出到指定目录
�BatchTrainAndSaveRebaseModel() 批任务 训练用来重新初始化的模型,用来拉回模型,防止模型跑偏
� localPredict() 本地推理 本地搭建一个服务,对同一条样本预测,用来验证模型更新
�onlineTrainAndEval() 流任务 Ftrl 训练模型,并对模型进行预测评估

执行步骤

  • 首先我们需要执行 trainFeatureModel() 函数,生成特征工程模型,并存储到目录“/tmp/feature_model.csv�”,后面的函数都需要该模型
  • 第二步,执行savePipelineModel() 函数,生成部署到线上服务的模型,该模型会通过 setModelStreamFilePath(“/tmp/ftrl_models”) 设置模型流的目录,设置完再部署这个模型时,在推理的同时会实时监控这个目录,当有新模型产生时会自动加载新模型,用最新模型进行推理。
  • 第三步,执行localPredict() 函数,将第二步产生的模型部署成本地服务。
  • 第四步,执行�onlineTrainAndSave(),用Ftrl算法实时训练在线模型,并以固定频率输出到目录”/tmp/ftrl_models”,这个目录与第二步的目录是同一个目录。另外这一步还要设置一个目录setModelStreamFilePath(“/tmp/rebase_ftrl_models”),这个目录是用来做模型rebase的,在线学习过程中会实时监控这个目录,当有新的模型出现在这个目录中,会重新加载这个模型作为base模型继续进行训练。
  • 第五步,隔段时间(1小时 or 1天)执行函数BatchTrainAndSaveRebaseModel(),这个函数将训练一个新的批模型,并将其写入到rebase目录”/tmp/rebase_ftrl_models”,这样,第四步中的函数就会监测到这个模型,并进行模型 rebase。
  • onlineTrainAndEval() 函数是一个独立的函数,用来评估在线学习算法生成的模型怎么样,并打印评估结果。

    备注

  • 特征工程这里我们对数据做标准化和FeatureHash,其实这里可以使用任何其他Alink 的特征工程算法,类似OneHot编码,GBDT编码,多热编码,归一化、分桶算法等

  • 实时加载在线训练的模型是通过一个给定的文件系统的目录(”/tmp/ftrl_models”)完成的,这个目录可以是本地目录,也可以是网络文件系统的目录,类似OSS 和HDSF等。
  • 模型重新初始化和在线模型的加载是类似的,也是通过一个给定的文件系统的目录(”/tmp/rebase_ftrl_models”)完成的,同样,这个目录可以是本地目录,也可以是网络文件系统的目录,类似OSS 和HDSF等。