由于参加NLP比赛经常需要使用预训练模型进行finetune,故而写下此贴,记录finetune必要的流程。

增大词汇表

比赛数据集的词汇量往往多于预训练模型,所以需要添加一些新词汇,例子如下:

  1. tokenizer = BertTokenizer.from_pretrained("./chinese-roberta-wwm-ext")
  2. model = BertModel.from_pretrained("./chinese-roberta-wwm-ext")
  3. # 添加新token,如果token已存在返回0,否则返回1
  4. tokenizer.add_tokens(tokens)
  5. # 修改模型embedding大小,即添加新的embedding
  6. model.resize_token_embeddings(len(tokenizer))
  7. # 将tokenizer与模型进行更新并存到本地,新增词语存到add_tokens.json文件内
  8. model.save_pretrained(folder)
  9. tokenizer.save_pretrained(folder)

不等长序列如何拼接为一个batch

训练时需要将一组不等长序列整成一个tensor,可以利用下面的方法进行转换

  1. tokens = ["北医三", "你好"]
  2. print(tokenizer(tokens, padding=True, truncation=True, return_tensors="pt", verbose=False))

输出为:

  1. {'input_ids': tensor([[ 101, 1266, 1278, 676, 102],
  2. [ 101, 872, 1962, 102, 0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0],
  3. [0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1],
  4. [1, 1, 1, 1, 0]])}

其中input_ids即为token对应的id序列,返回pytorch long tensor类型。attention_mask表示哪些值有意义,1代表正常值,0代表padding值,计算attention时需要将其忽略。
token_type_ids用于问答的场景下,0指代第一句,1指代第二句,由于我不做问答相关的内容,故而可以添加选项return_token_type_ids=False

模型获取embedding

google的BERT有两个主要输出,分别是last_hidden_statepooler_output,意义分别为:

  • last_hidden_state:经过模型处理得到的序列中每个token的embedding序列
  • pooler_output:这个输出是第一个token[CLS]经过线性层+Tanh层进一步处理得到的,google官方的建议是如果要得到更好的句子表征时,尽量对last_hidden_state进行平均或者其它池化即可。

在一个batch中获取各个句子的embedding的操作为:

  1. tokens = [
  2. "比赛好难",
  3. "写代码好tm慢啊",
  4. "什么时候才能拿好offer"
  5. ]
  6. inputs = tokenizer(tokens, padding=True, truncation=True, return_tensors="pt", verbose=False, return_token_type_ids=False)
  7. outputs = model(**inputs) # 返回值为一个BaseModelOutputWithPoolingAndCrossAttentions
  8. # 获取序列中各个token的表征
  9. last_hidden_state = outputs.last_hidden_state
  10. attention_mask = inputs["attention_mask"]
  11. # 进行平均池化,padding的值不包括在内
  12. s_emb = last_hidden_state * attention_mask.unsqueeze(-1) # [batch_size, seq_len, emb_dim]
  13. s_emb = s_emb.sum(dim=1) / attention_mask.sum(-1, keepdim=True) # [batch_size, emb_dim]