Demo code
package benchmark.online;import org.apache.flink.api.common.typeinfo.TypeInformation;import org.apache.flink.api.java.tuple.Tuple2;import org.apache.flink.api.java.typeutils.RowTypeInfo;import org.apache.flink.table.functions.TableFunction;import org.apache.flink.types.Row;import com.alibaba.alink.common.AlinkGlobalConfiguration;import com.alibaba.alink.common.AlinkTypes;import com.alibaba.alink.common.MLEnvironmentFactory;import com.alibaba.alink.common.utils.TableUtil;import com.alibaba.alink.operator.batch.BatchOperator;import com.alibaba.alink.operator.batch.classification.LogisticRegressionTrainBatchOp;import com.alibaba.alink.operator.batch.sink.AkSinkBatchOp;import com.alibaba.alink.operator.batch.sink.AppendModelStreamFileSinkBatchOp;import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;import com.alibaba.alink.operator.stream.StreamOperator;import com.alibaba.alink.operator.stream.classification.LogisticRegressionPredictStreamOp;import com.alibaba.alink.operator.stream.dataproc.JsonValueStreamOp;import com.alibaba.alink.operator.stream.dataproc.SplitStreamOp;import com.alibaba.alink.operator.stream.evaluation.EvalBinaryClassStreamOp;import com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp;import com.alibaba.alink.operator.stream.sink.ModelStreamFileSinkStreamOp;import com.alibaba.alink.operator.stream.source.CsvSourceStreamOp;import com.alibaba.alink.pipeline.LocalPredictor;import com.alibaba.alink.pipeline.Pipeline;import com.alibaba.alink.pipeline.PipelineModel;import com.alibaba.alink.pipeline.classification.LogisticRegression;import com.alibaba.alink.pipeline.dataproc.StandardScaler;import com.alibaba.alink.pipeline.feature.FeatureHasher;import org.apache.commons.lang3.ArrayUtils;import org.junit.Test;/*** https://www.kaggle.com/c/avazu-ctr-prediction/data*/public class FtrlTest {private static final String[] ORIGIN_COL_NAMES = new String[] {"id", "click", "dt", "C1", "banner_pos","site_id", "site_domain", "site_category", "app_id", "app_domain","app_category", "device_id", "device_ip", "device_model", "device_type","device_conn_type", "C14", "C15", "C16", "C17","C18", "C19", "C20", "C21"};private static final String[] ORIGIN_COL_TYPES = new String[] {"string", "string", "string", "string", "int","string", "string", "string", "string", "string","string", "string", "string", "string", "string","string", "int", "int", "int", "int","int", "int", "int", "int"};private static final String[] COL_NAMES = new String[] {"id", "click","dt_year", "dt_month", "dt_day", "dt_hour","C1", "banner_pos","site_id", "site_domain", "site_category", "app_id", "app_domain","app_category", "device_id", "device_ip", "device_model", "device_type","device_conn_type", "C14", "C15", "C16", "C17","C18", "C19", "C20", "C21"};private static final String DATA_DIR = "https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/";private static final String SMALL_FILE = "avazu-small.csv";private static final String LARGE_FILE = "avazu-ctr-train-8M.csv";private static final String FEATURE_PIPELINE_MODEL_FILE = "/tmp/feature_model.csv";private static final String labelColName = "click";private static final String vecColName = "vec";static final String[] FEATURE_COL_NAMES =ArrayUtils.removeElements(COL_NAMES, labelColName, "id", "dt_year", "dt_month","site_id", "site_domain", "app_id", "device_id", "device_ip", "device_model");static final String[] HIGH_FREQ_FEATURE_COL_NAMES = new String[] {"site_id", "site_domain", "device_id","device_model"};static final String[] CATEGORY_FEATURE_COL_NAMES = new String[] {"C1", "banner_pos","site_category", "app_domain","app_category", "device_type","device_conn_type"};static final String[] NUMERICAL_FEATURE_COL_NAMES =ArrayUtils.removeElements(FEATURE_COL_NAMES, CATEGORY_FEATURE_COL_NAMES);@Testpublic void trainFeatureModel() throws Exception {MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().setParallelism(4);MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(1);int numHashFeatures = 30000;Pipeline feature_pipeline = new Pipeline().add(new StandardScaler().setSelectedCols(NUMERICAL_FEATURE_COL_NAMES)).add(new FeatureHasher().setSelectedCols(ArrayUtils.addAll(FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES)).setCategoricalCols(ArrayUtils.addAll(CATEGORY_FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES)).setOutputCol(vecColName).setNumFeatures(numHashFeatures).setReservedCols("click"));feature_pipeline.fit(getSmallBatchSet()).save(FEATURE_PIPELINE_MODEL_FILE,true);BatchOperator.execute();}@Testpublic void onlineTrainAndEval() throws Exception {PipelineModel featurePipelineModel = PipelineModel.load(FEATURE_PIPELINE_MODEL_FILE);AlinkGlobalConfiguration.setPrintProcessInfo(true);Tuple2 <StreamOperator, StreamOperator> sources = getStreamTrainTestData();StreamOperator <?> trainStream = sources.f0;StreamOperator <?> testStream = sources.f1;trainStream = featurePipelineModel.transform(trainStream);testStream = featurePipelineModel.transform(testStream);BatchOperator <?> trainBatch = featurePipelineModel.transform(getSmallBatchSet());StreamOperator.setParallelism(2);BatchOperator <?> model = new LogisticRegressionTrainBatchOp().setVectorCol(vecColName).setLabelCol(labelColName).setWithIntercept(true).linkFrom(trainBatch);StreamOperator <?> models = new FtrlTrainStreamOp(model).setVectorCol(vecColName).setLabelCol(labelColName).setMiniBatchSize(1024).setTimeInterval(10).setWithIntercept(true).setModelStreamFilePath("/tmp/avazu_fm_models").linkFrom(trainStream);StreamOperator <?> predictResults = new LogisticRegressionPredictStreamOp(model).setPredictionCol("predict").setReservedCols(labelColName).setPredictionDetailCol("details").linkFrom(testStream, models);new EvalBinaryClassStreamOp().setPredictionDetailCol("details").setLabelCol(labelColName).setTimeInterval(10).linkFrom(predictResults).link(new JsonValueStreamOp().setSelectedCol("Data").setReservedCols(new String[] {"Statistics"}).setOutputCols(new String[] {"Accuracy", "AUC", "ConfusionMatrix"}).setJsonPath("$.Accuracy", "$.AUC", "ConfusionMatrix")).print();StreamOperator.execute();}@Testpublic void onlineTrainAndSave() throws Exception {PipelineModel featurePipelineModel = PipelineModel.load(FEATURE_PIPELINE_MODEL_FILE);AlinkGlobalConfiguration.setPrintProcessInfo(true);Tuple2 <StreamOperator, StreamOperator> sources = getStreamTrainTestData();StreamOperator <?> trainStream = sources.f0;trainStream = featurePipelineModel.transform(trainStream);BatchOperator <?> trainBatch = featurePipelineModel.transform(getSmallBatchSet());StreamOperator.setParallelism(2);BatchOperator <?> model = new LogisticRegressionTrainBatchOp().setVectorCol(vecColName).setLabelCol(labelColName).setWithIntercept(true).linkFrom(trainBatch);StreamOperator <?> models = new FtrlTrainStreamOp(model).setVectorCol(vecColName).setLabelCol(labelColName).setMiniBatchSize(1024).setTimeInterval(10).setWithIntercept(true).setModelStreamFilePath("/tmp/rebase_ftrl_models").linkFrom(trainStream);models.link(new ModelStreamFileSinkStreamOp().setFilePath("/tmp/ftrl_models"));StreamOperator.execute();}@Testpublic void BatchTrainAndSaveRebaseModel() throws Exception {PipelineModel featurePipelineModel = PipelineModel.load(FEATURE_PIPELINE_MODEL_FILE);BatchOperator <?> trainBatch = featurePipelineModel.transform(getSmallBatchSet());StreamOperator.setParallelism(2);BatchOperator <?> model1 = new LogisticRegressionTrainBatchOp().setVectorCol(vecColName).setLabelCol(labelColName).setWithIntercept(true).linkFrom(trainBatch);model1.link(new AppendModelStreamFileSinkBatchOp().setFilePath("/tmp/rebase_ftrl_models"));BatchOperator.execute();}@Testpublic void savePipelineModel() throws Exception {BatchOperator <?> trainBatch = getSmallBatchSet();int numHashFeatures = 30000;PipelineModel pipelineModel = new Pipeline().add(new StandardScaler().setSelectedCols(NUMERICAL_FEATURE_COL_NAMES)).add(new FeatureHasher().setSelectedCols(ArrayUtils.addAll(FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES)).setCategoricalCols(ArrayUtils.addAll(CATEGORY_FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES)).setOutputCol(vecColName).setNumFeatures(numHashFeatures).setReservedCols("click")).add(new LogisticRegression().setVectorCol("vec").setLabelCol("click").setPredictionCol("pred").setModelStreamFilePath("/tmp/ftrl_models").setPredictionDetailCol("detail").setMaxIter(10)).fit(trainBatch);pipelineModel.save().link(new AkSinkBatchOp().setOverwriteSink(true).setFilePath("/tmp/lr_pipeline.ak"));BatchOperator.execute();}@Testpublic void localPredictor() throws Exception {LocalPredictor predictor = new LocalPredictor("/tmp/lr_pipeline.ak",TableUtil.schema2SchemaStr(getSmallBatchSet().getSchema()));System.out.println(TableUtil.schema2SchemaStr(getSmallBatchSet().getSchema()));for (int i = 0; i < Integer.MAX_VALUE; ++i) {System.out.println(predictor.map(Row.of("220869541682524752", "0", 14, 10, 21, 2, "1005", 0, "1fbe01fe", "f3845767","28905ebd", "ecad2386", "7801e8d9", "07d7df22", "a99f214a", "af1c0727", "a0f5f879", "1", "0",15703, 320, 50, 1722, 0, 35, -1, 79)));Thread.sleep(5000);}}public static class SplitDataTime extends TableFunction <Row> {private Integer parseInt(String s) {if ('0' == s.charAt(0)) {return Integer.parseInt(s.substring(1));} else {return Integer.parseInt(s);}}public void eval(String str) {collect(Row.of(parseInt(str.substring(0, 2)),parseInt(str.substring(2, 4)),parseInt(str.substring(4, 6)),parseInt(str.substring(6, 8))));}@Overridepublic TypeInformation <Row> getResultType() {return new RowTypeInfo(AlinkTypes.INT, AlinkTypes.INT, AlinkTypes.INT, AlinkTypes.INT);}}private static Tuple2 <StreamOperator, StreamOperator> getStreamTrainTestData() {StringBuilder sbd = new StringBuilder();for (int i = 0; i < ORIGIN_COL_NAMES.length; i++) {if (i > 0) {sbd.append(",");}sbd.append(ORIGIN_COL_NAMES[i]).append(" ").append(ORIGIN_COL_TYPES[i]);}StreamOperator <?> source = new CsvSourceStreamOp().setFilePath(DATA_DIR + FtrlRebaseTest.LARGE_FILE).setSchemaStr(sbd.toString()).setIgnoreFirstLine(true).udtf("dt", new String[] {"dt_year", "dt_month", "dt_day", "dt_hour"}, new SplitDataTime()).select(COL_NAMES);SplitStreamOp splitter = new SplitStreamOp().setFraction(0.5);source.link(splitter);return new Tuple2 <>(splitter, splitter.getSideOutput(0));}private static BatchOperator <?> getSmallBatchSet() {StringBuilder sbd = new StringBuilder();for (int i = 0; i < ORIGIN_COL_NAMES.length; i++) {if (i > 0) {sbd.append(",");}sbd.append(ORIGIN_COL_NAMES[i]).append(" ").append(ORIGIN_COL_TYPES[i]);}return new CsvSourceBatchOp().setFilePath(DATA_DIR + FtrlRebaseTest.SMALL_FILE).setSchemaStr(sbd.toString()).setIgnoreFirstLine(true).udtf("dt", new String[] {"dt_year", "dt_month", "dt_day", "dt_hour"}, new SplitDataTime()).select(COL_NAMES);}}
Demo 功能介绍
该demo使用Ftrl 算法对Avazu 数据(https://www.kaggle.com/c/avazu-ctr-prediction/data)进行实时训练并生成模型流,并将模型流实时加载到推理服务中。另外我们还增加了模型rebase 的示例代码,能够很容易的完成用一个批模型定时重新拉回模型,防止模型跑偏。最后还提供了一个模型训练+预测+评估的示例代码。
函数说明
| 函数 | 任务类型 | 说明 |
|---|---|---|
| trainFeatureModel() | 批任务 | 训练特征工程模型,这个模型将对训练、预测、推理数据进行特征编码 |
| �savePipelineModel() | 批任务 | 训练PipelineModel,该模型是部署到线上服务的模型 |
| �onlineTrainAndSave() | 流任务 | 使用Ftrl实时训练模型,并定时将模型写出到指定目录 |
| �BatchTrainAndSaveRebaseModel() | 批任务 | 训练用来重新初始化的模型,用来拉回模型,防止模型跑偏 |
| � localPredict() | 本地推理 | 本地搭建一个服务,对同一条样本预测,用来验证模型更新 |
| �onlineTrainAndEval() | 流任务 | Ftrl 训练模型,并对模型进行预测评估 |
执行步骤
- 首先我们需要执行 trainFeatureModel() 函数,生成特征工程模型,并存储到目录“/tmp/feature_model.csv�”,后面的函数都需要该模型
- 第二步,执行savePipelineModel() 函数,生成部署到线上服务的模型,该模型会通过 setModelStreamFilePath(“/tmp/ftrl_models”) 设置模型流的目录,设置完再部署这个模型时,在推理的同时会实时监控这个目录,当有新模型产生时会自动加载新模型,用最新模型进行推理。
- 第三步,执行localPredict() 函数,将第二步产生的模型部署成本地服务。
- 第四步,执行�onlineTrainAndSave(),用Ftrl算法实时训练在线模型,并以固定频率输出到目录”/tmp/ftrl_models”,这个目录与第二步的目录是同一个目录。另外这一步还要设置一个目录setModelStreamFilePath(“/tmp/rebase_ftrl_models”),这个目录是用来做模型rebase的,在线学习过程中会实时监控这个目录,当有新的模型出现在这个目录中,会重新加载这个模型作为base模型继续进行训练。
- 第五步,隔段时间(1小时 or 1天)执行函数BatchTrainAndSaveRebaseModel(),这个函数将训练一个新的批模型,并将其写入到rebase目录”/tmp/rebase_ftrl_models”,这样,第四步中的函数就会监测到这个模型,并进行模型 rebase。
onlineTrainAndEval() 函数是一个独立的函数,用来评估在线学习算法生成的模型怎么样,并打印评估结果。
备注
特征工程这里我们对数据做标准化和FeatureHash,其实这里可以使用任何其他Alink 的特征工程算法,类似OneHot编码,GBDT编码,多热编码,归一化、分桶算法等
- 实时加载在线训练的模型是通过一个给定的文件系统的目录(”/tmp/ftrl_models”)完成的,这个目录可以是本地目录,也可以是网络文件系统的目录,类似OSS 和HDSF等。
- 模型重新初始化和在线模型的加载是类似的,也是通过一个给定的文件系统的目录(”/tmp/rebase_ftrl_models”)完成的,同样,这个目录可以是本地目录,也可以是网络文件系统的目录,类似OSS 和HDSF等。
