Huggingface Transformer对接简介

本文将介绍在EasyNLP框架的基础上,将 HuggingFace transformer包的模型运行过程转换为EasyNLP框架的代码进行快速对接运行。下文以文本分类(text-classification)为例进行阐述。

Example代码地址

https://github.com/alibaba/EasyNLP/tree/master/examples/hf_adapter_easynlp

训练阶段

数据处理

数据处理过程需要按照EasyNLP框架ClassificationDataset接口进行定义

  1. train_dataset = ClassificationDataset(
  2. pretrained_model_name_or_path=args.pretrained_model_name_or_path,
  3. data_file=args.tables.split(",")[0],
  4. max_seq_length=args.sequence_length,
  5. input_schema=args.input_schema,
  6. first_sequence=args.first_sequence,
  7. second_sequence=args.second_sequence,
  8. label_name=args.label_name,
  9. label_enumerate_values=args.label_enumerate_values,
  10. user_defined_parameters=user_defined_parameters,
  11. is_training=True)
  1. pretrained_model_name_or_path:这里需要填写目前EasyNLP框架支持的模型key值,可以从本地路径下载的modelzoo_alibaba.json文件中选取合适的模型
  2. 其他参数参考EasyNLP文档解释即可

模型处理

模型部分,EasyNLP框架兼容Hugginface Transformers库,可以直接迁移过来,但需要将计算loss和predict状态公共提取特征部分动态绑定到模型类中,可以按照自己的需求更改loss的定义方式

from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained('/root/.easynlp/modelzoo/bert-base-uncased/')
model.compute_loss = MethodType(compute_loss, model)
model.forward_repre = MethodType(forward_repre, model)

具体详绑定代码见:https://github.com/alibaba/EasyNLP/blob/35b7fbf73a9b9871dd17dd66ab8416c9ae81afd1/examples/hf_adapter_easynlp/hf_ez_nlp_user_defined.py#L5

预测阶段

直接将训练过程中得到的checkpoint加载模型进行测试,其他参数说明见EasyNLP文档即可

 predictor = Predictor(
            model_dir=args.checkpoint_dir,
            user_defined_parameters=args.user_defined_parameters,
            first_sequence=args.first_sequence,
            second_sequence=args.second_sequence,
            sequence_length=args.sequence_length,
            input_file=args.tables.split(",")[-1],
            input_schema=args.input_schema,
            output_file=args.outputs,
            output_schema=args.output_schema,
            append_cols=args.append_cols,
            batch_size=args.micro_batch_size,
            args=args
        )