Demo code
package benchmark.online;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.table.functions.TableFunction;
import org.apache.flink.types.Row;
import com.alibaba.alink.common.AlinkGlobalConfiguration;
import com.alibaba.alink.common.AlinkTypes;
import com.alibaba.alink.common.MLEnvironmentFactory;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.classification.LogisticRegressionTrainBatchOp;
import com.alibaba.alink.operator.batch.sink.AkSinkBatchOp;
import com.alibaba.alink.operator.batch.sink.AppendModelStreamFileSinkBatchOp;
import com.alibaba.alink.operator.batch.source.CsvSourceBatchOp;
import com.alibaba.alink.operator.stream.StreamOperator;
import com.alibaba.alink.operator.stream.classification.LogisticRegressionPredictStreamOp;
import com.alibaba.alink.operator.stream.dataproc.JsonValueStreamOp;
import com.alibaba.alink.operator.stream.dataproc.SplitStreamOp;
import com.alibaba.alink.operator.stream.evaluation.EvalBinaryClassStreamOp;
import com.alibaba.alink.operator.stream.onlinelearning.FtrlTrainStreamOp;
import com.alibaba.alink.operator.stream.sink.ModelStreamFileSinkStreamOp;
import com.alibaba.alink.operator.stream.source.CsvSourceStreamOp;
import com.alibaba.alink.pipeline.LocalPredictor;
import com.alibaba.alink.pipeline.Pipeline;
import com.alibaba.alink.pipeline.PipelineModel;
import com.alibaba.alink.pipeline.classification.LogisticRegression;
import com.alibaba.alink.pipeline.dataproc.StandardScaler;
import com.alibaba.alink.pipeline.feature.FeatureHasher;
import org.apache.commons.lang3.ArrayUtils;
import org.junit.Test;
/**
* https://www.kaggle.com/c/avazu-ctr-prediction/data
*/
public class FtrlTest {
private static final String[] ORIGIN_COL_NAMES = new String[] {
"id", "click", "dt", "C1", "banner_pos",
"site_id", "site_domain", "site_category", "app_id", "app_domain",
"app_category", "device_id", "device_ip", "device_model", "device_type",
"device_conn_type", "C14", "C15", "C16", "C17",
"C18", "C19", "C20", "C21"
};
private static final String[] ORIGIN_COL_TYPES = new String[] {
"string", "string", "string", "string", "int",
"string", "string", "string", "string", "string",
"string", "string", "string", "string", "string",
"string", "int", "int", "int", "int",
"int", "int", "int", "int"
};
private static final String[] COL_NAMES = new String[] {
"id", "click",
"dt_year", "dt_month", "dt_day", "dt_hour",
"C1", "banner_pos",
"site_id", "site_domain", "site_category", "app_id", "app_domain",
"app_category", "device_id", "device_ip", "device_model", "device_type",
"device_conn_type", "C14", "C15", "C16", "C17",
"C18", "C19", "C20", "C21"
};
private static final String DATA_DIR = "https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/";
private static final String SMALL_FILE = "avazu-small.csv";
private static final String LARGE_FILE = "avazu-ctr-train-8M.csv";
private static final String FEATURE_PIPELINE_MODEL_FILE = "/tmp/feature_model.csv";
private static final String labelColName = "click";
private static final String vecColName = "vec";
static final String[] FEATURE_COL_NAMES =
ArrayUtils.removeElements(COL_NAMES, labelColName, "id", "dt_year", "dt_month",
"site_id", "site_domain", "app_id", "device_id", "device_ip", "device_model");
static final String[] HIGH_FREQ_FEATURE_COL_NAMES = new String[] {"site_id", "site_domain", "device_id",
"device_model"};
static final String[] CATEGORY_FEATURE_COL_NAMES = new String[] {
"C1", "banner_pos",
"site_category", "app_domain",
"app_category", "device_type",
"device_conn_type"
};
static final String[] NUMERICAL_FEATURE_COL_NAMES =
ArrayUtils.removeElements(FEATURE_COL_NAMES, CATEGORY_FEATURE_COL_NAMES);
@Test
public void trainFeatureModel() throws Exception {
MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().setParallelism(4);
MLEnvironmentFactory.getDefault().getExecutionEnvironment().setParallelism(1);
int numHashFeatures = 30000;
Pipeline feature_pipeline = new Pipeline()
.add(
new StandardScaler()
.setSelectedCols(NUMERICAL_FEATURE_COL_NAMES)
)
.add(
new FeatureHasher()
.setSelectedCols(ArrayUtils.addAll(FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES))
.setCategoricalCols(ArrayUtils.addAll(CATEGORY_FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES))
.setOutputCol(vecColName)
.setNumFeatures(numHashFeatures).setReservedCols("click")
);
feature_pipeline.fit(getSmallBatchSet()).save(FEATURE_PIPELINE_MODEL_FILE,
true);
BatchOperator.execute();
}
@Test
public void onlineTrainAndEval() throws Exception {
PipelineModel featurePipelineModel = PipelineModel.load(FEATURE_PIPELINE_MODEL_FILE);
AlinkGlobalConfiguration.setPrintProcessInfo(true);
Tuple2 <StreamOperator, StreamOperator> sources = getStreamTrainTestData();
StreamOperator <?> trainStream = sources.f0;
StreamOperator <?> testStream = sources.f1;
trainStream = featurePipelineModel.transform(trainStream);
testStream = featurePipelineModel.transform(testStream);
BatchOperator <?> trainBatch = featurePipelineModel.transform(getSmallBatchSet());
StreamOperator.setParallelism(2);
BatchOperator <?> model = new LogisticRegressionTrainBatchOp()
.setVectorCol(vecColName)
.setLabelCol(labelColName)
.setWithIntercept(true)
.linkFrom(trainBatch);
StreamOperator <?> models = new FtrlTrainStreamOp(model)
.setVectorCol(vecColName)
.setLabelCol(labelColName)
.setMiniBatchSize(1024)
.setTimeInterval(10)
.setWithIntercept(true)
.setModelStreamFilePath("/tmp/avazu_fm_models")
.linkFrom(trainStream);
StreamOperator <?> predictResults = new LogisticRegressionPredictStreamOp(model)
.setPredictionCol("predict")
.setReservedCols(labelColName)
.setPredictionDetailCol("details")
.linkFrom(testStream, models);
new EvalBinaryClassStreamOp()
.setPredictionDetailCol("details").setLabelCol(labelColName).setTimeInterval(10).linkFrom(predictResults)
.link(new JsonValueStreamOp().setSelectedCol("Data")
.setReservedCols(new String[] {"Statistics"})
.setOutputCols(new String[] {"Accuracy", "AUC", "ConfusionMatrix"})
.setJsonPath("$.Accuracy", "$.AUC", "ConfusionMatrix")).print();
StreamOperator.execute();
}
@Test
public void onlineTrainAndSave() throws Exception {
PipelineModel featurePipelineModel = PipelineModel.load(FEATURE_PIPELINE_MODEL_FILE);
AlinkGlobalConfiguration.setPrintProcessInfo(true);
Tuple2 <StreamOperator, StreamOperator> sources = getStreamTrainTestData();
StreamOperator <?> trainStream = sources.f0;
trainStream = featurePipelineModel.transform(trainStream);
BatchOperator <?> trainBatch = featurePipelineModel.transform(getSmallBatchSet());
StreamOperator.setParallelism(2);
BatchOperator <?> model = new LogisticRegressionTrainBatchOp()
.setVectorCol(vecColName)
.setLabelCol(labelColName)
.setWithIntercept(true)
.linkFrom(trainBatch);
StreamOperator <?> models = new FtrlTrainStreamOp(model)
.setVectorCol(vecColName)
.setLabelCol(labelColName)
.setMiniBatchSize(1024)
.setTimeInterval(10)
.setWithIntercept(true)
.setModelStreamFilePath("/tmp/rebase_ftrl_models")
.linkFrom(trainStream);
models.link(new ModelStreamFileSinkStreamOp().setFilePath("/tmp/ftrl_models"));
StreamOperator.execute();
}
@Test
public void BatchTrainAndSaveRebaseModel() throws Exception {
PipelineModel featurePipelineModel = PipelineModel.load(FEATURE_PIPELINE_MODEL_FILE);
BatchOperator <?> trainBatch = featurePipelineModel.transform(getSmallBatchSet());
StreamOperator.setParallelism(2);
BatchOperator <?> model1 = new LogisticRegressionTrainBatchOp()
.setVectorCol(vecColName)
.setLabelCol(labelColName)
.setWithIntercept(true)
.linkFrom(trainBatch);
model1.link(new AppendModelStreamFileSinkBatchOp().setFilePath("/tmp/rebase_ftrl_models"));
BatchOperator.execute();
}
@Test
public void savePipelineModel() throws Exception {
BatchOperator <?> trainBatch = getSmallBatchSet();
int numHashFeatures = 30000;
PipelineModel pipelineModel = new Pipeline()
.add(
new StandardScaler()
.setSelectedCols(NUMERICAL_FEATURE_COL_NAMES)
)
.add(
new FeatureHasher()
.setSelectedCols(ArrayUtils.addAll(FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES))
.setCategoricalCols(ArrayUtils.addAll(CATEGORY_FEATURE_COL_NAMES, HIGH_FREQ_FEATURE_COL_NAMES))
.setOutputCol(vecColName)
.setNumFeatures(numHashFeatures).setReservedCols("click")
).add(
new LogisticRegression()
.setVectorCol("vec")
.setLabelCol("click")
.setPredictionCol("pred")
.setModelStreamFilePath("/tmp/ftrl_models")
.setPredictionDetailCol("detail")
.setMaxIter(10))
.fit(trainBatch);
pipelineModel.save().link(new AkSinkBatchOp().setOverwriteSink(true).setFilePath("/tmp/lr_pipeline.ak"));
BatchOperator.execute();
}
@Test
public void localPredictor() throws Exception {
LocalPredictor predictor = new LocalPredictor("/tmp/lr_pipeline.ak",
TableUtil.schema2SchemaStr(getSmallBatchSet().getSchema()));
System.out.println(TableUtil.schema2SchemaStr(getSmallBatchSet().getSchema()));
for (int i = 0; i < Integer.MAX_VALUE; ++i) {
System.out.println(predictor.map(
Row.of("220869541682524752", "0", 14, 10, 21, 2, "1005", 0, "1fbe01fe", "f3845767",
"28905ebd", "ecad2386", "7801e8d9", "07d7df22", "a99f214a", "af1c0727", "a0f5f879", "1", "0",
15703, 320, 50, 1722, 0, 35, -1, 79)));
Thread.sleep(5000);
}
}
public static class SplitDataTime extends TableFunction <Row> {
private Integer parseInt(String s) {
if ('0' == s.charAt(0)) {
return Integer.parseInt(s.substring(1));
} else {
return Integer.parseInt(s);
}
}
public void eval(String str) {
collect(Row.of(
parseInt(str.substring(0, 2)),
parseInt(str.substring(2, 4)),
parseInt(str.substring(4, 6)),
parseInt(str.substring(6, 8))
));
}
@Override
public TypeInformation <Row> getResultType() {
return new RowTypeInfo(AlinkTypes.INT, AlinkTypes.INT, AlinkTypes.INT, AlinkTypes.INT);
}
}
private static Tuple2 <StreamOperator, StreamOperator> getStreamTrainTestData() {
StringBuilder sbd = new StringBuilder();
for (int i = 0; i < ORIGIN_COL_NAMES.length; i++) {
if (i > 0) {
sbd.append(",");
}
sbd.append(ORIGIN_COL_NAMES[i]).append(" ").append(ORIGIN_COL_TYPES[i]);
}
StreamOperator <?> source = new CsvSourceStreamOp()
.setFilePath(DATA_DIR + FtrlRebaseTest.LARGE_FILE)
.setSchemaStr(sbd.toString())
.setIgnoreFirstLine(true)
.udtf("dt", new String[] {"dt_year", "dt_month", "dt_day", "dt_hour"}, new SplitDataTime())
.select(COL_NAMES);
SplitStreamOp splitter = new SplitStreamOp().setFraction(0.5);
source.link(splitter);
return new Tuple2 <>(splitter, splitter.getSideOutput(0));
}
private static BatchOperator <?> getSmallBatchSet() {
StringBuilder sbd = new StringBuilder();
for (int i = 0; i < ORIGIN_COL_NAMES.length; i++) {
if (i > 0) {
sbd.append(",");
}
sbd.append(ORIGIN_COL_NAMES[i]).append(" ").append(ORIGIN_COL_TYPES[i]);
}
return new CsvSourceBatchOp()
.setFilePath(DATA_DIR + FtrlRebaseTest.SMALL_FILE)
.setSchemaStr(sbd.toString())
.setIgnoreFirstLine(true)
.udtf("dt", new String[] {"dt_year", "dt_month", "dt_day", "dt_hour"}, new SplitDataTime())
.select(COL_NAMES);
}
}
Demo 功能介绍
该demo使用Ftrl 算法对Avazu 数据(https://www.kaggle.com/c/avazu-ctr-prediction/data)进行实时训练并生成模型流,并将模型流实时加载到推理服务中。另外我们还增加了模型rebase 的示例代码,能够很容易的完成用一个批模型定时重新拉回模型,防止模型跑偏。最后还提供了一个模型训练+预测+评估的示例代码。
函数说明
函数 | 任务类型 | 说明 |
---|---|---|
trainFeatureModel() | 批任务 | 训练特征工程模型,这个模型将对训练、预测、推理数据进行特征编码 |
�savePipelineModel() | 批任务 | 训练PipelineModel,该模型是部署到线上服务的模型 |
�onlineTrainAndSave() | 流任务 | 使用Ftrl实时训练模型,并定时将模型写出到指定目录 |
�BatchTrainAndSaveRebaseModel() | 批任务 | 训练用来重新初始化的模型,用来拉回模型,防止模型跑偏 |
� localPredict() | 本地推理 | 本地搭建一个服务,对同一条样本预测,用来验证模型更新 |
�onlineTrainAndEval() | 流任务 | Ftrl 训练模型,并对模型进行预测评估 |
执行步骤
- 首先我们需要执行 trainFeatureModel() 函数,生成特征工程模型,并存储到目录“/tmp/feature_model.csv�”,后面的函数都需要该模型
- 第二步,执行savePipelineModel() 函数,生成部署到线上服务的模型,该模型会通过 setModelStreamFilePath(“/tmp/ftrl_models”) 设置模型流的目录,设置完再部署这个模型时,在推理的同时会实时监控这个目录,当有新模型产生时会自动加载新模型,用最新模型进行推理。
- 第三步,执行localPredict() 函数,将第二步产生的模型部署成本地服务。
- 第四步,执行�onlineTrainAndSave(),用Ftrl算法实时训练在线模型,并以固定频率输出到目录”/tmp/ftrl_models”,这个目录与第二步的目录是同一个目录。另外这一步还要设置一个目录setModelStreamFilePath(“/tmp/rebase_ftrl_models”),这个目录是用来做模型rebase的,在线学习过程中会实时监控这个目录,当有新的模型出现在这个目录中,会重新加载这个模型作为base模型继续进行训练。
- 第五步,隔段时间(1小时 or 1天)执行函数BatchTrainAndSaveRebaseModel(),这个函数将训练一个新的批模型,并将其写入到rebase目录”/tmp/rebase_ftrl_models”,这样,第四步中的函数就会监测到这个模型,并进行模型 rebase。
onlineTrainAndEval() 函数是一个独立的函数,用来评估在线学习算法生成的模型怎么样,并打印评估结果。
备注
特征工程这里我们对数据做标准化和FeatureHash,其实这里可以使用任何其他Alink 的特征工程算法,类似OneHot编码,GBDT编码,多热编码,归一化、分桶算法等
- 实时加载在线训练的模型是通过一个给定的文件系统的目录(”/tmp/ftrl_models”)完成的,这个目录可以是本地目录,也可以是网络文件系统的目录,类似OSS 和HDSF等。
- 模型重新初始化和在线模型的加载是类似的,也是通过一个给定的文件系统的目录(”/tmp/rebase_ftrl_models”)完成的,同样,这个目录可以是本地目录,也可以是网络文件系统的目录,类似OSS 和HDSF等。