工业界

学术界

PLOME

paper ref: https://aclanthology.org/2021.acl-long.233.pdf

image.png
预训练语言模型 可以看做是macbert的进化

  • PLOME模型是专门针对中文文本纠错任务构建的预训练语言模型。
  • 在训练预训练语言模型时,采用基于语义混淆集的MASK策略
  • 拼音和笔画作为预训练语言模型以及模型微调的输入
  • 字符预测任务拼音预测任务作为预训练语言模型以及模型微调的训练目标
  • PLOME预训练语言模型的下游任务主要是文本纠错任务
  • 微调和训练目标 和 预训练任务一致,都设置了字符预测和拼音预测任务;但预训练仅需要对替换字符进行预测,实际微调使用过程中需要对所有的输入字符进行预测

    ErnieCSC(百度)

    基于ERNIE的中文拼写纠错模型,模型已经开源在PaddleNLP的 模型库中https://bj.bcebos.com/paddlenlp/taskflow/text_correction/csc-ernie-1.0/csc-ernie-1.0.pdparams
    image.png

  • 构建了一整套端到端中文文本纠错模型,包括构建预训练语言模型MLM-phonetics和微调下游纠错任务

  • MLM-base 遮盖了15%的词进行预测, MLM-phonetics 遮盖了20%的词进行预测。
  • MLM-base 的遮盖策略基于以下3种:[MASK]标记替换(和BERT一致)、随机字符替换(Random Hanzi)、原词不变(Same)。且3种遮盖策略占比分别为: 80% 、10%、10%。MLM-phonetics的Mask策略基于以下3种:[MASK]标记替换(和BERT一致)、字音混淆词替换(Confused-Hanzi)、混淆字符的拼音替换(Noisy-pinyin)。且这3种遮盖策略分别占比为: 40%、30%、30%。
  • 错误纠正任务错误检测任务作为预训练语言模型以及模型微调的训练目标

    未来方向与挑战

  • 如果有足够的对齐语料,可以继续沿预训练模型的角度进行,Transformer的编码解码器思路,并且引入基于预训练的Seq2Seq模型,例如GPT,BART等;

  • 模型需要支持热更新,支持时事热点中的新词,例如当下热点的”传闻中的成仙仙” -> “传闻中的陈芊芊”;
  • 对于不等长的文本纠错缺乏可用模型;
  • 对于复杂的句法错误以及语义中的知识性错误、逻辑性错误、表意不明还不能有效进行处理;

    pycorrector

    基于规则

    文本预处理

    初始化:

  • 包括分词jieba(需将 自定义混淆集,专名集加入用户自定义词典)

  • 加载各类词典文件
  • 语言模型kenlm

文本切分为句子,以标点符号切分

错误检测

1. 自定义混淆集加入疑似错误词典
2. 专名错误检测,专名词典,包括成语、俗语、专业领域词等

  • 分词
  • 获得1,2,3,4gram的结果
  • 词长度过滤,max_word_length: 专名词的最大长度为4; min_word_length:专名词的最小长度为2
  • 专名词典中的词计算相似度(计算两个词的拼音和字形相似度),策略是取大的作为相思分数
  • 如果相似度大于>阈值,则加入疑似错误词典

3. 词错误

  • 分词后,过滤(数字,标点符号,英文,非中文等)字符串

    1. if self.is_filter_token(token):
    2. continue
    1. def is_filter_token(token):
    2. """
    3. 是否为需过滤字词
    4. :param token: 字词
    5. :return: bool
    6. """
    7. result = False
    8. # pass blank
    9. if not token.strip():
    10. result = True
    11. # pass num
    12. if token.isdigit():
    13. result = True
    14. # pass alpha
    15. if is_alphabet_string(token.lower()):
    16. result = True
    17. # pass not chinese
    18. if not is_chinese_string(token):
    19. result = True
    20. return result
  • 将未登录词加入疑似错误词典

    1. # pass in dict
    2. if token in self.word_freq:
    3. continue
    4. maybe_err = [token, begin_idx + start_idx, end_idx + start_idx, ErrorType.word]

    4. 字错误,语言模型检测疑似错误字

  • 获得2-gram 和3-gram 的语言模型得分scores;

  • 滑动窗口补全得分,ngram,窗口移动,前后得补全n-1个元素
  • 取拼接后的n-gram平均得分— line32
  • 取疑似错字的位置,通过平均绝对离差(MAD)或者 通过平均值上下n倍标准差之间属于正常点

    1. if self.is_char_error_detect:
    2. try:
    3. ngram_avg_scores = []
    4. for n in [2, 3]: #选取2-gram,3-gram.
    5. scores = []
    6. # sentence 今天新情很好,句子长度为6
    7. for i in range(len(sentence) - n + 1):
    8. word = sentence[i:i + n]
    9. '''
    10. word: 今天
    11. word: 天新
    12. word: 新情
    13. word: 情很
    14. word: 很好
    15. '''
    16. score = self.ngram_score(list(word)) # kenlm加载已训练好的工具包
    17. scores.append(score)
    18. # scores=[-4.004828929901123, -5.91748571395874, -5.758666038513184, -5.612854957580566, -4.5769429206848145]
    19. if not scores:
    20. continue
    21. # 移动窗口补全得分,# ngram,窗口移动,前后得补全n-1个元素
    22. for _ in range(n - 1):
    23. scores.insert(0, scores[0])
    24. scores.append(scores[-1])
    25. # [-4.004828929901123, -4.004828929901123, -5.91748571395874, -5.758666038513184, -5.612854957580566, -4.5769429206848145, -4.5769429206848145]
    26. avg_scores = [sum(scores[i:i + n]) / len(scores[i:i + n]) for i in range(len(sentence))]
    27. ngram_avg_scores.append(avg_scores)#2-gram3-gram
    28. # [[-4.004828929901123, -4.961157321929932, -5.838075876235962, -5.685760498046875, -5.09489893913269, -4.5769429206848145],
    29. # [-7.758430480957031, -8.252998987833658, -8.834455808003744, -8.381499767303467, -7.4339752197265625, -6.399562358856201]]
    30. if ngram_avg_scores:
    31. # 取拼接后的n-gram平均得分,2-gram和3-gram
    32. sent_scores = list(np.average(np.array(ngram_avg_scores), axis=0))
    33. # 取疑似错字信息
    34. for i in self._get_maybe_error_index(sent_scores):
    35. token = sentence[i]
    36. # pass filter word
    37. if self.is_filter_token(token):
    38. continue
    39. # pass in stop word dict
    40. if token in self.stopwords:
    41. continue
    42. # token, begin_idx, end_idx, error_type
    43. maybe_err = [token, i + start_idx, i + start_idx + 1,
    44. ErrorType.char]
    45. self._add_maybe_error_item(maybe_err, maybe_errors)
    46. except IndexError as ie:
    47. logger.warn("index error, sentence:" + sentence + str(ie))
    48. except Exception as e:
    49. logger.warn("detect error, sentence:" + sentence + str(e))
    1. '''
    2. 平均绝对离差(mean absolute deviation)是用样本数据相对于
    3. 其平均值的绝对距离来度量数据的离散程度。平均绝对离差也称为平均离差(mean deviation)。
    4. 平均绝对离差定义为各数据与平均值的离差的绝对值的平均数
    5. '''
    6. def _get_maybe_error_index(scores, ratio=0.6745, threshold=2):
    7. """
    8. 取疑似错字的位置,通过平均绝对离差(MAD)
    9. :param scores: np.array
    10. :param ratio: 正态分布表参数
    11. :param threshold: 阈值越小,得到疑似错别字越多
    12. :return: 全部疑似错误字的index: list
    13. """
    14. result = []
    15. scores = np.array(scores)
    16. if len(scores.shape) == 1:
    17. scores = scores[:, None]
    18. median = np.median(scores, axis=0) # get median of all scores
    19. margin_median = np.abs(scores - median).flatten() # deviation from the median
    20. # 平均绝对离差值
    21. med_abs_deviation = np.median(margin_median)
    22. if med_abs_deviation == 0:
    23. return result
    24. y_score = ratio * margin_median / med_abs_deviation
    25. # 打平
    26. scores = scores.flatten()
    27. maybe_error_indices = np.where((y_score > threshold) & (scores < median))
    28. # 取全部疑似错误字的index
    29. result = [int(i) for i in maybe_error_indices[0]]
    30. return result
    31. def _get_maybe_error_index_by_stddev(scores, n=2):
    32. """
    33. 取疑似错字的位置,通过平均值上下n倍标准差之间属于正常点
    34. :param scores: list, float
    35. :param n: n倍
    36. :return: 全部疑似错误字的index: list
    37. """
    38. std = np.std(scores, ddof=1)
    39. mean = np.mean(scores)
    40. down_limit = mean - n * std
    41. upper_limit = mean + n * std
    42. maybe_error_indices = np.where((scores > upper_limit) | (scores < down_limit))
    43. # 取全部疑似错误字的index
    44. result = list(maybe_error_indices[0])
    45. return result

    错误修正

  • 自定义混淆集加入疑似错误词典

    • 根据词典定义直接用正确的替换
  • 专有名词
    • 根据专有名词词典定义直接用正确的替换
  • 字/词错误
    • 生成纠错候选集,根据相同的拼音获取,根据混淆词典中定义混淆词对获取
    • 通过语言模型纠正字词错误
      1. """
      2. 通过语言模型纠正字词错误
      3. :param cur_item: 当前词
      4. :param candidates: 候选词
      5. :param before_sent: 前半部分句子
      6. :param after_sent: 后半部分句子
      7. :param threshold: ppl阈值, 原始字词替换后大于该ppl值则认为是错误
      8. :param cut_type: 切词方式, 字粒度
      9. :return: str, correct item, 正确的字词
      10. """
      11. result = cur_item
      12. if cur_item not in candidates:
      13. candidates.append(cur_item)
      14. ppl_scores = {i: self.ppl_score(segment(before_sent + i + after_sent, cut_type=cut_type)) for i in candidates}
      15. sorted_ppl_scores = sorted(ppl_scores.items(), key=lambda d: d[1])

      基于深度学习网络

      主要使用了多种深度模型应用于文本纠错任务,分别是前面模型小节介绍的macbertseq2seqbertelectratransformerernie-csc,各模型方法内置于pycorrector文件夹下,有README.md详细指导,各模型可独立运行,相互之间无依赖。

      macbert

      模型简介

      MacBERT 全称为 MLM as correction BERT,其中 MLM 指的是 masked language model。
      paper:Revisiting Pre-trained Models for Chinese Natural Language Processing
      code:https://github.com/ymcui/MacBERT tensorflow版本
      以上来源:https://paperswithcode.com/paper/revisiting-pre-trained-models-for-chinese#code
      本项目是 MacBERT 改变网络结构的中文文本纠错模型:

      “MacBERT shares the same pre-training tasks as BERT with several modifications.” —— (Cui et al., Findings of the EMNLP 2020)

模型结构

  • 在通常 BERT 模型上进行了魔改,追加了一个全连接层作为错误检测即 detection, 与 SoftMaskedBERT 模型不同点在于,本项目中的 MacBERT 中,只是利用 detection 层和 correction 层的 loss 加权得到最终的 loss。不像 SoftmaskedBERT 中需要利用 detection 层的置信概率来作为 correction 的输入权重。

image.png
不过该模型只能处理输入输出等长的纠错任务,与后文的Soft-MASK BERT一样,具有一定局限性。

模型训练

在mask 策略上主要不同点在于:

  • 和 BERT 类模型相似地,对于每个训练样本,15%的输入的word进行mask,其中 80% 的词被替换成近义词(原为[MASK])、10%的词替换为随机词,10%的词不变。
  • BERT 类模型通常使用 [MASK] 来屏蔽原词,而 MacBERT 使用第三方的同义词工具来为目标词生成近义词用于屏蔽原词,特别地,当原词没有近义词时,使用随机 n-gram 来屏蔽原词;
  • 使用全词屏蔽 (wwm, whole-word masking) 以及 N-gram 屏蔽策略来选择 candidate tokens 进行屏蔽;

微调:
微调的基本模型,可以自动下载hfl/chinese-macbert-base:

  1. >>> from transformers import BertForMaskedLM
  2. >>> bert = BertForMaskedLM.from_pretrained("hfl/chinese-macbert-base")
  3. Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 659/659 [00:00<00:00, 468kB/s]
  4. Downloading: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 393M/393M [00:28<00:00, 14.4MB/s]
  5. Some weights of the model checkpoint at hfl/chinese-macbert-base were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
  6. - This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
  7. - This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  8. >>>

bert_outputs是bert 的输出,追加一个全连接(nn.Linear)得到检错概率,

  1. class MacBert4Csc(CscTrainingModel, ABC):
  2. def __init__(self, cfg, tokenizer):
  3. super().__init__(cfg)
  4. self.cfg = cfg
  5. # cfg.MODEL.BERT_CKPT="hfl/chinese-macbert-base"
  6. self.bert = BertForMaskedLM.from_pretrained(cfg.MODEL.BERT_CKPT)
  7. self.detection = nn.Linear(self.bert.config.hidden_size, 1)
  8. self.sigmoid = nn.Sigmoid()
  9. self.tokenizer = tokenizer
  10. def forward(self, texts, cor_labels=None, det_labels=None):
  11. if cor_labels:
  12. text_labels = self.tokenizer(cor_labels, padding=True, return_tensors='pt')['input_ids']
  13. text_labels[text_labels == 0] = -100 # -100计算损失时会忽略
  14. text_labels = text_labels.to(self.device)
  15. else:
  16. text_labels = None
  17. encoded_text = self.tokenizer(texts, padding=True, return_tensors='pt')
  18. encoded_text.to(self.device)
  19. bert_outputs = self.bert(**encoded_text, labels=text_labels, return_dict=True, output_hidden_states=True)
  20. # 检错概率
  21. prob = self.detection(bert_outputs.hidden_states[-1])
  22. if text_labels is None:
  23. # 检错输出,纠错输出
  24. outputs = (prob, bert_outputs.logits)
  25. else:
  26. det_loss_fct = FocalLoss(num_labels=None, activation_type='sigmoid')
  27. # pad部分不计算损失
  28. active_loss = encoded_text['attention_mask'].view(-1, prob.shape[1]) == 1
  29. active_probs = prob.view(-1, prob.shape[1])[active_loss]
  30. active_labels = det_labels[active_loss]
  31. det_loss = det_loss_fct(active_probs, active_labels.float())
  32. # 检错loss,纠错loss,检错输出,纠错输出
  33. outputs = (det_loss,
  34. bert_outputs.loss,
  35. self.sigmoid(prob).squeeze(-1),
  36. bert_outputs.logits)
  37. return outputs

模型加载

  1. tokenizer = BertTokenizer.from_pretrained("shibing624/macbert4csc-base-chinese"�)
  2. model = BertForMaskedLM.from_pretrained("shibing624/macbert4csc-base-chinese"�)
  • 当前中文拼写纠错模型效果最好的是macbert,模型名称是shibing624/macbert4csc-base-chinese
  • 基础模型Models:https://huggingface.co/shibing624/macbert4csc-base-chinese 关于原macbert模型训练代码并未开源
  • 但是纠错的macbert微调训练模型已开源
  • 模型最后输出纬度:(batch_size,seq_length+2, vocab_size),最后输出最大词的索引,由于tokenizer.decode没有开源,不知道具体解码过程,也就是seq_length+2。

    使用说明

    ```python from pycorrector.macbert.macbert_corrector import MacBertCorrector

nlp = MacBertCorrector(“shibing624/macbert4csc-base-chinese”).macbert_correct

i = nlp(‘今天新情很好’) print(i) ```

纠错任务测试结果

安装了pycorrect环境,做了一个简单的对比测试。
2.8G:语言模型:zh_giga.no_cna_cmn.prune01244.klm
144M语言模型:people2014corpus_chars.klm(密码o5e9)
xmnlp用的是bert (pytorch版本)的模型。

  • 评估标准:纠错准召率,采用严格句子粒度(Sentence Level)计算方式,把模型纠正之后的与正确句子完成相同的视为正确,否则为错。 | | | 过纠率 | acc | precision | recall | f1 | 时间 | | —- | —- | —- | —- | —- | —- | —- | —- | | sighan_15 | rule_2.8G | 0.1257 | 0.5100 | 0.5139 | 0.1363 | 0.2154 | 790.75 s | | | rule_140M | 0.16 | 0.4827 | 0.4167 | 0.1197 | 0.1860 | 807.10 s | | | macbert | 0.1508 | 0.7900 | 0.8250 | 0.7293 | 0.7742 | 16.67s | | | xmnlp | 0.1454 | 0.5327 | 0.5759 | 0.2026 | 0.2997 | 43.77 s | | | FASPell | - | - | 0.666 | 0.591 | - | - | | corpus500 | rule_2.8G | 0.1095 | 0.4820 | 0.7381 | 0.2074 | 0.3238 | 483.01 s | | | rule_140M | 0.1194 | 0.4760 | 0.7176 | 0.2040 | 0.3177 | 501.77s | | | macbert | 0.0846 | 0.7260 | 0.9133 | 0.5987 | 0.7232 | 9.04 s | | | xmnlp | 0.1045 | 0.4560 | 0.6957 | 0.1605 | 0.2609 | 23.13 |

官方结果:
sighan_15
image.png