纠错算法

  1. # @Author:sunshine
  2. # @Time : 2020/5/12 上午9:17
  3. import json
  4. from bert4keras.tokenizers import load_vocab, Tokenizer
  5. from bert4keras.models import build_transformer_model
  6. from bert4keras.snippets import DataGenerator, sequence_padding
  7. from bert4keras.optimizers import AdaFactor
  8. from keras.layers import Lambda
  9. from keras.models import Model
  10. from keras.callbacks import Callback
  11. import keras.backend as K
  12. import numpy as np
  13. from tqdm import tqdm
  14. max_len = 64
  15. config_path = '/home/chenbing/pretrain_models/bert/chinese_L-12_H-768_A-12/bert_config.json'
  16. checkpoint_path = '/home/chenbing/pretrain_models/bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
  17. vocab_path = '/home/chenbing/pretrain_models/bert/chinese_L-12_H-768_A-12/vocab.txt'
  18. train_data = json.load(open('data/train_data.json', 'r', encoding='utf-8'))
  19. valid_data = json.load(open('data/valid_data.json', 'r', encoding='utf-8'))
  20. # 加载精简词表
  21. token_dict, keep_words = load_vocab(
  22. dict_path=vocab_path,
  23. simplified=True,
  24. startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]']
  25. )
  26. tokenizer = Tokenizer(token_dict, do_lower_case=True)
  27. class MyDataGenerator(DataGenerator):
  28. def __iter__(self, random=True):
  29. """
  30. 单条样本格式: [cls]错误词汇[sep][mask][mask]..[sep]
  31. :param random:
  32. :return:
  33. """
  34. batch_tokens_ids, batch_segment_ids, batch_right_token_ids = [], [], []
  35. for is_end, D in self.sample(random):
  36. wrong, right = D
  37. right_token_ids, _ = tokenizer.encode(first_text=right)
  38. wrong_token_ids, _ = tokenizer.encode(first_text=wrong)
  39. token_ids = wrong_token_ids
  40. token_ids += [tokenizer._token_mask_id] * max_len
  41. token_ids += [tokenizer._token_end_id]
  42. segemnt_ids = [0] * len(token_ids)
  43. batch_tokens_ids.append(token_ids)
  44. batch_segment_ids.append(segemnt_ids)
  45. batch_right_token_ids.append(right_token_ids[1:])
  46. if len(batch_tokens_ids) == self.batch_size or is_end:
  47. batch_tokens_ids = sequence_padding(batch_tokens_ids)
  48. batch_segment_ids = sequence_padding(batch_segment_ids)
  49. batch_right_token_ids = sequence_padding(batch_right_token_ids, max_len)
  50. yield [batch_tokens_ids, batch_segment_ids], batch_right_token_ids
  51. batch_tokens_ids, batch_segment_ids, batch_right_token_ids = [], [], []
  52. # 构建模型
  53. bert_model = build_transformer_model(
  54. config_path=config_path,
  55. checkpoint_path=checkpoint_path,
  56. with_mlm=True,
  57. keep_tokens=keep_words
  58. )
  59. output = Lambda(lambda x: x[:, 1:max_len + 1])(bert_model.output)
  60. model = Model(bert_model.input, output)
  61. def masked_cross_entropy(y_true, y_pred):
  62. """交叉熵作为loss,并mask掉padding部分的预测
  63. """
  64. y_true = K.reshape(y_true, [K.shape(y_true)[0], -1])
  65. y_mask = K.cast(K.not_equal(y_true, 0), K.floatx())
  66. cross_entropy = K.sparse_categorical_crossentropy(y_true, y_pred)
  67. cross_entropy = K.sum(cross_entropy * y_mask) / K.sum(y_mask)
  68. return cross_entropy
  69. model.compile(loss=masked_cross_entropy, optimizer=AdaFactor(learning_rate=1e-3))
  70. model.summary()
  71. def ge_answer(wrong):
  72. """
  73. 解码
  74. :param wrong:
  75. :return:
  76. """
  77. wrong_token_ids, _ = tokenizer.encode(wrong)
  78. token_ids = wrong_token_ids + [tokenizer._token_mask_id] * max_len + [tokenizer._token_end_id]
  79. segemnt_ids = [0] * len(token_ids)
  80. probas = model.predict([np.array([token_ids]), np.array([segemnt_ids])])[0]
  81. proba_ids = probas.argmax(axis=1)
  82. useful_index = proba_ids[np.where(proba_ids != 3)]
  83. if any(useful_index):
  84. answer = tokenizer.decode(useful_index)
  85. else:
  86. answer = tokenizer.decode(proba_ids[:len(wrong)])
  87. return answer
  88. def evalute(valid_data):
  89. X, Y = 1e-10, 1e-10
  90. for item in tqdm(valid_data):
  91. wrong, right = item
  92. pred = ge_answer(wrong)
  93. X += pred == right
  94. Y += 1
  95. precision = X / Y
  96. return precision
  97. class Evaluator(Callback):
  98. def __init__(self):
  99. self.lowest = 1e10
  100. self.f1 = 1e-10
  101. def on_epoch_end(self, epoch, logs=None):
  102. if logs['loss'] <= self.lowest:
  103. self.lowest = logs['loss']
  104. model.save('models/best_mlm_model.h5')
  105. if __name__ == '__main__':
  106. # 训练模型
  107. # evaluator = Evaluator()
  108. # train_generator = MyDataGenerator(train_data, batch_size=8)
  109. #
  110. # model.fit_generator(
  111. # train_generator.forfit(),
  112. # steps_per_epoch=len(train_generator),
  113. # epochs=10,
  114. # callbacks=[evaluator]
  115. # )
  116. # predict
  117. model.load_weights('models/best_mlm_model.h5')
  118. wrong = '追风少俊年王俊凯'
  119. result = ge_answer(wrong)
  120. print(result)
  1. # @Author:sunshine
  2. # @Time : 2020/5/12 下午1:27
  3. import json
  4. from bert4keras.tokenizers import load_vocab, Tokenizer
  5. from bert4keras.models import build_transformer_model
  6. from bert4keras.snippets import DataGenerator, sequence_padding
  7. from bert4keras.optimizers import AdaFactor
  8. from keras.callbacks import Callback
  9. import keras.backend as K
  10. from tqdm import tqdm
  11. max_len = 64
  12. config_path = '/home/chenbing/pretrain_models/bert/chinese_L-12_H-768_A-12/bert_config.json'
  13. checkpoint_path = '/home/chenbing/pretrain_models/bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
  14. vocab_path = '/home/chenbing/pretrain_models/bert/chinese_L-12_H-768_A-12/vocab.txt'
  15. train_data = json.load(open('data/train_data.json', 'r', encoding='utf-8'))
  16. valid_data = json.load(open('data/valid_data.json', 'r', encoding='utf-8'))
  17. # 加载精简词表
  18. token_dict, keep_words = load_vocab(
  19. dict_path=vocab_path,
  20. simplified=True,
  21. startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]']
  22. )
  23. tokenizer = Tokenizer(token_dict, do_lower_case=True)
  24. class MyDataGenerator(DataGenerator):
  25. def __iter__(self, random=True):
  26. """
  27. 单条样本格式: [cls]错误词汇[sep][mask][mask]..[sep]
  28. :param random:
  29. :return:
  30. """
  31. batch_tokens_ids, batch_segment_ids = [], []
  32. for is_end, D in self.sample(random):
  33. wrong, right = D
  34. # segment_ids也作为mask输入
  35. token_ids, segment_ids = tokenizer.encode(first_text=wrong, second_text=right, max_length=max_len * 2)
  36. batch_tokens_ids.append(token_ids)
  37. batch_segment_ids.append(segment_ids)
  38. if len(batch_tokens_ids) == self.batch_size or is_end:
  39. batch_tokens_ids = sequence_padding(batch_tokens_ids)
  40. batch_segment_ids = sequence_padding(batch_segment_ids)
  41. yield [batch_tokens_ids, batch_segment_ids], None
  42. batch_tokens_ids, batch_segment_ids = [], []
  43. # 构建模型
  44. model = build_transformer_model(
  45. config_path=config_path,
  46. checkpoint_path=checkpoint_path,
  47. application='unilm',
  48. keep_tokens=keep_words
  49. )
  50. y_true = model.input[0][:, 1:]
  51. y_mask = model.input[1][:, 1:]
  52. y_pred = model.output[:, :-1]
  53. cross_entropy = K.sparse_categorical_crossentropy(y_true, y_pred)
  54. cross_entropy = K.sum(cross_entropy * y_mask) / K.sum(y_mask)
  55. model.add_loss(cross_entropy)
  56. model.compile(optimizer=AdaFactor(learning_rate=1e-3))
  57. model.summary()
  58. def ge_answer(wrong):
  59. """
  60. 解码
  61. :param wrong:
  62. :return:
  63. """
  64. wrong_token_ids, _ = tokenizer.encode(wrong)
  65. token_ids = wrong_token_ids + [tokenizer._token_mask_id] * max_len + [tokenizer._token_end_id]
  66. segemnt_ids = [0] * len(token_ids)
  67. probas = model.predict([np.array([token_ids]), np.array([segemnt_ids])])[0]
  68. proba_ids = probas.argmax(axis=1)
  69. useful_index = proba_ids[np.where(proba_ids != 3)]
  70. if any(useful_index):
  71. answer = tokenizer.decode(useful_index)
  72. else:
  73. answer = tokenizer.decode(proba_ids[:len(wrong)])
  74. return answer
  75. def evalute(valid_data):
  76. X, Y = 1e-10, 1e-10
  77. for item in tqdm(valid_data):
  78. wrong, right = item
  79. pred = ge_answer(wrong)
  80. X += pred == right
  81. Y += 1
  82. precision = X / Y
  83. return precision
  84. class Evaluator(Callback):
  85. def __init__(self):
  86. self.lowest = 1e10
  87. def on_epoch_end(self, epoch, logs=None):
  88. if logs['loss'] <= self.lowest:
  89. self.lowest = logs['loss']
  90. model.save('models/best_seq2seq_model.h5')
  91. if __name__ == '__main__':
  92. evaluator = Evaluator()
  93. train_generator = MyDataGenerator(train_data, batch_size=8)
  94. model.fit_generator(
  95. train_generator.forfit(),
  96. steps_per_epoch=len(train_generator),
  97. epochs=10,
  98. callbacks=[evaluator]
  99. )