算法简介
BERT 文本分类DLC是基于PAI-DLC的通用的基于BERT的分类模型,输入为文本,输出分类标签,如下所示:
尽管名字称为 BERT,但事实上支持ALBERT,RoBerta 等已定义的的ModelZoo
可视化配置参数

【输入桩配置】
| 输入桩(从左到右) | 限制数据类型 | 建议上游组件 | 是否必选 |
|---|---|---|---|
| 训练数据 | oss | 读数据表oss | 是 |
| 测试数据 | oss | 读数据表oss | 是 |
【输出桩配置】
| 输出桩 | 限制数据类型 | 建议下游组件 | 是否必选 |
|---|---|---|---|
| 结果数据 | oss | 写oss数据 | 否 |
【右侧参数表单】
数据参数:
| 参数名称 | 参数描述 | 取值类型 | 必选,默认值 |
|---|---|---|---|
| 输入数据(训练集) | 输入训练集的OSS路径 | OSS文件路径 | 必选 |
| 输入数据(验证集) | 输入验证集的OSS路径 | OSS文件路径 | 必选 |
| 输入Schema数据 | 输入CSV文件的Schema | string类型 | 必选,比方说有三个字段,分别是str,float,int类型,名字为col1,col2,col3,长度分别为1,10,1,则配置如下:col1:str:1,col2:float:10,col3:int:1 |
| 文本列选择 | 文本序列在输入格式中对应的列名 | string类型 | 必选,选择输入文本列,对应schema中的某一列,比方说schematic为col1:str:1,col2:float:10,col3:int:1,那么候选为col1,col2,col3中的一个 |
| 标签列选择 | 分类标签对应的列名 | string类型 | 必选,选择输入标签列,对应schema中的某一列,比方说schematic为col1:str:1,col2:float:10,col3:int:1,那么候选为col1,col2,col3中的一个 |
| 标签枚举值 | 需要枚举出所有标签 | string类型 | 必选,枚举所有的标签值,如果为二分类,则填入0,1 |
| 模型存储路径 | 模型checkpoint的存储路径 | string类型 | 必选 |
模型参数:
| 参数名称 | 参数描述 | 取值类型 | 必选,默认值 |
|---|---|---|---|
| 模型选择 | 预训练模型名 | string | 可选,默认为’text_classify_bert’,此外还支持非bert模型: text_classify_cnn, text_classify_dgcnn |
| 优化器类型 | 优化器选择 | string | 可选,默认为’adam’ |
| batchSize | 特征提取批大小 | int | 可选,默认为256 |
| sequenceLength | 序列整体最大长度 | int | 可选,默认为128,范围为1~512 |
| numEpochs | 训练的轮次 | int | 可选,默认为2 |
| 学习率 | 优化器的学习率 | float | 可选,默认为1e-5 |
| pretrain_model_name_or_path | 预训练模型的选择 | string | 常用的为:pai-bert-base-zh, 其他模型详见:https://yuque.antfin-inc.com/pai/transfer-learning/uugdk2 |
执行调优:
| 参数名称 | 参数描述 | 取值类型 | 必选,默认值 |
|---|---|---|---|
| GPU机器类型 | 选择ECS机型用于模型训练 | 机型选择 | 必选 |
支持计算资源
| 输入桩(从左到右) | 限制数据类型 | 建议上游组件 | 是否必选 |
|---|---|---|---|
| 训练数据 | odps | 读数据表odps | 是 |
| 测试数据 | odps | 读数据表odps | 是 |
【输出桩配置】
| 输出桩 | 限制数据类型 | 建议下游组件 | 是否必选 |
|---|---|---|---|
| 结果数据 | oss | 写oss数据 | 否 |
【右侧参数表单】
数据参数:
| 参数名称 | 参数描述 | 取值类型 | 必选,默认值 |
|---|---|---|---|
| 输入数据(训练集) | 输入训练集的OSS路径 | OSS文件路径 | 必选 |
| 输入数据(验证集) | 输入验证集的OSS路径 | OSS文件路径 | 必选 |
| 输入Schema数据 | 输入CSV文件的Schema | string类型 | 必选 |
| 文本列选择 | 文本序列在输入格式中对应的列名 | string类型 | 必选 |
| 标签列选择 | 分类标签对应的列名 | string类型 | 必选 |
| 标签枚举值 | 需要枚举出所有标签 | string类型 | 必选 |
| 模型存储路径 | 模型checkpoint的存储路径 | string类型 | 必选 |
模型参数:
| 参数名称 | 参数描述 | 取值类型 | 必选,默认值 |
|---|---|---|---|
| 模型选择 | 预训练模型名 | string | 可选,默认为’text_classify_bert’,此外还支持非bert模型: text_classify_cnn, text_classify_dgcnn |
| 优化器类型 | 优化器选择 | string | 可选,默认为’adam’ |
| batchSize | 特征提取批大小 | int | 可选,默认为256 |
| sequenceLength | 序列整体最大长度 | int | 可选,默认为128,范围为1~512 |
| numEpochs | 训练的轮次 | int | 可选,默认为2 |
| 学习率 | 优化器的学习率 | float | 可选,默认为1e-5 |
| 模型额外参数 | 额外的参数,比方说修改预训练模型等 | string | 可选,可以修改预训练模型,比方说:pretrain_model_name_or_path=pai-bert-base-zh, 其他模型详见:https://yuque.antfin-inc.com/pai/transfer-learning/uugdk2 |
执行调优:
| 参数名称 | 参数描述 | 取值类型 | 必选,默认值 |
|---|---|---|---|
| GPU机器类型 | 选择ECS机型用于模型训练 | 机型选择 | 必选 |
支持计算资源
具体示例
首先可以下载 训练集 和 评估集,其中 train.csv , dev.csv 是用\t 分隔的 .csv 文件:
53360 美少女甜甜圈自拍,迷之角度竟这么好看,美吸引一切事物 102 news_entertainment 自拍,美少女,经纪人,甜甜圈53361 重庆美食打卡,带你领略舌尖上的重庆 102 news_food 重庆,美食,美味
我们定义这五个字段为 example_id,content,label,label_str,keywords
我们将相应的数据上传到 oss 上:
注意:本教程所用数据来自 TNEWS’ 今日头条中文新闻(短文本)分类,为了演示教程,训练集取了1000个样本,评估集取了100个样本。这里共有四个字段:
- example_id: 样本id信息
- content: 文本信息,对应组件里的“文并列选择”
- label: label信息,对应组件里的“标签列选择”
- label_str: 额外信息
- keywords: 额外信息
参考以上可视化配置参数。创建工作流,新建两个输入组件(读数据表组件),对应训练数据和测试数据。将两个输入组件和模型组件连接,运行即可获得结果。工作流示例如下:
- 输入Schema数据:example_id:int:1,content:str:1,label:str:1,label_str:str:1,keywords:str:1
- 文本列选择:content
- 标签列选择:label
- 标签枚举值:可选,如 100,101,102,103,104,105,106,107,108,109,110,112,113,114,115,116



