Huggingface Transformer对接简介
本文将介绍在EasyNLP框架的基础上,将 HuggingFace transformer包的模型运行过程转换为EasyNLP框架的代码进行快速对接运行。下文以文本分类(text-classification)为例进行阐述。
Example代码地址
https://github.com/alibaba/EasyNLP/tree/master/examples/hf_adapter_easynlp
训练阶段
数据处理
数据处理过程需要按照EasyNLP框架ClassificationDataset接口进行定义
train_dataset = ClassificationDataset(pretrained_model_name_or_path=args.pretrained_model_name_or_path,data_file=args.tables.split(",")[0],max_seq_length=args.sequence_length,input_schema=args.input_schema,first_sequence=args.first_sequence,second_sequence=args.second_sequence,label_name=args.label_name,label_enumerate_values=args.label_enumerate_values,user_defined_parameters=user_defined_parameters,is_training=True)
- pretrained_model_name_or_path:这里需要填写目前EasyNLP框架支持的模型key值,可以从本地路径下载的modelzoo_alibaba.json文件中选取合适的模型
- 其他参数参考EasyNLP文档解释即可
模型处理
模型部分,EasyNLP框架兼容Hugginface Transformers库,可以直接迁移过来,但需要将计算loss和predict状态公共提取特征部分动态绑定到模型类中,可以按照自己的需求更改loss的定义方式
from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained('/root/.easynlp/modelzoo/bert-base-uncased/')
model.compute_loss = MethodType(compute_loss, model)
model.forward_repre = MethodType(forward_repre, model)
预测阶段
直接将训练过程中得到的checkpoint加载模型进行测试,其他参数说明见EasyNLP文档即可
predictor = Predictor(
model_dir=args.checkpoint_dir,
user_defined_parameters=args.user_defined_parameters,
first_sequence=args.first_sequence,
second_sequence=args.second_sequence,
sequence_length=args.sequence_length,
input_file=args.tables.split(",")[-1],
input_schema=args.input_schema,
output_file=args.outputs,
output_schema=args.output_schema,
append_cols=args.append_cols,
batch_size=args.micro_batch_size,
args=args
)
