5.1 得分判别模型的整体训练流程


学习目标

  • 掌握得分判别模型的整体训练流程。

接下来,我们将训练一个BERT-Multilingual模型来判断学生的最终得分,这是一个文本多分类任务。我们将使用Huggingface提供的BERT-Multilingual脚本来进行微调训练。


得分判别模型的整体训练流程

  • 第一步: 训练集与验证集的划分
  • 第二步: 根据现有任务选择标准任务类型
  • 第三步: 使用微调脚本进行训练和验证
  • 第四步: 重要的超参数调优
  • 第五步: 使用标签平滑技术
  • 第六步: Badcase分析以便对数据进行调整和重新训练

第一步: 训练集与验证集的划分

我们之前已经定义了三个标签’Y’,’N’,’M’,在这里我们将直接划分数据集。

  1. # 使用train_test_split工具进行划分
  2. from sklearn.model_selection import train_test_split
  3. # 分别读取不同标签的数据
  4. zero_path = "./yxb_data/zero_target.tsv"
  5. df1 = pd.read_csv(zero_path, sep="\t")
  6. perfect_path = "./yxb_data/perfect_target.tsv"
  7. df2 = pd.read_csv(perfect_path, sep="\t")
  8. middle_path = "./yxb_data/middle_target.tsv"
  9. df3 = pd.read_csv(middle_path, sep="\t")
  10. # 将他们进行连接
  11. df = pd.concat([df1, df2, df3], axis=0, join="outer")
  12. # 使用20%作为验证集
  13. train, valid = train_test_split(df, shuffle=True, test_size=0.2)
  14. # 将训练集和验证集写入文件
  15. train.to_csv("./yxb_data/train.tsv", index=False, sep="\t")
  16. valid.to_csv("./yxb_data/dev.tsv", index=False, sep="\t")
  • 代码位置:
    • /home/YXB/data_processor.py

  • 输出效果:
    • 在./yxb_data/路径下出现train.tsv和dev.tsv。

第二步: 根据现有任务选择标准任务类型

  • 因为我们要使用微调脚本进行预训练模型的微调,而这些微调脚本都是在标准数据集GLUE下进行编写的,因此,我们需要从标准数据集中找到与我们任务类型相同的标准任务,即文本三分类任务,并使用ACC作为评估指标。前往GLUE标准数据集, 我们将使用MNLI任务,它正是文本多分类且评估指标为ACC。

    第三步: 使用微调脚本进行训练和验证

  • 接下来我们将使用配置微调脚本run_glue.sh中的参数: ```shell

    微调运行脚本:

    任务名代表任务类型:必须选择已有标准任务中的任务,这里我们选择MNLI-MM,这和我们刚刚选择MNLI类似,

    MNLI-MM早期也是GLUE之前的标准任务之一,后来融合到MNLI,但是我们使用的这个版本的transformers仍然支持MNLI-MM。

    定义DATA_DIR: 微调数据所在路径, 这里我们使用yxb_data中的数据作为微调数据

    export DATA_DIR=”./yxb_data”

    定义SAVE_DIR: 模型的保存路径, 我们将模型保存在当前目录的bert_finetuning_test文件中

    export SAVE_DIR=”./bert_multi_finetuning_test/“

使用python运行微调脚本

run_glue.py : 已为大家准备好

—model_type: 选择需要微调的模型类型, 这里可以选择BERT, XLNET, XLM, roBERTa, distilBERT, ALBERT

—model_name_or_path: 选择具体的模型或者变体, 这里是在英文语料上微调, 因此选择bert-base-uncased

—task_name: 它将代表对应的任务类型, 如MRPC代表句子对二分类任务

—do_train: 使用微调脚本进行训练

—do_eval: 使用微调脚本进行验证

—data_dir: 训练集及其验证集所在路径, 将自动寻找该路径下的train.tsv和dev.tsv作为训练集和验证集

—max_seq_length: 输入句子的最大长度, 超过则截断, 不足则补齐

—learning_rate: 学习率

—num_train_epochs: 训练轮数

—save_steps: 检测点保存步骤间隔

—logging_steps: 日志打印步骤间隔

—output_dir $SAVE_DIR: 训练后的模型保存路径

python run_glue.py \ —model_type BERT \ —model_name_or_path bert-base-multilingual-cased \ —task_name MNLI-MM \ —do_train \ —do_eval \ —data_dir $DATA_DIR/ \ —max_seq_length 100 \ —learning_rate 2e-5 \ —num_train_epochs 10 \ —save_steps 2000 \ —logging_steps 2000 \ —overwrite_output_dir \ —output_dir $SAVE_DIR ```


  • 运行以上微调脚本:

sh run_glue.sh


  • 直接运行将可能出现以下的问题:

Traceback (most recent call last):
File “run_glue.py”, line 536, in
main()
File “run_glue.py”, line 486, in main
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
File “run_glue.py”, line 301, in load_and_cache_examples
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
File “/root/miniconda3/lib/python3.6/site-packages/transformers/data/processors/glue.py”, line 207, in get_train_examples
self._read_tsv(os.path.join(data_dir, “train.tsv”)), “train”)
File “/root/miniconda3/lib/python3.6/site-packages/transformers/data/processors/glue.py”, line 226, in _create_examples
text_a = line[8]
IndexError: list index out of range

  • 列表索引越界问题是因为我们现在给模型输入的数据列和标准MNLI的数据列不同导致的。因此我们需要根据对MNLI训练数据和我们已有数据的列进行对比,以便修改源码中一部分数据读取和标签设定信息。
    • 修改1:对于读取的输入数据,我们知道原标准任务中共12列,而我们的数据中只有4列,修改文件”/root/miniconda3/lib/python3.6/site-packages/transformers/data/processors/glue.py”中的第226行。

219 def _create_examples(self, lines, set_type):
220 “””Creates examples for the training and dev sets.”””
221 examples = []
222 for (i, line) in enumerate(lines):
223 if i == 0:
224 continue
225 guid = “%s-%s” % (set_type, line[0])
226 text_a = line[8]
227 text_b = line[9]
228 label = line[-1]
229 examples.append(
230 InputExample(guid=guid, text_a=text_a, text_b=t ext_b, label=label))
231 return examples

———- 修改为 ————

把8,9列改成1,2列

219 def _create_examples(self, lines, set_type):
220 “””Creates examples for the training and dev sets.”””
221 examples = []
222 for (i, line) in enumerate(lines):
223 if i == 0:
224 continue
225 guid = “%s-%s” % (set_type, line[0])
226 ####
227 # text_a = line[8]
228 # text_b = line[9]
229 ####
230 text_a = line[1]
231 text_b = line[2]
232 label = line[-1]
233 examples.append(
234 InputExample(guid=guid, text_a=text_a, text_b=t ext_b, label=label))
235 return examples


  • 修改2: 原标准数据的标签也需要更改,在218行,由原来的[“contradiction”, “entailment”, “neutral”]改成[“Y”, “N”, “M”]。

215 def get_labels(self):
216 “””See base class.”””
217 return [“contradiction”, “entailment”, “neutral”]

———- 修改为 ————

215 def get_labels(self):
216 “””See base class.”””
217 # return [“contradiction”, “entailment”, “neutral”]
218 return [“Y”, “N”, “M”]


  • 修改3: 修改验证集的文件名,在212行,由原来的”dev_matched.tsv”改成”dev.tsv”。

209 def get_dev_examples(self, data_dir):
210 “””See base class.”””
211 return self._create_examples(
212 self._read_tsv(os.path.join(data_dir, “dev_matched.tsv”)),
213 “dev_matched”)

———- 修改为 ————

209 def get_dev_examples(self, data_dir):
210 “””See base class.”””
211 # return self._create_examples(
212 # self._read_tsv(os.path.join(data_dir, “dev_matched.tsv”)),
213 # “dev_matched”)
214 return self._create_examples(
215 self._read_tsv(os.path.join(data_dir, “dev.tsv”)), “dev”)


  • 修改4: 修改验证集文件名,在247行,由原来”dev_mismatched.tsv”修改为”dev.tsv”。

244 def get_dev_examples(self, data_dir):
245 “””See base class.”””
246 return self._create_examples(
247 self._read_tsv(os.path.join(data_dir, “dev_mismatched.tsv”)),
248 “dev_matched”)

———- 修改为 ————

244 def get_dev_examples(self, data_dir):
245 “””See base class.”””
246 # return self._create_examples(
247 # self._read_tsv(os.path.join(data_dir, “dev_mismatched.tsv”)),
248 # “dev_matched”)
249 return self._create_examples(
250 self._read_tsv(os.path.join(data_dir, “dev.tsv”)), “dev”)


  • 修改5: 因为我们存在多项填空,因此之前数据处理的时候添加了特殊标记[PAD],我们需要将该标记告诉模型,以便在数值映射的时映射成一个数字。找到run_glue.py中的464行,在其下面添加一行代码。

464 tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
465 do_lower_case=args.do_lower_case,
466 cache_dir=args.cache_dir if args.cache_dir else None)
467
468 ############### 这里是添加的代码 ###################
469 tokenizer.add_special_tokens({“pad_token”: “[PAD]”})
470 #####################################################
471 model = model_class.from_pretrained(args.model_name_or_path,
472 from_tf=bool(‘.ckpt’ in args.model_name_or_path),
473 config=config,
474 cache_dir=args.cache_dir if args.cache_dir else None)


  • 再次运行微调脚本:

sh run_glue.sh


  • 输出效果:

# 我们在GTX1080Ti和Tesla T4上分别做了GPU实验
# batchsize为16,maxseq_length为100, 其他参数默认不变的情况下
# 训练集大小为92172,GTX1080Ti每个epoch耗时约45min
# Tesla T4每个epoch耗时约36min,之后的所有实验在T4上进行
07/13/2020 22:52:06 - WARNING - __main
- Process rank: -1, device: cuda, n_gpu: 1, distributed training: False, 16-bits training: False
{‘cola’: , ‘mnli’: , ‘mnli-mm’: , ‘mrpc’: , ‘sst-2’: , ‘sts-b’: , ‘qqp’: , ‘qnli’: , ‘rte’: , ‘wnli’: }
07/13/2020 22:52:11 - INFO - transformers.configuration_utils - loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json from cache at /root/.cache/torch/transformers/45629519f3117b89d89fd9c740073d8e4c1f0a70f9842476185100a8afe715d1.65df3cef028a0c91a7b059e4c404a975ebe6843c71267b67019c0e9cfa8a88f0

07/13/2020 22:53:23 - INFO - transformers.tokenizationutils - loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt from cache at /root/.cache/torch/transformers/96435fa287fbf7e469185f1062386e05a075cadbf6838b74da22bf64b080bc32.99bcd55fc66f4f3360bc49ba472b940b8dcf223ea6a345deb969d607ca900729
07/13/2020 22:53:24 - INFO - transformers.tokenizationutils - Assigning [PAD] to the padtoken key of the tokenizer
07/13/2020 22:53:29 - INFO - transformers.modelingutils - loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin from cache at /root/.cache/torch/transformers/5b5b80054cd2c95a946a8e0ce0b93f56326dff9fbda6a6c3e02de3c91c918342.7131dcb754361639a7d5526985f880879c9bfd144b65a0bf50590bddb7de9059
07/13/2020 22:53:33 - INFO - transformers.modelingutils - Weights of BertForSequenceClassification not initialized from pretrained model: [‘classifier.weight’, ‘classifier.bias’]
07/13/2020 22:53:33 - INFO - transformers.modelingutils - Weights from pretrained model not used in BertForSequenceClassification: [‘cls.predictions.bias’, ‘cls.predictions.transform.dense.weight’, ‘cls.predictions.transform.dense.bias’, ‘cls.predictions.decoder.weight’, ‘cls.seqrelationship.weight’, ‘cls.seqrelationship.bias’, ‘cls.predictions.transform.LayerNorm.weight’, ‘cls.predictions.transform.LayerNorm.bias’]
07/13/2020 22:53:37 - INFO - main - Training/evaluation parameters Namespace(adamepsilon=1e-08, cachedir=’’, configname=’’, datadir=’./yxbdata/‘, device=device(type=’cuda’), doeval=True, dolowercase=False, dotrain=True, evalallcheckpoints=False, evaluateduring_training=False, fp16=False, fp16_opt_level=’O1’, gradient_accumulation_steps=1, learning_rate=2e-05, local_rank=-1, logging_steps=2000, max_grad_norm=1.0, max_seq_length=100, max_steps=-1, model_name_or_path=’bert-base-multilingual-cased’, model_type=’bert’, n_gpu=2, no_cuda=False, num_train_epochs=10.0, output_dir=’./bert_multi_finetuning_test5/‘, output_mode=’classification’, overwrite_cache=False, overwrite_output_dir=True, per_gpu_eval_batch_size=8, per_gpu_train_batch_size=8, save_steps=2000, seed=42, server_ip=’’, server_port=’’, task_name=’mnli-mm’, tokenizer_name=’’, warmup_steps=0, weight_decay=0.0)
07/13/2020 22:53:37 - INFO - __main
- Creating features from dataset file at ./yxb_data/
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - Writing example 0
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - Example
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - guid: train-4668
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - input_ids: 101 2650 2172 110 102 112 2650 2172 110 112 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - attention_mask: 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - token_type_ids: 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - label: M (id = 2)
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - Example
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - guid: train-18038
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - input_ids: 101 2775 6406 7735 5817 102 10662 12785 63158 24906 10162 30997 10060 2775 6406 7735 5817 10061 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - token_type_ids: 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - label: Y (id = 0)
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - Example
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - guid: train-36408
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - input_ids: 101 4015 7069 7349 2748 1881 7698 2999 2355 102 2726 4305 2542 7082 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - token_type_ids: 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - label: N (id = 1)
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - Example
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - guid: train-44178
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - input_ids: 101 24137 102 2079 4561 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - attention_mask: 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - token_type_ids: 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - label: N (id = 1)
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - Example
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - guid: train-25649
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - input_ids: 101 42141 30434 7082 2774 92734 24203 3354 4823 2774 4305 4140 6141 3069 102 42141 30434 7082 2774 15453 4350 3354 4823 2774 4333 3354 4823 2774 4333 4305 4140 6141 3069 132 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/13/2020 22:53:37 - INFO - transformers.data.processors.glue - label: N (id = 1)
07/13/2020 22:53:39 - INFO - transformers.data.processors.glue - Writing example 10000
07/13/2020 22:53:41 - INFO - transformers.data.processors.glue - Writing example 20000
07/13/2020 22:53:43 - INFO - transformers.data.processors.glue - Writing example 30000
07/13/2020 22:53:44 - INFO - transformers.data.processors.glue - Writing example 40000
07/13/2020 22:53:46 - INFO - transformers.data.processors.glue - Writing example 50000
07/13/2020 22:53:48 - INFO - transformers.data.processors.glue - Writing example 60000
07/13/2020 22:53:50 - INFO - transformers.data.processors.glue - Writing example 70000
07/13/2020 22:53:52 - INFO - transformers.data.processors.glue - Writing example 80000
07/13/2020 22:53:54 - INFO - transformers.data.processors.glue - Writing example 90000
07/13/2020 22:53:54 - INFO - __main
- Saving features into cached file ./yxb_data/cached_train_bert-base-multilingual-cased_100_mnli-mm
07/13/2020 22:54:09 - INFO - __main
- * Running training *
07/13/2020 22:54:09 - INFO - __main
- Num examples = 92172
07/13/2020 22:54:09 - INFO - __main
- Num Epochs = 4
07/13/2020 22:54:09 - INFO - __main
- Instantaneous batch size per GPU = 8
07/13/2020 22:54:09 - INFO - __main
- Total train batch size (w. parallel, distributed & accumulation) = 16
07/13/2020 22:54:09 - INFO - __main
- Gradient Accumulation steps = 1
07/13/2020 22:54:09 - INFO - __main
- Total optimization steps = 57610
Epoch: 0%| | 0/10 [00:00<?, ?it/s/root/miniconda3/lib/python3.6/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
warnings.warn(‘Was asked to gather along dimension 0, but all ‘
/root/miniconda3/lib/python3.6/site-packages/torch/optim/lr_scheduler.py:224: UserWarning: To get the last learning rate computed by the scheduler, please use get_last_lr().5761 [15:49<29:39, 2.11it/s]
warnings.warn(“To get the last learning rate computed by the scheduler, “
07/13/2020 23:09:59 - INFO - transformers.configuration_utils - Configuration saved in ./bert_multi_finetuning_test5/checkpoint-2000/config.json
07/13/2020 23:10:00 - INFO - transformers.modeling_utils - Model weights saved in ./bert_multi_finetuning_test5/checkpoint-2000/pytorch_model.bin
07/13/2020 23:10:00 - INFO - __main
- Saving model checkpoint to ./bert_multi_finetuning_test5/checkpoint-2000

Iteration: 100%|##########| 5761/5761 [46:17<00:00, 2.07it/s]
Epoch: 100%|##########| 1/1 [46:17<00:00, 2777.53s/it].09it/s]
07/14/2020 09:49:28 - INFO - main - globalstep = 5761, average loss = 0.505601485116762
07/14/2020 09:49:28 - INFO - main - Saving model checkpoint to ./bertmultifinetuningtest5/
07/14/2020 09:49:28 - INFO - transformers.configuration_utils - Configuration saved in ./bert_multi_finetuning_test5/config.json
07/14/2020 09:49:31 - INFO - transformers.modeling_utils - Model weights saved in ./bert_multi_finetuning_test5/pytorch_model.bin
07/14/2020 09:49:32 - INFO - transformers.configuration_utils - loading configuration file ./bert_multi_finetuning_test5/config.json
Iteration: 100%|##########| 5761/5761 [46:17<00:00, 2.07it/s]
Epoch: 100%|##########| 1/1 [46:17<00:00, 2777.53s/it].09it/s]
07/14/2020 09:49:28 - INFO - __main
- global_step = 5761, average loss = 0.505601485116762
07/14/2020 09:49:28 - INFO - __main
- Saving model checkpoint to ./bert_multi_finetuning_test5/
07/14/2020 09:49:28 - INFO - transformers.configuration_utils - Configuration saved in ./bert_multi_finetuning_test5/config.json
07/14/2020 09:49:31 - INFO - transformers.modeling_utils - Model weights saved in ./bert_multi_finetuning_test5/pytorch_model.bin
07/14/2020 09:49:32 - INFO - transformers.configuration_utils - loading configuration file ./bert_multi_finetuning_test5/config.json

07/14/2020 09:49:49 - INFO - main - * Running evaluation *
07/14/2020 09:49:49 - INFO - main - Num examples = 23043
07/14/2020 09:49:49 - INFO - main - Batch size = 16
Evaluating: 100%|##########| 1441/1441 [03:54<00:00, 6.15it/s]
07/14/2020 09:53:43 - INFO - main - * Eval results *
07/14/2020 09:53:43 - INFO - main - acc = 0.8513536431888209
{‘cola’: , ‘mnli’: , ‘mnli-mm’: , ‘mrpc’: , ‘sst-2’: , ‘sts-b’: , ‘qqp’: , ‘qnli’: , ‘rte’: , ‘wnli’: }

  • 生成文件:

-rw-r--r--. 1 root root 2 Jul 6 18:08 added_tokens.json
drwxr-xr-x. 2 root root 72 Jul 6 09:01 checkpoint-10000
drwxr-xr-x. 2 root root 72 Jul 6 09:17 checkpoint-12000
drwxr-xr-x. 2 root root 72 Jul 6 09:33 checkpoint-14000
-rw-r—r—. 1 root root 972 Jul 6 18:08 config.json
-rw-r—r—. 1 root root 25 Jul 6 18:14 eval_results.txt
-rw-r—r—. 1 root root 711473265 Jul 6 18:08 pytorch_model.bin
-rw-r—r—. 1 root root 112 Jul 6 18:08 special_tokens_map.json
-rw-r—r—. 1 root root 59 Jul 6 18:08 tokenizer_config.json
-rw-r—r—. 1 root root 1228 Jul 6 18:08 training_args.bin
-rw-r—r—. 1 root root 995526 Jul 6 18:08 vocab.txt


  • 文件解释:
    • pytorch_model.bin代表模型参数,可以使用torch.load加载查看;
    • traning_args.bin代表模型训练时的超参,如batch_size,epoch等,仍可使用torch.load查看;
    • config.json是模型配置文件,如多头注意力的头数,编码器的层数等,代表典型的模型结构,如bert,xlnet,一般不更改;
    • added_token.json记录在训练时通过代码添加的自定义token对应的数值,即在代码中使用add_token方法添加的自定义词汇;
    • special_token_map.json当添加的token具有特殊含义时,如分隔符,该文件存储特殊字符的及其对应的含义,使文本中出现的特殊字符先映射成其含义,之后特殊字符的含义仍然使用add_token方法映射;
    • checkpoint: 若干步骤保存的模型参数文件(也叫检测点文件);
    • eval_results.txt:最终的评估结果。

  • 关于微调脚本中bert编码后进行微调的结构:
    • 通过查看/root/miniconda3/lib/python3.6/site-packages/transformers/modeling_bert.py中第979-1051行。

979 @add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
980 the pooled output) e.g. for GLUE tasks. “””,
981 BERTSTARTDOCSTRING,
982 BERTINPUTSDOCSTRING)
983 class BertForSequenceClassification(BertPreTrainedModel):
984 r”””
985 labels: (optional) torch.LongTensor of shape (batch_size,):
986 Labels for computing the sequence classification/regression loss.
987 Indices should be in [0, ..., config.num_labels - 1].
988 If config.num_labels == 1 a regression loss is computed (Mean-Square loss),
989 If config.num_labels > 1 a classification loss is computed (Cross-Entropy).
990
991 Outputs: Tuple comprising various elements depending on the configuration (config) and inputs:
992 loss: (optional, returned when labels is provided) torch.FloatTensor of shape (1,):
993 Classification (or regression if config.num_labels==1) loss.
994 logits: torch.FloatTensor of shape (batch_size, config.num_labels)
995 Classification (or regression if config.num_labels==1) scores (before SoftMax).
996 hidden_states: (optional, returned when config.output_hidden_states=True)
997 list of torch.FloatTensor (one for the output of each layer + the output of the embeddings)
998 of shape (batch_size, sequence_length, hidden_size):
999 Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1000 attentions: (optional, returned when config.output_attentions=True)
1001 list of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length):
1002 Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
1003
1004 Examples::
1005
1006 tokenizer = BertTokenizer.from_pretrained(‘bert-base-uncased’)
1007 model = BertForSequenceClassification.from_pretrained(‘bert-base-uncased’)
1008 input_ids = torch.tensor(tokenizer.encode(“Hello, my dog is cute”, add_special_tokens=True)).unsqueeze(0) # Batch size 1
1009 labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
1010 outputs = model(input_ids, labels=labels)
1011 loss, logits = outputs[:2]
1012
1013 “””
1014 def __init
(self, config):
1015 super(BertForSequenceClassification, self).__init
(config)
1016 self.num_labels = config.num_labels
1017
1018 self.bert = BertModel(config)
1019 self.dropout = nn.Dropout(config.hidden_dropout_prob)
1020 self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
1021
1022 self.init_weights()
1023
1024 def forward(self, input_ids=None, attention_mask=None, token_type_ids=None,
1025 position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
1026
1027 outputs = self.bert(input_ids,
1028 attention_mask=attention_mask,
1029 token_type_ids=token_type_ids,
1030 position_ids=position_ids,
1031 head_mask=head_mask,
1032 inputs_embeds=inputs_embeds)
1033
1034 pooled_output = outputs[1]
1035
1036 pooled_output = self.dropout(pooled_output)
1037 logits = self.classifier(pooled_output)
1038
1039 outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
1040
1041 if labels is not None:
1042 if self.num_labels == 1:
1043 # We are doing regression
1044 loss_fct = MSELoss()
1045 loss = loss_fct(logits.view(-1), labels.view(-1))
1046 else:
1047 loss_fct = CrossEntropyLoss()
1048 loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1049 outputs = (loss,) + outputs
1050
1051 return outputs # (loss), logits, (hidden_states), (attentions)

  • 分析:
    • 第1034行开始,从BERT的输出outputs获得pooled_output,它是在其输出索引1的位置上,这也是我们所有的[CLS]的位置对应的张量。
    • 为什么能够拿[CLS]的位置张量充当整个句子的表示呢?这是因为[CLS]标记并不是某个具体的有意义的词,因此它不会对某些词偏倚,也就是能够‘公平’地表示每一个词汇的含义,即整个句子的表示。
    • 得到pooled_output后,接着使用dropout层,用于防止过拟合,最后接了一层的classifier,也就是全连接输出层。
    • 最后将会选择损失函数,当设定标签数为1时,默认为回归问题,则使用MSE损失。否则,就像我们现在处理的多分类问题,将使用交叉熵损失。这就是我们使用的微调脚本的分类头具体结构和损失函数选择。

第四步:重要的超参数调优

  • 深度学习中大型模型的超参数调优不像机器学习一样简单,因为一次训练的时间成本太高,因此,这里主要给大家说明当前数据集中对结果最有影响里的超参数。(通过run_glue.py可以看到允许我们设置的全部超参数)
  • 通过大量实验,在当前数据集上,对结果影响最大超参数有两个,分别是:max_seq_lengthweight_decay,对于max_seq_length默认为512,通过之前的数据分析,我们已经将其设定为100,而weight_decay默认为0.0,通过大量实验我们将其设定为1e-2,以下不同weight_decay对结果的影响: | weight_decay | 准确率 | | —- | —- | | 0.0 | 85.1% | | 1e-1 | 86.3% | | 1e-2 | 87.5% | | 1e-3 | 87.3% | | 1e-4 | 86.0% |

  • 什么是weight_decay:
    • 说到weight_decay就不得不提到当前BERT模型使用的优化器AdamW,大家之前可能对Adam优化器比较熟悉,它是一种能够利用历史梯度来调节学习率的优化器(查看具体公式),AdamW本质是Adam的一种改进,为了防止大型模型普遍的过拟合现象而提出,我们知道,防止过拟合可以通过向损失函数中添加L1/L2正则项来完成(早期的AdamW就是这样实现的,但与论文作者的想法并不同),因为Adam + L2并不具备任何新的创意。现在AdamW实现并不修改损失函数,而是要求直接在梯度更新时就是引入衰减项:
  • 其中公式中的w就是weight_decay(权重衰减系数)。
  • 修改损失函数要比这种直接修改参数的方式消耗更多的算力和显存。

第五步:使用标签平滑技术

  • 标签平滑的作用:
    • 就是小幅度的改变原有标签值的值域,如[0, 0, 1] —> [0.1, 0.1, 0.8],它适用于人工的标注数据可能并非完全正确的情况, 可以使用标签平滑来弥补这种偏差, 减少模型对某一条规律的绝对认知, 以防止过拟合。

  • 标签平滑技术看似原理简单,但实现起来并不容易,因为像pytorch这样的工具中,当使用原生的交叉熵损失(CrossEntropyLoss())时,要求标签值必须为整型,在这里我们不去修改交叉熵损失的源码,而是重新定义一个类似计算规则的标签平滑交叉熵损失

## 以下代码直接使用即可
class LabelSmoothingCELoss(nn.Module):
def init(self, classes, smoothing=0.0, dim=-1):
super(LabelSmoothingCELoss, self).init()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.cls = classes
self.dim = dim

  1. def forward(self, pred, target):<br /> pred = pred.log_softmax(dim=self.dim)<br /> with torch.no_grad():<br /> true_dist = pred.data.clone()<br /> true_dist = torch.zeros_like(pred)<br /> true_dist.fill_(self.smoothing / (self.cls - 1))<br /> # .scatter_也是一种数据填充方法,目的仍然是将self.confidence填充到true_dist中<br /> # 第一个参数0/1代表填充的轴,大多数情况下使用scatter_都使用纵轴(1)填充<br /> # 第二个参数就是self.confidence的填充规则,即填充到第几列里面,如[[1], [2]]代表填充到第二列和第三列里面<br /> # 第三个参数就是填充的数值,int/float<br /> true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)<br /> return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

  • 代码位置:
    • /data/YXB/smoothing.py

  • 调用:

if __name__ == "__main__":
predict = torch.FloatTensor([[1, 1, 1, 1, 1]])
target = torch.LongTensor([2])
LSL = LabelSmoothingLoss(3, 0.03)
print(LSL(predict, target))


  • 输出效果:

tensor(1.6577)


  • 修改huggingface transformer中的源码:
    • 路径:/usr/local/lib/python3.7/site-packages/transformers/modeling_bert.py

···
1036 pooled_output = self.dropout(pooled_output)
1037 logits = self.classifier(pooled_output)
1038
1039 outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
1040
1041 if labels is not None:
1042 if self.num_labels == 1:
1043 # We are doing regression
1044 loss_fct = MSELoss()
1045 loss = loss_fct(logits.view(-1), labels.view(-1))
1046 else:
# 注释掉之前的交叉熵损失函数
1047 # loss_fct = CrossEntropyLoss()
# 导入之前的LabelSmoothingCELoss,填入类别参数和平滑系数
1048 from smoothing import LabelSmoothingCELoss
1049 loss_fct = LabelSmoothingCELoss(3, smoothing=0.1)
1050 loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1051 outputs = (loss,) + outputs
1052
1053 return outputs # (loss), logits, (hidden_states), (attentions)
···


  • 标签平滑技术的效果: | label smoothing | 准确率 | | —- | —- | | 0.05 | 87.5% | | 0.10 | 87.9% | | 0.15 | 87.1% | | 0.20 | 86.8% | | 0.25 | 86.2% |

  • 结论:
    • 在当前的数据上,标签平滑带来的效果并不是很明显,只提升大约0.4%,但在后期加大数据量级时,标签平滑将带来显著效果。

第六步: Badcase分析以便进行数据调整和重新训练

  • 什么是Badcase分析:
    • 我们都知道,模型在验证集上的准确率是我们的评估指标,因此对验证集上数据预测的正确与否十分重要,我们把在验证集上预测错误的数据称作badcase。对这些Badcase进行人为的统计分析,推测预测错误可能的原因,以及对指定的类型的数据进行扩充等,比如:我们的badcase中发现是数字识别不准确导致的类别错误,那么就补充大量数字有关的样本再进行训练。

  • Badcase分析的实现:
    • 首先需要实现模型的单条预测功能:

import torch
import os
# 从transformers中导入BERT模型的相关工具
from transformers import BertForSequenceClassification, BertTokenizer

模型文件所在路径,这里我们已经训练完模型
# 并在./bert_multi_finetuning_test/文件中写入该代码
model_path = “./“
model = BertForSequenceClassification.from_pretrained(model_path)
# 还原数值映射器
tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=True)
# 标签列表
label_list = [“Y”, “N”, “M”]
# 输入句子对的[SEP],用于分割句子的标志映射的数值
MARK = 102

def model_prediction(sentence1, sentence2, model, tokenizer):
“””调用训练好的模型”””
# 恢复词汇映射器
indexed_tokens = tokenizer.encode(sentence1, sentence2)
print(“Indexed_tokens:”, indexed_tokens)
k = indexed_tokens.index(MARK)
segments_ids = [0](k+1) + [1](len(indexed_tokens)-(k+1))
# 转化张量形式
segments_tensors = torch.tensor([segments_ids])
tokens_tensor = torch.tensor([indexed_tokens])
# 使用带有问答模型头的预训练模型获得结果
with torch.no_grad():
logits = model(tokens_tensor, token_type_ids=segments_tensors)
# 获得预测的标签
predicted_labels = label_list[torch.argmax(logits[0]).item()]
return predicted_labels


  • 文件位置:
    • /home/YXB/bert_multi_finetuning_test/badcase.py

  • 调用:

sentence1 = "构造[PAD]函数"
sentence2 = “构造[PAD]函数”
res = model_prediction(sentence1, sentence2, model, tokenizer)
print(res)


  • 输出效果:

# 编码一致,间隔[PAD]映射成0
Indexed_tokens: [101, 4514, 7740, 0, 2529, 4305, 102, 4514, 7740, 0, 2529, 4305, 102]

Y


  • 对验证集中预测错误的例子进行提取并写入文件:

import pandas as pd

def badcase_analysis(model, tokenizer, valid_data_path, output_path):
“””使用模型对验证集中的数据进行预测,并提取所有预测错误的数据样本

  1. Args:<br /> model: 需要使用的模型<br /> valid_data_path: 验证集所在路径
  2. Return: <br /> badcase csv文件,包括预测的样本,真实的标签,预测的标签<br /> """<br /> # 读取验证集数据<br /> df = pd.read_csv(valid_data_path, sep="\t", iterator=True)<br /> # 下面筛选验证集中预测错误的数据<br /> badcase_list = []<br /> for item in next(df).values.tolist():<br /> pre_label = model_prediction(str(item[1]), str(item[2]), model, tokenizer)<br /> if not pre_label == str(item[-1]):<br /> badcase_list.append(item + [pre_label])<br /> # 将对应的索引,输入的文本,真实标签和预测标签写入新的文件<br /> df = pd.DataFrame(badcase_list, columns=["index", "s1", "s2", "true", "pre"])<br /> df.to_csv(output_path, index=False, sep="\t")

  • 代码位置:
    • /home/YXB/bert_multi_finetuning_test/badcase.py

  • 调用:

# 指明验证集路径
# model和tokenizer在上一个函数中已经给出
valid_data_path = “../yxb_data/dev.tsv”
output_path = “./badcase.csv”
badcase_analysis(model, tokenizer, valid_data_path, output_path)


  • 输出效果:
    • 在当前路径下生成badcase.csv文件。

  • badcase文件查看:

index s1 s2 true pre
18822 网页或网络,互联网 网页 N Y
50874 flowlayout setlayout() Y N
17522 cglib代理 jdk动态代理 Y N
8513 [][PAD][||] 真[PAD]假 N Y
79831 isnull isnotnull或(isnotnull) Y N
35897 虚拟机 java虚拟机(或jvm) Y N
75810 表现[PAD]形式 表现[PAD]行为 Y M
43710 选择器(属性1:属性1) 选择器{属性1:属性值1属性2:属性值2属性3:属性值3} N Y
11869 1024 1027 N Y
22438 参数类型和参数名 参数列表 N Y
2389 选择器{属性:属性值} 选择器{属性1:属性值1属性2:属性值2属性3:属性值3} M Y
1851 notnull isnotnull或(isnotnull) N M
15368 44 46 N Y
27890 主机根目录 当前web应用程序的根目录 Y N
86614 howcreateprocedure showcreateprocedure Y N
1223 length属性 frames.length M N
51829 5432 5431 N Y
15973 collectionmap comparator Y N
4382 一维 等差 M N
15202 行为 方法 Y N
49325 floor(3+rand()(8-3+1)) floor(3+rand()6) N Y
45307 键值 映射 Y N
87861 关系运算符 逻辑运算符 Y N
76619 大数据 dt Y M
8266 setcookie set_cookie() N Y
9846 h2{font-size:16pxcolor} h2{font-size:16pxcolor:red} M Y


对badcase进行分析:

  • 对于数字类型的答案,比如”1024”和”1027”的判断比较模糊,可能由于训练数据中该类型的答案对出现较少,且存在一些脏数据(错误标注的情况),但我们知道,对于填空题如果是只有一个数值作为答案,那么绝大多数情况应该是唯一的。因此,我们决定,当正确答案和学生答案都是数值类型时,将其作为初始判定的规则(相同则满分,不同则零分),而不再使用模型来判断。
    • 数值判断规则:

def is_num(student_answer: str, true_answer: str):
“””当两者的都是数值类型时,如果不同则为N, 如果相同则为Y“””
try:
# 如: 对”1.65”的表示使用eval变成1.65, 再进行对比
if eval(student_answer) == eval(true_answer):
return “Y”
else:
return “N”
# 无法判断说明不满足数值类型条件,返回answer进行下一个环节
except Exception as e:
pass
return student_answer, true_answer


  • 调用:

student_answer = "1024"
true_answer = “1028”
res = is_num(student_answer, true_answer)
print(res)


  • 输出效果:

N


  • 模型对于比较长的单词中出现一些拼写错误,比如:contentreserver和contentobserver,大多数情况都会判定为”满分”, 经过我们对语料的分析,主要原因是有”脏数据”引起的,也就是在提供给我们的满分语料中,确实有很多老师认为拼写错误无关紧要。通过与该部门沟通,这种情况是要避免的,因此我们又重新对语料中该类型的答案进行了标注,使得满分数据减少约8000条,对应零分数据增加8000条。最终,模型效果提升3.8%左右。 | | 拼写错误数据规范前 | 拼写错误数据规范后 | | —- | —- | —- | | 准确率 | 87.9% | 91.7% |

  • 对于回答不完整的现象,模型经常会出现预测失误,比如:学生答案为:VM,而正确答案为java虚拟机(或JVM),这应该是一个中间分的回答,但模型总是预测为满分。经过我们的分析,是由于该类型的数据过少导致的。因此,我们需要进行该类型数据的增加,从历史数据中我们又增加了大概10000条左右的数据(中间分且为只描述了答案的一部分导致的)进行重新训练。通过局部数据增强方法,我们的模型又提升了1.8%。 | | 局部数据增强前 | 局部数据增强后 | | —- | —- | —- | | 准确率 | 91.7% | 93.5% |

  • 我们还可以通过其他对badcase的分析来提升模型的准确率,但注意这些badcase必须具有代表性,否则,带来的调整可能会降低指标。而且每次调整后,我们不仅要观察模型的指标变化,还需要参考新的badcase文件。最后,我们模型的验证准确率已经达到了93.1%左右。


  • badcase分析总结:
    • 对于我们的系统,整个的解决方案是由规则+模型来解决,规则包括直接判断答案相同,经过数据清洗后判断答案相同,以及对数值型答案判断相同。这些规则大概会覆盖10%左右的输入,而规则的准确率是极高的,可认为是99%,剩下90%的问题就会使用模型进行预测,准确率为93.1%,因此系统的最终准确率约为: 10%_99% + 90%_93.5% = 94.05%。 | | 规则+模型优化前 | 规则+模型优化后 | | —- | —- | —- | | 准确率 | 87.5% | 94.0% |

小节总结

  • 学习了得分判别模型的整体训练流程:
    • 第一步: 训练集与验证集的划分
    • 第二步: 根据现有任务选择标准任务类型
    • 第三步: 使用微调脚本进行训练和验证
    • 第四步: 重要的超参数调优
    • 第五步: 使用标签平滑技术
    • 第六步: Badcase分析以便对数据进行调整和重新训练


5.2 模型服务的部署


学习目标

  • 了解什么是模型热更新以及如何做到热更新。
  • 了解Flask框架及其相关的服务组件。
  • 掌握使用Flask框架将模型封装成服务的流程。

什么是模型热更新

  • 因为训练AI模型往往是较大的文件,在每次IO时往往比较耗时,因此会选择在服务开启时读入内存,避免IO操作。而这样的话,就意味着当我们更新模型时需要暂停服务, 这对于在线任务是非常不可取的行为;因此我们需要一种既能避免IO又能使用户无感知的方式,这种的要求就是模型热更新要求。

如何做到热更新

  • 最常见的满足热更新要求的方法就是一同开启两个模型服务,一个作为正式使用,一个作为backup(备用),当我们有更新需求时,将正式服务暂停进行模型更换,而此时备用服务将继续为用户服务,直到正式服务重新上线。在正式服务运转正常后,再为备用服务更换模型。

Flask服务组件

  • web框架FLask:
    • Flask框架是当下最受欢迎的python轻量级框架, 也是pytorch官网指定的部署框架. Flask的基本模式为在程序里将一个视图函数分配给一个URL,每当用户访问这个URL时,系统就会执行给该URL分配好的视图函数,获取函数的返回值.

  • 作用:
    • 在项目中, Flask框架是主逻辑服务和句子相关模型服务使用的服务框架.

  • 安装:

# 使用pip安装Flask
pip install Flask==1.1.1


  • 基本使用方法:

# 导入Flask类
from flask import Flask
# 创建一个该类的实例app, 参数为name, 这个参数是必需的,
# 这样Flask才能知道在哪里可找到模板和静态文件等东西.
app = Flask(name)

使用route()装饰器来告诉Flask触发函数的URL
@app.route(‘/‘)
def hello_world():
“””请求指定的url后,执行的主要逻辑函数”””
# 在用户浏览器中显示信息:’Hello, World!’
return ‘Hello, World!’

if name == ‘main‘:
app.run(host=”0.0.0.0”, port=5003)


  • 代码位置:
    • /data/ItcastBrain/Yxb/bert_server/app.py

  • 启动服务:

python app.py


  • 查看效果:

  • web组件Waitress:
    • Waitress是Flask官方推荐的生产环境使用组件,与Gunicorn类似,它具有使用非常简单,轻量级的资源消耗,以及高性能等特点。

  • 作用:
    • 在项目中和Flask框架一同使用, 处理请求, 因其高性能的特点能够有效减少服务丢包率.

  • 安装:

# 使用pip安装
pip install waitress==1.4.4


  • 基本使用方法:

# 注意:kill掉之前的5003端口服务,不再使用原生的启动方式
# 而是使用waitress启动Flask服务:
waitress-serve —threads=1 —listen=*:5003 app:app
# —threads 代表开启的线程数, 我们只开启一个线程
# —listen 服务的IP地址和端口
# app:app 是指执行的主要对象位置, 在app.py中的app对象


使用Flask框架将模型封装成服务

我们可以将模型封装成服务的流程分为三步:

  • 第一步: 编写app.py文件
  • 第二步: 使用waitress启动服务
  • 第三步: 编写test.py进行接口测试
  • 第四步: 使用Nginx代理两个服务满足热更新

第一步: 编写app.py文件,代码实现如下:

# Flask框架固定工具
from flask import Flask
from flask import request

app = Flask(name)

import os
import torch

从transformers中导入BERT模型的相关工具
from transformers import BertForSequenceClassification, BertTokenizer

model_path = “/data/ItcastBrain/Yxb/bert_model/“
model = BertForSequenceClassification.from_pretrained(model_path)
# 还原数值映射器
tokenizer = BertTokenizer.from_pretrained(
model_path, do_lower_case=True
)
# 标签列表
label_list = [“Y”, “N”, “M”]
# 输入句子对的[SEP],用于分割句子的标志映射的数值
MARK = 102

定义服务请求路径和方式, 这里使用POST请求
@app.route(“/v1/model_prediction/“, methods=[“POST”])
def model_prediction():
“””调用训练好的模型进行预测
“””
request_json = request.get_json()
sentence1 = request_json[“student_answer”]
sentence2 = request_json[“true_answer”]
# 恢复词汇映射器
indexed_tokens = tokenizer.encode(sentence1, sentence2)
k = indexed_tokens.index(MARK)
segments_ids = [0] (k + 1) + [1] (
len(indexed_tokens) - (k + 1)
)
# 转化张量形式
segments_tensors = torch.tensor([segments_ids])
tokens_tensor = torch.tensor([indexed_tokens])
with torch.no_grad():
logits = model(tokens_tensor, token_type_ids=segments_tensors)
# logits = model(**patch)
# 获得预测的标签
predicted_labels = label_list[torch.argmax(logits[0]).item()]
return predicted_labels


  • 代码位置:
    • /data/ItcastBrain/Yxb/bert_server/app.py

第二步: 使用waitress来启动服务

waitress-serve --threads=1 --listen=*:5005 app:app


  • 输出效果:

Serving on [http://0.0.0.0:5003](http://0.0.0.0:5003)
Serving on http://[::]:5003


第三步: 编写test.py进行接口测试

import requests

url = “http://0.0.0.0:5003/v1/model_prediction/

data = {“student_answer”: “ALT”, “true_answer”: “TITLE”}
# 多层嵌套必须使用json
res = requests.post(url, json=data, timeout=200)
print(res.text)


  • 代码位置:
    • data/ItcastBrain/Yxb/bert_server/test.py

  • 输出效果:

N


第四步: 使用Nginx代理两个服务满足热更新

到这里说明我们模型服务能够正常工作,之后我们将启动两个同样的服务,分别使用5005和5006端口, 并将两个服务使用Nginx代理宜满足热更新。下面对nginx进行一些简单介绍,并对其中的配置进行说明。


  • Nginx:
    • Nginx是一个高性能的HTTP和反向代理web服务器,也是工业界web服务最常使用的外层代理。

  • Nginx热更新部分配置说明:
    • 这些配置已经为大家写好,可以在/data/ItcastBrain/conf/nginx/nginx.conf中进行查看。

...

以下是与热更新有关的配置
# 这里代理两个端口的服务
# 其中5004为backup,即当5003服务停止时被启用
# 这里的prod要与下面proxy_pass中http://后的名称相同
upstream prod {
server 0.0.0.0:5003;
server 0.0.0.0:5004 backup;
}

nginx的外层服务使用8086端口
server {
listen 8086;
server_name 0.0.0.0;
location /static/ {
alias /data/ItcastBrain/static/;
}

  1. # 这里注意prod要与上面upstream后的名称相同<br /> location / {<br /> proxy_pass http://prod;<br /> include /data/ItcastBrain/conf/nginx/uwsgi_params;<br /> proxy_set_header X-Real-IP $remote_addr;
  2. }<br /> }


  • Nginx的启动与关闭:

# 实际中我们并不会直接启动Nginx,而是在整体服务部署时使用supervisor进行启动和关闭
# 因此这里大家了解以下启动命令即可
# -c是指向配置文件
nginx -c /data/ItcastBrain/conf/nginx/nginx.conf

关闭nginx
nginx -s stop


小节总结

  • 学习了什么是热更新与如何做到热更新
  • 学习了Flask服务组件的使用
  • 学习了将模型封装成服务的流程
    • 第一步: 编写app.py文件
    • 第二步: 使用waitress启动服务
    • 第三步: 编写test.py进行接口测试
    • 第四步: 使用Nginx代理两个服务满足热更新