编写用户端程序

导入依赖项

  1. import torch
  2. from datasets import load_dataset, load_metric
  3. from easynlp.modelzoo import AutoTokenizer
  4. from easynlp.appzoo.api import get_application_model
  5. from easynlp.utils.global_vars import parse_user_defined_parameters
  6. from easynlp.utils import losses
  7. from rapidformer import RapidformerEngine, get_args, Finetuner

创建EasyNL微调加速器

class EasyNLPFintuner(Finetuner):
    def __init__(self,
                 engine,
                 ):
        super().__init__(engine=engine)

    def train_valid_test_datasets_provider(self):
        args = get_args()
        """Build train and validation dataset."""
        tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

        def tokenize_function(examples):
            # max_length=None => use the model max length (it's actually the default)
            outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
            return outputs

        datasets = load_dataset(args.data_dir, args.data_name)

        # Apply the method we just defined to all the examples in all the splits of the dataset
        tokenized_datasets = datasets.map(
            tokenize_function,
            batched=True,
            remove_columns=["idx", "sentence1", "sentence2"],
        )
        tokenized_datasets.rename_column_("label", "labels")

        train_dataset = tokenized_datasets["train"]
        valid_dataset = tokenized_datasets['validation']
        test_dataset = tokenized_datasets['test']

        def collate_fn(examples):
            return tokenizer.pad(examples, padding="longest", return_tensors="pt")

        return train_dataset, valid_dataset, test_dataset, collate_fn

    def model_optimizer_lr_scheduler_provider(self):
        args = get_args()
        user_defined_parameters = parse_user_defined_parameters(args.user_defined_parameters)
        model = get_application_model(app_name=args.app_name,
                                      pretrained_model_name_or_path=args.pretrained_model_name_or_path,
                                      user_defined_parameters=user_defined_parameters)

        return model, None, None

    def run_forward_step(self, batch, model):
        label_ids = batch['labels']
        del batch['labels']
        forward_outputs = model(batch)
        logits = forward_outputs['logits']
        loss = losses.cross_entropy(logits, label_ids)
        return loss

    # after each epoch run metric on eval dataset
    def run_compute_metrics(self, model, eval_dataloader):
        args = get_args()
        model = model[0]
        metric = load_metric(args.data_dir, args.data_name)
        for step, batch in enumerate(eval_dataloader):
            label_ids = batch['labels']
            del batch['labels']
            with torch.no_grad():
                forward_outputs = model(batch)
            predictions = forward_outputs['predictions']
            metric.add_batch(
                predictions=self.gather(predictions),
                references=self.gather(label_ids),
            )

        eval_metric = metric.compute()
        return eval_metric

定义Main函数

if __name__ == "__main__":
    engine = RapidformerEngine()
    trainer = EasyNLPFintuner(engine=engine)
    trainer.train()

启动训练

python -m torch.distributed.launch $DISTRIBUTED_ARGS finetune_easynlp_bert.py \
       --task sequence_classification \
       --app-name=text_classify \
       --user-defined-parameters='pretrain_model_name_or_path=bert-base-uncased' \
       --data-dir glue \
       --data-name mrpc \
       --micro-batch-size 16 \
       --global-batch-size 16 \
       --epochs 3 \
       --num-layers 12 \
       --hidden-size 768 \
       --num-attention-heads 12 \
       --max-position-embeddings 512 \
       --seq-length 512 \
       --lr 2e-5 \
       --lr-decay-style linear \
       --lr-warmup-iters 100 \
       --weight-decay 1e-2 \
       --clip-grad 1.0 \
       --seed 42 \
       --log-interval 100 \
       --eval-interval 1000 \               
       --mixed-precision    \                #开启混合精度
       --onnx-runtime-training \             #开启计算图优化
       --zero-2-memory-optimization   \      #开启显存优化
       --num-workers 2 \                     #数据读取加速