• 模型单条文本预测代码实现:
      1. import os
      2. import torch
      3. import json
      4. from bilstm_crf import BiLSTM_CRF
      5. def singel_predict(model_path, content, char_to_id_json_path, batch_size, embedding_dim,
      6. hidden_dim, num_layers, sentence_length, offset, target_type_list, tag2id):
      7. char_to_id = json.load(open(char_to_id_json_path, mode="r", encoding="utf-8"))
      8. # 将字符串转为码表id列表
      9. char_ids = content_to_id(content, char_to_id)
      10. # 处理成 batch_size * sentence_length 的 tensor 数据
      11. # 定义模型输入列表
      12. model_inputs_list, model_input_map_list = build_model_input_list(content,
      13. char_ids,
      14. batch_size,
      15. sentence_length,
      16. offset)
      17. # 加载模型
      18. model = BiLSTM_CRF(vocab_size=len(char_to_id),
      19. tag_to_ix=tag2id,
      20. embedding_dim=embedding_dim,
      21. hidden_dim=hidden_dim,
      22. batch_size=batch_size,
      23. num_layers=num_layers,
      24. sequence_length=sentence_length)
      25. # 加载模型字典
      26. model.load_state_dict(torch.load(model_path))
      27. tag_id_dict = {v: k for k, v in tag_to_id.items() if k[2:] in target_type_list}
      28. # 定义返回实体列表
      29. entities = []
      30. with torch.no_grad():
      31. for step, model_inputs in enumerate(model_inputs_list):
      32. prediction_value = model(model_inputs)
      33. # 获取每一行预测结果
      34. for line_no, line_value in enumerate(prediction_value):
      35. # 定义将要识别的实体
      36. entity = None
      37. # 获取当前行每个字的预测结果
      38. for char_idx, tag_id in enumerate(line_value):
      39. # 若预测结果 tag_id 属于目标字典数据 key 中
      40. if tag_id in tag_id_dict:
      41. # 取符合匹配字典id的第一个字符,即B, I
      42. tag_index = tag_id_dict[tag_id][0]
      43. # 计算当前字符确切的下标位置
      44. current_char = model_input_map_list[step][line_no][char_idx]
      45. # 若当前字标签起始为 B, 则设置为实体开始
      46. if tag_index == "B":
      47. entity = current_char
      48. # 若当前字标签起始为 I, 则进行字符串追加
      49. elif tag_index == "I" and entity:
      50. entity += current_char
      51. # 当实体不为空且当前标签类型为 O 时,加入实体列表
      52. if tag_id == tag_to_id["O"] and entity:
      53. # 满足当前字符为O,上一个字符为目标提取实体结尾时,将其加入实体列表
      54. entities.append(entity)
      55. # 重置实体
      56. entity = None
      57. return entities
      58. def content_to_id(content, char_to_id):
      59. # 定义字符串对应的码表 id 列表
      60. char_ids = []
      61. for char in list(content):
      62. # 判断若字符不在码表对应字典中,则取 NUK 的编码(即 unknown),否则取对应的字符编码
      63. if char_to_id.get(char):
      64. char_ids.append(char_to_id[char])
      65. else:
      66. char_ids.append(char_to_id["UNK"])
      67. return char_ids
      68. def build_model_input_list(content, char_ids, batch_size, sentence_length, offset):
      69. # 定义模型输入数据列表
      70. model_input_list = []
      71. # 定义每个批次句子 id 数据
      72. batch_sentence_list = []
      73. # 将文本内容转为列表
      74. content_list = list(content)
      75. # 定义与模型 char_id 对照的文字
      76. model_input_map_list = []
      77. # 定义每个批次句子字符数据
      78. batch_sentence_char_list = []
      79. # 判断是否需要 padding
      80. if len(char_ids) % sentence_length > 0:
      81. # 将不足 batch_size * sentence_length 的部分填充0
      82. padding_length = (batch_size * sentence_length
      83. - len(char_ids) % batch_size * sentence_length
      84. - len(char_ids) % sentence_length)
      85. char_ids.extend([0] * padding_length)
      86. content_list.extend(["#"] * padding_length)
      87. # 迭代字符 id 列表
      88. # 数据满足 batch_size * sentence_length 将加入 model_input_list
      89. for step, idx in enumerate(range(0, len(char_ids) + 1, sentence_length)):
      90. # 起始下标,从第一句开始增加 offset 个字的偏移
      91. start_idx = 0 if idx == 0 else idx - step * offset
      92. # 获取长度为 sentence_length 的字符 id 数据集
      93. sub_list = char_ids[start_idx:start_idx + sentence_length]
      94. # 获取长度为 sentence_length 的字符数据集
      95. sub_char_list = content_list[start_idx:start_idx + sentence_length]
      96. # 加入批次数据集中
      97. batch_sentence_list.append(sub_list)
      98. # 批量句子包含字符列表
      99. batch_sentence_char_list.append(sub_char_list)
      100. # 每当批次长度达到 batch_size 时候,将其加入 model_input_list
      101. if len(batch_sentence_list) == batch_size:
      102. # 将数据格式转为 tensor 格式,大小为 batch_size * sentence_length
      103. model_input_list.append(torch.tensor(batch_sentence_list))
      104. # 重置 batch_sentence_list
      105. batch_sentence_list = []
      106. # 将 char_id 对应的字符加入映射表中
      107. model_input_map_list.append(batch_sentence_char_list)
      108. # 重置批字符串内容
      109. batch_sentence_char_list = []
      110. # 返回模型输入列表
      111. return model_input_list, model_input_map_list

    • 代码实现位置: /data/doctor_offline/ner_model/predict.py

    • 输入参数:
      1. # 参数1:待识别文本
      2. content = "本病是由DNA病毒的单纯疱疹病毒所致。人类单纯疱疹病毒分为两型," \
      3. "即单纯疱疹病毒Ⅰ型(HSV-Ⅰ)和单纯疱疹病毒Ⅱ型(HSV-Ⅱ)。" \
      4. "Ⅰ型主要引起生殖器以外的皮肤黏膜(口腔黏膜)和器官(脑)的感染。" \
      5. "Ⅱ型主要引起生殖器部位皮肤黏膜感染。" \
      6. "病毒经呼吸道、口腔、生殖器黏膜以及破损皮肤进入体内," \
      7. "潜居于人体正常黏膜、血液、唾液及感觉神经节细胞内。" \
      8. "当机体抵抗力下降时,如发热胃肠功能紊乱、月经、疲劳等时," \
      9. "体内潜伏的HSV被激活而发病。"
      10. # 参数2:模型保存文件路径
      11. model_path = "model/bilstm_crf_state_dict_20200129_210417.pt"
      12. # 参数3:批次大小
      13. BATCH_SIZE = 8
      14. # 参数4:字向量维度
      15. EMBEDDING_DIM = 300
      16. # 参数5:隐层维度
      17. HIDDEN_DIM = 128
      18. # 参数6:句子长度
      19. SENTENCE_LENGTH = 100
      20. # 参数7:偏移量
      21. OFFSET = 10
      22. # 参数8:标签码表对照字典
      23. tag_to_id = {"O": 0, "B-dis": 1, "I-dis": 2, "B-sym": 3, "I-sym": 4, "<START>": 5, "<STOP>": 6}
      24. # 参数9:字符码表文件路径
      25. char_to_id_json_path = "./data/char_to_id.json"
      26. # 参数10:预测结果存储路径
      27. prediction_result_path = "prediction_result"
      28. # 参数11:待匹配标签类型
      29. target_type_list = ["sym"]

    • 调用:
      1. # 单独文本预测, 获得实体结果
      2. entities = singel_predict(model_path,
      3. content,
      4. char_to_id_json_path,
      5. BATCH_SIZE,
      6. EMBEDDING_DIM,
      7. HIDDEN_DIM,
      8. SENTENCE_LENGTH,
      9. OFFSET,
      10. target_type_list,
      11. tag_to_id)
      12. # 打印实体结果
      13. print("entities:\n", entities)

    • 输出效果:
      1. entities:
      2. ['感染', '发热', '##']

    • 批量文件夹文件预测代码实现:
      1. def batch_predict(data_path, model_path, char_to_id_json_path, batch_size, embedding_dim,
      2. hidden_dim, sentence_length, offset, target_type_list,
      3. prediction_result_path, tag_to_id):
      4. """
      5. description: 批量预测,查询文件目录下数据,
      6. 从中提取符合条件的实体并存储至新的目录下prediction_result_path
      7. :param data_path: 数据文件路径
      8. :param model_path: 模型文件路径
      9. :param char_to_id_json_path: 字符码表文件路径
      10. :param batch_size: 训练批次大小
      11. :param embedding_dim: 字向量维度
      12. :param hidden_dim: BiLSTM 隐藏层向量维度
      13. :param sentence_length: 句子长度(句子做了padding)
      14. :param offset: 设定偏移量,
      15. 当字符串超出sentence_length时, 换行时增加偏移量
      16. :param target_type_list: 待匹配类型,符合条件的实体将会被提取出来
      17. :param prediction_result_path: 预测结果保存路径
      18. :param tag_to_id: 标签码表对照字典, 标签对应 id
      19. :return: 无返回
      20. """
      21. # 迭代路径, 读取文件名
      22. for fn in os.listdir(data_path):
      23. # 拼装全路径
      24. fullpath = os.path.join(data_path, fn)
      25. # 定义输出结果文件
      26. entities_file = open(os.path.join(prediction_result_path, fn),
      27. mode="w",
      28. encoding="utf-8")
      29. with open(fullpath, mode="r", encoding="utf-8") as f:
      30. # 读取文件内容
      31. content = f.readline()
      32. # 调用单个预测模型,输出为目标类型实体文本列表
      33. entities = singel_predict(model_path, content, char_to_id_json_path,
      34. batch_size, embedding_dim, hidden_dim, sentence_length,
      35. offset, target_type_list, tag_to_id)
      36. # 写入识别结果文件
      37. entities_file.write("\n".join(entities))
      38. print("batch_predict Finished".center(100, "-"))

    • 代码实现位置: /data/doctor_offline/ner_model/predict.py

    • 输入参数:
      1. # 参数1:模型保存路径
      2. model_path = "model/bilstm_crf_state_dict_20191219_220256.pt"
      3. # 参数2:批次大小
      4. BATCH_SIZE = 8
      5. # 参数3:字向量维度
      6. EMBEDDING_DIM = 200
      7. # 参数4:隐层维度
      8. HIDDEN_DIM = 100
      9. # 参数5:句子长度
      10. SENTENCE_LENGTH = 20
      11. # 参数6:偏移量
      12. OFFSET = 10
      13. # 参数7:标签码表对照字典
      14. tag_to_id = {"O": 0, "B-dis": 1, "I-dis": 2, "B-sym": 3, "I-sym": 4, "<START>": 5, "<STOP>": 6}
      15. # 参数8:字符码表文件路径
      16. char_to_id_json_path = "./data/char_to_id.json"
      17. # 参数9:预测结果存储路径
      18. prediction_result_path = "prediction_result"
      19. # 参数10:待匹配标签类型
      20. target_type_list = ["sym"]
      21. # 参数11:待预测文本文件所在目录
      22. data_path = "origin_data"

    • 调用:
      1. # 批量文本预测, 并将结果写入文件中
      2. batch_predict(data_path,
      3. model_path,
      4. char_to_id_json_path,
      5. BATCH_SIZE,
      6. EMBEDDING_DIM,
      7. HIDDEN_DIM,
      8. SENTENCE_LENGTH,
      9. OFFSET,
      10. target_type_list,
      11. prediction_result_path,
      12. tag_to_id)

    • 输出效果: 将识别结果保存至prediction_result_path指定的目录下, 名称与源文件一致, 内容为每行存储识别实体名称

    • 小节总结:
      • 学习了模型单条文本预测代码实现
      • 学习了批量文件夹文件预测代码实现