算法简介

BERT 文本分类DLC是基于PAI-DLC的通用的基于BERT的分类模型,输入为文本,输出分类标签,如下所示:
image.png
尽管名字称为 BERT,但事实上支持ALBERT,RoBerta 等已定义的的ModelZoo

可视化配置参数

image.png

image.pngimage.png

【输入桩配置】

输入桩(从左到右) 限制数据类型 建议上游组件 是否必选
训练数据 oss 读数据表oss
测试数据 oss 读数据表oss

【输出桩配置】

输出桩 限制数据类型 建议下游组件 是否必选
结果数据 oss 写oss数据

【右侧参数表单】

数据参数:

参数名称 参数描述 取值类型 必选,默认值
输入数据(训练集) 输入训练集的OSS路径 OSS文件路径 必选
输入数据(验证集) 输入验证集的OSS路径 OSS文件路径 必选
输入Schema数据 输入CSV文件的Schema string类型 必选,比方说有三个字段,分别是str,float,int类型,名字为col1,col2,col3,长度分别为1,10,1,则配置如下:col1:str:1,col2:float:10,col3:int:1
文本列选择 文本序列在输入格式中对应的列名 string类型 必选,选择输入文本列,对应schema中的某一列,比方说schematic为col1:str:1,col2:float:10,col3:int:1,那么候选为col1,col2,col3中的一个
标签列选择 分类标签对应的列名 string类型 必选,选择输入标签列,对应schema中的某一列,比方说schematic为col1:str:1,col2:float:10,col3:int:1,那么候选为col1,col2,col3中的一个
标签枚举值 需要枚举出所有标签 string类型 必选,枚举所有的标签值,如果为二分类,则填入0,1
模型存储路径 模型checkpoint的存储路径 string类型 必选

模型参数:

参数名称 参数描述 取值类型 必选,默认值
模型选择 预训练模型名 string 可选,默认为’text_classify_bert’,此外还支持非bert模型: text_classify_cnn, text_classify_dgcnn
优化器类型 优化器选择 string 可选,默认为’adam’
batchSize 特征提取批大小 int 可选,默认为256
sequenceLength 序列整体最大长度 int 可选,默认为128,范围为1~512
numEpochs 训练的轮次 int 可选,默认为2
学习率 优化器的学习率 float 可选,默认为1e-5
pretrain_model_name_or_path 预训练模型的选择 string 常用的为:pai-bert-base-zh,
其他模型详见:https://yuque.antfin-inc.com/pai/transfer-learning/uugdk2

执行调优:

参数名称 参数描述 取值类型 必选,默认值
GPU机器类型 选择ECS机型用于模型训练 机型选择 必选

支持计算资源


【DLC】

输入桩(从左到右) 限制数据类型 建议上游组件 是否必选
训练数据 odps 读数据表odps
测试数据 odps 读数据表odps

【输出桩配置】

输出桩 限制数据类型 建议下游组件 是否必选
结果数据 oss 写oss数据

【右侧参数表单】

数据参数:

参数名称 参数描述 取值类型 必选,默认值
输入数据(训练集) 输入训练集的OSS路径 OSS文件路径 必选
输入数据(验证集) 输入验证集的OSS路径 OSS文件路径 必选
输入Schema数据 输入CSV文件的Schema string类型 必选
文本列选择 文本序列在输入格式中对应的列名 string类型 必选
标签列选择 分类标签对应的列名 string类型 必选
标签枚举值 需要枚举出所有标签 string类型 必选
模型存储路径 模型checkpoint的存储路径 string类型 必选

模型参数:

参数名称 参数描述 取值类型 必选,默认值
模型选择 预训练模型名 string 可选,默认为’text_classify_bert’,此外还支持非bert模型: text_classify_cnn, text_classify_dgcnn
优化器类型 优化器选择 string 可选,默认为’adam’
batchSize 特征提取批大小 int 可选,默认为256
sequenceLength 序列整体最大长度 int 可选,默认为128,范围为1~512
numEpochs 训练的轮次 int 可选,默认为2
学习率 优化器的学习率 float 可选,默认为1e-5
模型额外参数 额外的参数,比方说修改预训练模型等 string 可选,可以修改预训练模型,比方说:pretrain_model_name_or_path=pai-bert-base-zh,
其他模型详见:https://yuque.antfin-inc.com/pai/transfer-learning/uugdk2

执行调优:

参数名称 参数描述 取值类型 必选,默认值
GPU机器类型 选择ECS机型用于模型训练 机型选择 必选

支持计算资源


【DLC】

具体示例

首先可以下载 训练集评估集,其中 train.csv , dev.csv 是用\t 分隔的 .csv 文件:

  1. 53360 美少女甜甜圈自拍,迷之角度竟这么好看,美吸引一切事物 102 news_entertainment 自拍,美少女,经纪人,甜甜圈
  2. 53361 重庆美食打卡,带你领略舌尖上的重庆 102 news_food 重庆,美食,美味

我们定义这五个字段为 example_id,content,label,label_str,keywords
我们将相应的数据上传到 oss 上:
image.png

注意:本教程所用数据来自 TNEWS’ 今日头条中文新闻(短文本)分类,为了演示教程,训练集取了1000个样本,评估集取了100个样本。这里共有四个字段:

  • example_id: 样本id信息
  • content: 文本信息,对应组件里的“文并列选择”
  • label: label信息,对应组件里的“标签列选择”
  • label_str: 额外信息
  • keywords: 额外信息

参考以上可视化配置参数。创建工作流,新建两个输入组件(读数据表组件),对应训练数据和测试数据。将两个输入组件和模型组件连接,运行即可获得结果。工作流示例如下:
截屏2022-03-25 下午2.58.22.png

  • 输入Schema数据:example_id:int:1,content:str:1,label:str:1,label_str:str:1,keywords:str:1
  • 文本列选择:content
  • 标签列选择:label
  • 标签枚举值:可选,如 100,101,102,103,104,105,106,107,108,109,110,112,113,114,115,116

image.pngimage.png