• 模型训练的流程
      • 第一步: 熟悉字符到数字编码的码表
      • 第二步: 熟悉训练数据集的样式和含义解释
      • 第三步: 生成批量训练数据
      • 第四步: 完成准确率和召回率的评估代码
      • 第五步: 完成训练模型的代码
      • 第六步: 绘制损失曲线和评估曲线图

    • 第一步: 熟悉字符到数字编码的码表.
      1. # 代表了数据集中所有字符到数字编码的字典映射
      2. # 码表可以包含中文简体、繁体、英文大小写字母、数字、中英文标点符号等等
      3. # <PAD>为填充标识, 训练时需要将句子转化成矩阵, 而句子长短不一, 需要做padding处理
      4. {
      5. "<PAD>": 0,
      6. "厑": 1,
      7. "吖": 2,
      8. "呵": 3,
      9. "啊": 4,
      10. "嗄": 5,
      11. "嬶": 6,
      12. ...
      13. }

    • 码表所在位置: /data/doctor_offline/ner_model/data/char_to_id.json

    • 第二步: 熟悉训练数据集的样式和含义解释.
      1. B-dis
      2. I-dis
      3. I-dis
      4. O
      5. O
      6. O
      7. O
      8. O
      9. O
      10. O
      11. O
      12. B-sym
      13. I-sym
      14. B-sym
      15. I-sym
      16. O

    • 训练数据集的含义解释:
      • 每一行包含一个字以及与之对应的标签, 字与标签之间通过\t分隔
      • 句子与句子之间通过空行分隔
      • 标签说明:
        • B-dis: 疾病实体名词起始标识
        • I-dis: 疾病实体名词中间到结尾标识
        • B-sym: 症状实体名词起始标识
        • I-sym: 症状实体名词中间到结尾标识
        • O: 其他非实体部分标识

    • 数据集所在位置: /data/doctor_offline/ner_model/data/train.txt

    • 将训练数据集转换为数字化编码集:
      1. # 导入包
      2. import json
      3. import numpy as np
      4. # 创建训练数据集, 从原始训练文件中将中文字符进行数字编码, 并将标签页进行数字编码
      5. def create_train_data(train_data_file, result_file, json_file, tag2id, max_length=20):
      6. # 导入json格式的中文字符到id的映射表
      7. char2id = json.load(open(json_file, mode='r', encoding='utf-8'))
      8. char_data, tag_data = [], []
      9. # 打开原始训练文件
      10. with open(train_data_file, mode='r', encoding='utf-8') as f:
      11. # 初始化一条语句数字化编码后的列表
      12. char_ids = [0] * max_length
      13. tag_ids = [0] * max_length
      14. idx = 0
      15. for line in f.readlines():
      16. line = line.strip('\n').strip()
      17. # 如果不是空行, 并且当前语句长度没有超过max_length, 则进行字符到id的映射
      18. if len(line) > 0 and line and idx < max_length:
      19. ch, tag = line.split('\t')
      20. # 如果当前字符存在于映射表中, 则直接映射为对应的id值
      21. if char2id.get(ch):
      22. char_ids[idx] = char2id[ch]
      23. # 否则直接用"UNK"的id值来代替这个未知字符
      24. else:
      25. char_ids[idx] = char2id['UNK']
      26. # 将标签也进行对应的转换
      27. tag_ids[idx] = tag2id[tag]
      28. idx += 1
      29. # 如果是空行, 或者当前语句长度超过max_length
      30. else:
      31. # 如果当前语句长度超过max_length, 直接将[0: max_langth]的部分作为结果
      32. if idx <= max_length:
      33. char_data.append(char_ids)
      34. tag_data.append(tag_ids)
      35. # 遇到空行, 说明当前句子已经结束, 初始化清零, 为下一个句子的映射做准备
      36. char_ids = [0] * max_length
      37. tag_ids = [0] * max_length
      38. idx = 0
      39. # 将数字化编码后的数据封装成numpy的数组类型, 数字编码采用np.int32
      40. x_data = np.array(char_data, dtype=np.int32)
      41. y_data = np.array(tag_data, dtype=np.int32)
      42. # 直接利用np.savez()将数据存储为.npz类型的文件
      43. np.savez(result_file, x_data=x_data, y_data=y_data)
      44. print("create_train_data Finished!".center(100, "-"))

    • 代码实现位置: /data/doctor_offline/ner_model/preprocess_data.py

    • 输入参数:
      1. # 参数1:字符码表文件路
      2. json_file = './data/char_to_id.json'
      3. # 参数2:标签码表对照字典
      4. tag2id = {"O": 0, "B-dis": 1, "I-dis": 2, "B-sym": 3, "I-sym": 4, "<START>": 5, "<STOP>": 6}
      5. # 参数3:训练数据文件路径
      6. train_data_file = './data/train.txt'
      7. # 参数4:创建的npz文件保路径(训练数据)
      8. result_file = './data/train.npz'

    • 调用:
      1. if __name__ == '__main__':
      2. create_train_data(train_data_file, result_file, json_file, tag2id)

    • 输出效果:

      1. ------------------------------------create_train_data Finished!-------------------------------------
    • 生成了新的数据集文件: /data/doctor_offline/ner_model/data/train.npz


    • 第三步: 生成批量训练数据.
      1. # 导入相关的包
      2. import numpy as np
      3. import torch
      4. import torch.utils.data as Data
      5. # 生成批量训练数据
      6. def load_dataset(data_file, batch_size):
      7. # 将第二步生成的train.npz文件导入内存
      8. data = np.load(data_file)
      9. # 分别取出特征值和标签
      10. x_data = data['x_data']
      11. y_data = data['y_data']
      12. # 将数据封装成tensor张量
      13. x = torch.tensor(x_data, dtype=torch.long)
      14. y = torch.tensor(y_data, dtype=torch.long)
      15. # 将数据封装成Tensor数据集
      16. dataset = Data.TensorDataset(x, y)
      17. total_length = len(dataset)
      18. # 采用80%的数据作为训练集, 20%的数据作为测试集
      19. train_length = int(total_length * 0.8)
      20. validation_length = total_length - train_length
      21. # 利用Data.random_split()直接切分集合, 按照80%, 20%的比例划分
      22. train_dataset, validation_dataset = Data.random_split(dataset=dataset,
      23. lengths=[train_length, validation_length])
      24. # 将训练集进行DataLoader封装
      25. # 参数说明如下:
      26. # dataset: 训练数据集
      27. # batch_size: 代表批次大小, 若数据集总样本数量无法被batch_size整除, 则最后一批数据为余数
      28. # 若设置drop_last为True, 则自动抹去最后不能被整除的剩余批次
      29. # shuffle: 是否每个批次为随机抽取, 若为True, 则每次迭代时数据为随机抽取
      30. # num_workers: 设定有多少子进程用来做数据加载, 默认为0, 即数据将被加载到主进程中
      31. # drop_last: 是否去除不能被整除后的最后批次, 若为True, 则不生成最后不能被整除剩余的数据内容
      32. # 例如: dataset长度为1028, batch_size为8,
      33. # 若drop_last=True, 则最后剩余的4(1028/8=128余4)条数据将被抛弃不用
      34. train_loader = Data.DataLoader(dataset=train_dataset, batch_size=batch_size,
      35. shuffle=True, num_workers=4, drop_last=True)
      36. validation_loader = Data.DataLoader(dataset=validation_dataset, batch_size=batch_size,
      37. shuffle=True, num_workers=4, drop_last=True)
      38. # 将两个数据生成器封装为一个字典类型
      39. data_loaders = {'train': train_loader, 'validation': validation_loader}
      40. # 将两个数据集的长度也封装为一个字典类型
      41. data_size = {'train': train_length, 'validation': validation_length}
      42. return data_loaders, data_size

    • 代码实现位置: /data/doctor_offline/ner_model/loader_data.py

    • 输入参数:
      1. # 批次大小
      2. BATCH_SIZE = 8
      3. # 编码后的训练数据文件路径
      4. DATA_FILE = './data/train.npz'

    • 调用:
      1. if __name__ == '__main__':
      2. data_loader, data_size = load_dataset(DATA_FILE, BATCH_SIZE)
      3. print('data_loader:', data_loader, '\ndata_size:', data_size)

    • 输出效果:
      1. data_loader: {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f29eaafb3d0>, 'validation': <torch.utils.data.dataloader.DataLoader object at 0x7f29eaafb5d0>}
      2. data_size: {'train': 5368, 'validation': 1343}

    • 第四步: 完成准确率和召回率的评估代码.
      1. # 评估模型的准确率, 召回率, F1, 等指标
      2. def evaluate(sentence_list, true_tag, predict_tag, id2char, id2tag):
      3. '''
      4. sentence_list: 文本向量化后的句子向量列表
      5. true_tag: 真实的标签
      6. predict_tag: 模型预测的标签
      7. id2char: id值到中文字符的映射表
      8. id2tag: id值到标签的映射表
      9. '''
      10. # 初始化真实的命名实体, 预测的命名实体, 接下来比较两者来评估各项指标
      11. true_entities, true_entity = [], []
      12. predict_entities, predict_entity = [], []
      13. # 逐条遍历批次中所有的语句
      14. for line_num, sentence in enumerate(sentence_list):
      15. # 遍历一条样本语句中的每一个字符编码(这里面是数字化编码)
      16. for char_num in range(len(sentence)):
      17. # 编码为0, 表示后面都是填充的0, 可以结束for循环
      18. if sentence[char_num]==0:
      19. break
      20. # 依次取出真实的样本字符, 真实的标签, 预测的标签
      21. char_text = id2char[sentence[char_num]]
      22. true_tag_type = id2tag[true_tag[line_num][char_num]]
      23. predict_tag_type = id2tag[predict_tag[line_num][char_num]]
      24. # 对真实标签进行命名实体的匹配
      25. # 如果第一个字符是"B", 表示一个实体的开始, 将"字符/标签"的格式添加进实体列表中
      26. if true_tag_type[0] == "B":
      27. true_entity = [char_text + "/" + true_tag_type]
      28. # 如果第一个字符是"I", 表示处于一个实体的中间
      29. # 如果真实命名实体列表非空, 并且最后一个添加进去的标签类型和当前的标签类型一样, 则继续添加
      30. # 意思就是比如true_entity = ["中/B-Person", "国/I-Person"], 此时的"人/I-Person"就可以添加进去, 因为都属于同一个命名实体
      31. elif true_tag_type[0] == "I" and len(true_entity) != 0 and true_entity[-1].split("/")[1][1:] == true_tag_type[1:]:
      32. true_entity.append(char_text + "/" + true_tag_type)
      33. # 如果第一个字符是"O", 并且true_entity非空, 表示一个命名实体的匹配结束了
      34. elif true_tag_type[0] == "O" and len(true_entity) != 0 :
      35. # 最后增加进去一个"行号_列号", 作为区分实体的标志
      36. true_entity.append(str(line_num) + "_" + str(char_num))
      37. # 将这个匹配出来的实体加入到结果列表中
      38. true_entities.append(true_entity)
      39. # 清空true_entity, 为下一个命名实体的匹配做准备
      40. true_entity=[]
      41. # 除了上面三种情况, 说明当前没有匹配出任何命名实体, 则清空true_entity, 继续下一次匹配
      42. else:
      43. true_entity=[]
      44. # 对预测标签进行命名实体的匹配
      45. # 如果第一个字符是"B", 表示一个实体的开始, 将"字符/预测标签"的格式添加进实体列表中
      46. if predict_tag_type[0] == "B":
      47. predict_entity = [char_text + "/" + predict_tag_type]
      48. # 如果第一个字符是"I", 表示处于一个实体的中间
      49. # 如果预测命名实体列表非空, 并且最后一个添加进去的标签类型和当前的标签类型一样, 则继续添加
      50. # 意思就是比如predict_entity = ["中/B-Person", "国/I-Person"], 此时的"人/I-Person"就可以添>加进去, 因为都属于同一个命名实体
      51. elif predict_tag_type[0] == "I" and len(predict_entity) != 0 and predict_entity[-1].split("/")[1][1:] == predict_tag_type[1:]:
      52. predict_entity.append(char_text + "/" + predict_tag_type)
      53. # 如果第一个字符是"O", 并且predict_entity非空, 表示一个命名实体的匹配结束了
      54. elif predict_tag_type[0] == "O" and len(predict_entity) != 0:
      55. # 最后增加进去一个"行号_列号", 作为区分实体的标志
      56. predict_entity.append(str(line_num) + "_" + str(char_num))
      57. # 将这个匹配出来的实体加入到结果列表中
      58. predict_entities.append(predict_entity)
      59. # 清空predict_entity, 为下一个命名实体的匹配做准备
      60. predict_entity = []
      61. # 除了上面三种情况, 说明当前没有匹配出任何命名实体, 则清空predict_entity, 继续下一次匹配
      62. else:
      63. predict_entity = []
      64. # 遍历所有预测实体的列表, 只有那些在真实命名实体中的才是正确的
      65. acc_entities = [entity for entity in predict_entities if entity in true_entities]
      66. # 计算正确实体的个数, 预测实体的总个数, 真实实体的总个数
      67. acc_entities_length = len(acc_entities)
      68. predict_entities_length = len(predict_entities)
      69. true_entities_length = len(true_entities)
      70. # 至少正确预测了一个, 才计算3个指标, 准确率
      71. if acc_entities_length > 0:
      72. accuracy = float(acc_entities_length / predict_entities_length)
      73. recall = float(acc_entities_length / true_entities_length)
      74. f1_score = 2 * accuracy * recall / (accuracy + recall)
      75. return accuracy, recall, f1_score, acc_entities_length, predict_entities_length, true_entities_length
      76. else:
      77. return 0, 0, 0, acc_entities_length, predict_entities_length, true_entities_length

    • 代码实现位置: /data/doctor_offline/ner_model/evaluate_model.py

    • 输入参数:
      1. # 真实标签数据
      2. tag_list = [
      3. [0, 0, 3, 4, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0],
      4. [0, 0, 3, 4, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0],
      5. [0, 0, 3, 4, 0, 3, 4, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
      6. [3, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 0, 0, 0, 0],
      7. [0, 0, 1, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
      8. [3, 4, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
      9. [0, 0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0],
      10. [0, 0, 3, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
      11. ]
      12. # 预测标签数据
      13. predict_tag_list = [
      14. [0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0],
      15. [0, 0, 3, 4, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0],
      16. [0, 0, 3, 4, 0, 3, 4, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
      17. [3, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 0, 0, 0, 0],
      18. [0, 0, 1, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0],
      19. [3, 4, 0, 3, 4, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0],
      20. [0, 0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0],
      21. [0, 0, 3, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
      22. ]
      23. # 编码与字符对照字典
      24. id2char = {0: '<PAD>', 1: '确', 2: '诊', 3: '弥', 4: '漫', 5: '大', 6: 'b', 7: '细', 8: '胞', 9: '淋', 10: '巴', 11: '瘤', 12: '1', 13: '年', 14: '反', 15: '复', 16: '咳', 17: '嗽', 18: '、', 19: '痰', 20: '4', 21: '0', 22: ',', 23: '再', 24: '发', 25: '伴', 26: '气', 27: '促', 28: '5', 29: '天', 30: '。', 31: '生', 32: '长', 33: '育', 34: '迟', 35: '缓', 36: '9', 37: '右', 38: '侧', 39: '小', 40: '肺', 41: '癌', 42: '第', 43: '三', 44: '次', 45: '化', 46: '疗', 47: '入', 48: '院', 49: '心', 50: '悸', 51: '加', 52: '重', 53: '胸', 54: '痛', 55: '3', 56: '闷', 57: '2', 58: '多', 59: '月', 60: '余', 61: ' ', 62: '周', 63: '上', 64: '肢', 65: '无', 66: '力', 67: '肌', 68: '肉', 69: '萎', 70: '缩', 71: '半'}
      25. # 编码与标签对照字典
      26. id2tag = {0: 'O', 1: 'B-dis', 2: 'I-dis', 3: 'B-sym', 4: 'I-sym'}
      27. # 输入的数字化sentences_sequence, 由下面的sentence_list经过映射函数sentence_map()转化后得到
      28. sentence_list = [
      29. "确诊弥漫大b细胞淋巴瘤1年",
      30. "反复咳嗽、咳痰40年,再发伴气促5天。",
      31. "生长发育迟缓9年。",
      32. "右侧小细胞肺癌第三次化疗入院",
      33. "反复气促、心悸10年,加重伴胸痛3天。",
      34. "反复胸闷、心悸、气促2多月,加重3天",
      35. "咳嗽、胸闷1月余, 加重1周",
      36. "右上肢无力3年, 加重伴肌肉萎缩半年"
      37. ]

    • 调用:
      1. def sentence_map(sentence_list, char_to_id, max_length):
      2. sentence_list.sort(key=lambda c:len(c), reverse=True)
      3. sentence_map_list = []
      4. for sentence in sentence_list:
      5. sentence_id_list = [char_to_id[c] for c in sentence]
      6. padding_list = [0] * (max_length-len(sentence))
      7. sentence_id_list.extend(padding_list)
      8. sentence_map_list.append(sentence_id_list)
      9. return torch.tensor(sentence_map_list, dtype=torch.long)
      10. char_to_id = {"<PAD>":0}
      11. SENTENCE_LENGTH = 20
      12. for sentence in sentence_list:
      13. for _char in sentence:
      14. if _char not in char_to_id:
      15. char_to_id[_char] = len(char_to_id)
      16. sentences_sequence = sentence_map(sentence_list, char_to_id, SENTENCE_LENGTH)
      17. if __name__ == '__main__':
      18. accuracy, recall, f1_score, acc_entities_length, predict_entities_length, true_entities_length = evaluate(sentences_sequence.tolist(), tag_list, predict_tag_list, id2char, id2tag)
      19. print("accuracy:", accuracy,
      20. "\nrecall:", recall,
      21. "\nf1_score:", f1_score,
      22. "\nacc_entities_length:", acc_entities_length,
      23. "\npredict_entities_length:", predict_entities_length,
      24. "\ntrue_entities_length:", true_entities_length)

    • 输出效果:
      1. step_acc: 0.8823529411764706
      2. step_recall: 0.9375
      3. f1_score: 0.9090909090909091
      4. acc_entities_length: 15
      5. predict_entities_length: 17
      6. true_entities_length: 16

    • 第五步: 完成训练模型的代码.
      # 导入包
      import json
      import time
      from tqdm import tqdm
      import matplotlib.pyplot as plt
      import torch
      import torch.optim as optim
      from torch.autograd import Variable
      # 导入之前编写好的包, 包括类, 数据集加载, 评估函数
      from bilstm_crf import BiLSTM_CRF
      from loader_data import load_dataset
      from evaluate_model import evaluate
      # 训练模型的函数
      def train(data_loader, data_size, batch_size, embedding_dim, hidden_dim,
              sentence_length, num_layers, epochs, learning_rate, tag2id,
              model_saved_path, train_log_path,
              validate_log_path, train_history_image_path):
        '''
        data_loader: 数据集的加载器, 之前已经通过load_dataset完成了构造
        data_size:   训练集和测试集的样本数量
        batch_size:  批次的样本个数
        embedding_dim:  词嵌入的维度
        hidden_dim:     隐藏层的维度
        sentence_length:  文本限制的长度
        num_layers:       神经网络堆叠的LSTM层数
        epochs:           训练迭代的轮次
        learning_rate:    学习率
        tag2id:           标签到id的映射字典
        model_saved_path: 模型保存的路径
        train_log_path:   训练日志保存的路径
        validate_log_path:  测试集日志保存的路径
        train_history_image_path:  训练数据的相关图片保存路径
        '''
        # 将中文字符和id的对应码表加载进内存
        char2id = json.load(open("./data/char_to_id.json", mode="r", encoding="utf-8"))
        # 初始化BiLSTM_CRF模型
        model = BiLSTM_CRF(vocab_size=len(char2id), tag_to_ix=tag2id,
                       embedding_dim=embedding_dim, hidden_dim=hidden_dim,
                       batch_size=batch_size, num_layers=num_layers,
                       sequence_length=sentence_length)
        # 定义优化器, 使用SGD作为优化器(pytorch中Embedding支持的GPU加速为SGD, SparseAdam)
        # 参数说明如下:
        # lr:          优化器学习率
        # momentum:    优化下降的动量因子, 加速梯度下降过程
        optimizer = optim.SGD(params=model.parameters(), lr=learning_rate, momentum=0.85)
        # 设定优化器学习率更新策略
        # 参数说明如下:
        # optimizer:    优化器
        # step_size:    更新频率, 每过多少个epoch更新一次优化器学习率
        # gamma:        学习率衰减幅度,
        #               按照什么比例调整(衰减)学习率(相对于上一轮epoch), 默认0.1
        #   例如:
        #   初始学习率 lr = 0.5,    step_size = 20,    gamma = 0.1
        #              lr = 0.5     if epoch < 20
        #              lr = 0.05    if 20 <= epoch < 40
        #              lr = 0.005   if 40 <= epoch < 60
        scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.2)
        # 初始化存放训练中损失, 准确率, 召回率, F1等数值指标
        train_loss_list = []
        train_acc_list = []
        train_recall_list = []
        train_f1_list = []
        train_log_file = open(train_log_path, mode="w", encoding="utf-8")
        # 初始化存放测试中损失, 准确率, 召回率, F1等数值指标
        validate_loss_list = []
        validate_acc_list = []
        validate_recall_list = []
        validate_f1_list = []
        validate_log_file = open(validate_log_path, mode="w", encoding="utf-8")
        # 利用tag2id生成id到tag的映射字典
        id2tag = {v:k for k, v in tag2id.items()}
        # 利用char2id生成id到字符的映射字典
        id2char = {v:k for k, v in char2id.items()}
        # 按照参数epochs的设定来循环epochs次
        for epoch in range(epochs):
            # 在进度条打印前, 先输出当前所执行批次
            tqdm.write("Epoch {}/{}".format(epoch + 1, epochs))
            # 定义要记录的正确总实体数, 识别实体数以及真实实体数
            total_acc_entities_length, \
            total_predict_entities_length, \
            total_gold_entities_length = 0, 0, 0
            # 定义每batch步数, 批次loss总值, 准确度, f1值
            step, total_loss, correct, f1 = 1, 0.0, 0, 0
            # 开启当前epochs的训练部分
            for inputs, labels in tqdm(data_loader["train"]):
                # 将数据以Variable进行封装
                inputs, labels = Variable(inputs), Variable(labels)
                # 在训练模型期间, 要在每个样本计算梯度前将优化器归零, 不然梯度会被累加
                optimizer.zero_grad()
                # 此处调用的是BiLSTM_CRF类中的neg_log_likelihood()函数
                loss = model.neg_log_likelihood(inputs, labels)
                # 获取当前步的loss, 由tensor转为数字
                step_loss = loss.data
                # 累计每步损失值
                total_loss += step_loss
                # 获取解码最佳路径列表, 此时调用的是BiLSTM_CRF类中的forward()函数
                best_path_list = model(inputs)
                # 模型评估指标值获取包括:当前批次准确率, 召回率, F1值以及对应的实体个数
                step_acc, step_recall, f1_score, acc_entities_length, \
                predict_entities_length, gold_entities_length = evaluate(inputs.tolist(),
                                                                         labels.tolist(),
                                                                         best_path_list,
                                                                         id2char,
                                                                         id2tag)
                # 训练日志内容
                log_text = "Epoch: %s | Step: %s " \
                           "| loss: %.5f " \
                           "| acc: %.5f " \
                           "| recall: %.5f " \
                           "| f1 score: %.5f" % \
                           (epoch, step, step_loss, step_acc, step_recall,f1_score)
                # 分别累计正确总实体数、识别实体数以及真实实体数
                total_acc_entities_length += acc_entities_length
                total_predict_entities_length += predict_entities_length
                total_gold_entities_length += gold_entities_length
                # 对损失函数进行反向传播
                loss.backward()
                # 通过optimizer.step()计算损失, 梯度和更新参数
                optimizer.step()
                # 记录训练日志
                train_log_file.write(log_text + "\n")
                step += 1
            # 获取当前epochs平均损失值(每一轮迭代的损失总值除以总数据量)
            epoch_loss = total_loss / data_size["train"]
            # 计算当前epochs准确率
            total_acc = total_acc_entities_length / total_predict_entities_length
            # 计算当前epochs召回率
            total_recall = total_acc_entities_length / total_gold_entities_length
            # 计算当前epochs的F1值
            total_f1 = 0
            if total_acc + total_recall != 0:
                total_f1 = 2 * total_acc * total_recall / (total_acc + total_recall)
            log_text = "Epoch: %s " \
                       "| mean loss: %.5f " \
                       "| total acc: %.5f " \
                       "| total recall: %.5f " \
                       "| total f1 scroe: %.5f" % (epoch, epoch_loss,
                                                   total_acc,
                                                   total_recall,
                                                   total_f1)
            # 当前epochs训练后更新学习率, 必须在优化器更新之后
            scheduler.step()
            # 记录当前epochs训练loss值(用于图表展示), 准确率, 召回率, f1值
            train_loss_list.append(epoch_loss)
            train_acc_list.append(total_acc)
            train_recall_list.append(total_recall)
            train_f1_list.append(total_f1)
            train_log_file.write(log_text + "\n")
            # 定义要记录的正确总实体数, 识别实体数以及真实实体数
            total_acc_entities_length, \
            total_predict_entities_length, \
            total_gold_entities_length = 0, 0, 0
            # 定义每batch步数, 批次loss总值, 准确度, f1值
            step, total_loss, correct, f1 = 1, 0.0, 0, 0
            # 开启当前epochs的验证部分
            for inputs, labels in tqdm(data_loader["validation"]):
                # 将数据以Variable进行封装
                inputs, labels = Variable(inputs), Variable(labels)
                # 此处调用的是BiLSTM_CRF类中的neg_log_likelihood 函数
                # 返回最终的CRF的对数似然结果
                loss = model.neg_log_likelihood(inputs, labels)
                # 获取当前步的loss值, 由tensor转为数字
                step_loss = loss.data
                # 累计每步损失值
                total_loss += step_loss
                # 获取解码最佳路径列表, 此时调用的是BiLSTM_CRF类中的forward()函数
                best_path_list = model(inputs)
                # 模型评估指标值获取: 当前批次准确率, 召回率, F1值以及对应的实体个数
                step_acc, step_recall, f1_score, acc_entities_length, \
                predict_entities_length, gold_entities_length = evaluate(inputs.tolist(),
                                                                         labels.tolist(),
                                                                         best_path_list,
                                                                         id_to_char,
                                                                         id_to_tag)
                # 训练日志内容
                log_text = "Epoch: %s | Step: %s " \
                           "| loss: %.5f " \
                           "| acc: %.5f " \
                           "| recall: %.5f " \
                           "| f1 score: %.5f" % \
                           (epoch, step, step_loss, step_acc, step_recall,f1_score)
                # 分别累计正确总实体数、识别实体数以及真实实体数
                total_acc_entities_length += acc_entities_length
                total_predict_entities_length += predict_entities_length
                total_gold_entities_length += gold_entities_length
                # 记录验证集损失日志
                validate_log_file.write(log_text + "\n")
                step += 1
            # 获取当前批次平均损失值(每一批次损失总值除以数据量)
            epoch_loss = total_loss / data_size["validation"]
            # 计算总批次准确率
            total_acc = total_acc_entities_length / total_predict_entities_length
            # 计算总批次召回率
            total_recall = total_acc_entities_length / total_gold_entities_length
            # 计算总批次F1值
            total_f1 = 0
            if total_acc + total_recall != 0:
                total_f1 = 2 * total_acc * total_recall / (total_acc + total_recall)
            log_text = "Epoch: %s " \
                       "| mean loss: %.5f " \
                       "| total acc: %.5f " \
                       "| total recall: %.5f " \
                       "| total f1 scroe: %.5f" % (epoch, epoch_loss,
                                                   total_acc,
                                                   total_recall,
                                                   total_f1)
            # 记录当前批次验证loss值(用于图表展示)准确率, 召回率, f1值
            validate_loss_list.append(epoch_loss)
            validate_acc_list.append(total_acc)
            validate_recall_list.append(total_recall)
            validate_f1_list.append(total_f1)
            validate_log_file.write(log_text + "\n")
        # 保存模型
        torch.save(model.state_dict(), model_saved_path)
        # 将loss下降历史数据转为图片存储
        save_train_history_image(train_loss_list,
                                 validate_loss_list,
                                 train_history_image_path,
                                 "Loss")
        # 将准确率提升历史数据转为图片存储
        save_train_history_image(train_acc_list,
                                 validate_acc_list,
                                 train_history_image_path,
                                 "Acc")
        # 将召回率提升历史数据转为图片存储
        save_train_history_image(train_recall_list,
                                 validate_recall_list,
                                 train_history_image_path,
                                 "Recall")
        # 将F1上升历史数据转为图片存储
        save_train_history_image(train_f1_list,
                                 validate_f1_list,
                                 train_history_image_path,
                                 "F1")
        print("train Finished".center(100, "-"))
      # 按照传入的不同路径, 绘制不同的训练曲线
      def save_train_history_image(train_history_list,
                                 validate_history_list,
                                 history_image_path,
                                 data_type):
        # 根据训练集的数据列表, 绘制折线图
        plt.plot(train_history_list, label="Train %s History" % (data_type))
        # 根据测试集的数据列表, 绘制折线图
        plt.plot(validate_history_list, label="Validate %s History" % (data_type))
        # 将图片放置在最优位置
        plt.legend(loc="best")
        # 设置x轴的图标为轮次Epochs
        plt.xlabel("Epochs")
        # 设置y轴的图标为参数data_type
        plt.ylabel(data_type)
        # 将绘制好的图片保存在特定的路径下面, 并修改图片名字中的"plot"为对应的data_type
        plt.savefig(history_image_path.replace("plot", data_type))
        plt.close()
      

    • 代码实现位置: /data/doctor_offline/ner_model/train.py

    • 输入参数:
      # 参数1:批次大小
      BATCH_SIZE = 8
      # 参数2:训练数据文件路径
      train_data_file_path = "data/train.npz"
      # 参数3:加载 DataLoader 数据
      data_loader, data_size = load_dataset(train_data_file_path, BATCH_SIZE)
      # 参数4:记录当前训练时间(拼成字符串用)
      time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime(time.time()))
      # 参数5:标签码表对照
      tag_to_id = {"O": 0, "B-dis": 1, "I-dis": 2, "B-sym": 3, "I-sym": 4, "<START>": 5, "<STOP>": 6}
      # 参数6:训练文件存放路径
      model_saved_path = "model/bilstm_crf_state_dict_%s.pt" % (time_str)
      # 参数7:训练日志文件存放路径
      train_log_path = "log/train_%s.log" % (time_str)
      # 参数8:验证打印日志存放路径
      validate_log_path = "log/validate_%s.log" % (time_str)
      # 参数9:训练历史记录图存放路径
      train_history_image_path = "log/bilstm_crf_train_plot_%s.png" % (time_str)
      # 参数10:字向量维度
      EMBEDDING_DIM = 200
      # 参数11:隐层维度
      HIDDEN_DIM = 100
      # 参数12:句子长度
      SENTENCE_LENGTH = 20
      # 参数13:堆叠 LSTM 层数
      NUM_LAYERS = 1
      # 参数14:训练批次
      EPOCHS = 100
      # 参数15:初始化学习率
      LEARNING_RATE = 0.5
      

    • 调用:
      if __name__ == '__main__':
        train(data_loader, data_size, BATCH_SIZE, EMBEDDING_DIM, HIDDEN_DIM, SENTENCE_LENGTH,
              NUM_LAYERS, EPOCHS, LEARNING_RATE, tag_to_id,
              model_saved_path, train_log_path, validate_log_path, train_history_image_path)
      

    • 输出效果:
      • 模型训练结果文件保存位置:model/bilstmcrf_state_dict[年月日时分秒时间字符串].pt
      • 模型训练日志文件保存位置:log/train_[年月日时分秒时间字符串].log
      • 模型验证日志文件保存位置:log/validate_[年月日时分秒时间字符串].log
      • 模型训练损失历史记录图片保存位置:log/bilstmcrf_train_Loss[年月日时分秒时间字符串].png
      • 模型训练准确率历史记录图片保存位置:log/bilstmcrf_train_Acc[年月日时分秒时间字符串].png
      • 模型训练召回率历史记录图片保存位置:log/bilstmcrf_train_Recall[年月日时分秒时间字符串].png
      • 模型训练F1值历史记录图片保存位置:log/bilstmcrf_train_F1[年月日时分秒时间字符串].png

    • 训练日志:
      Epoch: 0 | train loss: 366.58832 |acc: 0.632 |recall: 0.503 |f1 score: 0.56 | validate loss: 666.032 |acc: 0.591 |recall: 0.457 |f1 score: 0.515
      Epoch: 1 | train loss: 123.87159 |acc: 0.743 |recall: 0.687 |f1 score: 0.714 | validate loss: 185.021 |acc: 0.669 |recall: 0.606 |f1 score: 0.636
      Epoch: 2 | train loss: 113.04003 |acc: 0.738 |recall: 0.706 |f1 score: 0.722 | validate loss: 107.393 |acc: 0.711 |recall: 0.663 |f1 score: 0.686
      Epoch: 3 | train loss: 119.14317 |acc: 0.751 |recall: 0.692 |f1 score: 0.721 | validate loss: 158.381 |acc: 0.713 |recall: 0.64 |f1 score: 0.674
      Epoch: 4 | train loss: 105.81506 |acc: 0.741 |recall: 0.699 |f1 score: 0.72 | validate loss: 118.99 |acc: 0.669 |recall: 0.624 |f1 score: 0.646
      Epoch: 5 | train loss: 86.67545 |acc: 0.773 |recall: 0.751 |f1 score: 0.762 | validate loss: 123.636 |acc: 0.64 |recall: 0.718 |f1 score: 0.676
      Epoch: 6 | train loss: 79.66924 |acc: 0.808 |recall: 0.772 |f1 score: 0.789 | validate loss: 89.771 |acc: 0.735 |recall: 0.714 |f1 score: 0.724
      Epoch: 7 | train loss: 85.35771 |acc: 0.766 |recall: 0.752 |f1 score: 0.759 | validate loss: 141.233 |acc: 0.675 |recall: 0.7 |f1 score: 0.687
      Epoch: 8 | train loss: 82.38535 |acc: 0.787 |recall: 0.748 |f1 score: 0.767 | validate loss: 108.429 |acc: 0.717 |recall: 0.673 |f1 score: 0.694
      Epoch: 9 | train loss: 82.46296 |acc: 0.783 |recall: 0.751 |f1 score: 0.767 | validate loss: 74.716 |acc: 0.692 |recall: 0.702 |f1 score: 0.697
      Epoch: 10 | train loss: 75.12292 |acc: 0.814 |recall: 0.779 |f1 score: 0.796 | validate loss: 90.693 |acc: 0.672 |recall: 0.7 |f1 score: 0.686
      Epoch: 11 | train loss: 74.89426 |acc: 0.813 |recall: 0.77 |f1 score: 0.791 | validate loss: 77.161 |acc: 0.729 |recall: 0.718 |f1 score: 0.724
      Epoch: 12 | train loss: 76.39055 |acc: 0.814 |recall: 0.785 |f1 score: 0.799 | validate loss: 132.545 |acc: 0.806 |recall: 0.685 |f1 score: 0.74
      Epoch: 13 | train loss: 75.01093 |acc: 0.814 |recall: 0.787 |f1 score: 0.8 | validate loss: 101.596 |acc: 0.765 |recall: 0.681 |f1 score: 0.721
      Epoch: 14 | train loss: 74.35796 |acc: 0.83 |recall: 0.802 |f1 score: 0.816 | validate loss: 92.535 |acc: 0.745 |recall: 0.777 |f1 score: 0.761
      Epoch: 15 | train loss: 73.27102 |acc: 0.818 |recall: 0.791 |f1 score: 0.804 | validate loss: 109.51 |acc: 0.68 |recall: 0.76 |f1 score: 0.717
      Epoch: 16 | train loss: 67.66725 |acc: 0.841 |recall: 0.811 |f1 score: 0.826 | validate loss: 93.047 |acc: 0.768 |recall: 0.738 |f1 score: 0.753
      Epoch: 17 | train loss: 63.75809 |acc: 0.83 |recall: 0.813 |f1 score: 0.822 | validate loss: 76.231 |acc: 0.784 |recall: 0.776 |f1 score: 0.78
      Epoch: 18 | train loss: 60.30417 |acc: 0.845 |recall: 0.829 |f1 score: 0.837 | validate loss: 76.019 |acc: 0.806 |recall: 0.758 |f1 score: 0.781
      Epoch: 19 | train loss: 60.30238 |acc: 0.849 |recall: 0.823 |f1 score: 0.836 | validate loss: 90.269 |acc: 0.748 |recall: 0.733 |f1 score: 0.741
      Epoch: 20 | train loss: 60.20072 |acc: 0.847 |recall: 0.82 |f1 score: 0.833 | validate loss: 61.756 |acc: 0.81 |recall: 0.77 |f1 score: 0.79
      Epoch: 21 | train loss: 58.98606 |acc: 0.844 |recall: 0.82 |f1 score: 0.832 | validate loss: 60.799 |acc: 0.765 |recall: 0.754 |f1 score: 0.759
      Epoch: 22 | train loss: 60.23671 |acc: 0.848 |recall: 0.828 |f1 score: 0.838 | validate loss: 65.676 |acc: 0.787 |recall: 0.781 |f1 score: 0.784
      Epoch: 23 | train loss: 58.57862 |acc: 0.849 |recall: 0.827 |f1 score: 0.838 | validate loss: 65.975 |acc: 0.794 |recall: 0.754 |f1 score: 0.774
      Epoch: 24 | train loss: 58.93968 |acc: 0.848 |recall: 0.827 |f1 score: 0.838 | validate loss: 66.994 |acc: 0.784 |recall: 0.746 |f1 score: 0.764
      Epoch: 25 | train loss: 59.91834 |acc: 0.862 |recall: 0.828 |f1 score: 0.845 | validate loss: 68.794 |acc: 0.795 |recall: 0.756 |f1 score: 0.775
      Epoch: 26 | train loss: 59.09166 |acc: 0.84 |recall: 0.823 |f1 score: 0.831 | validate loss: 68.508 |acc: 0.746 |recall: 0.758 |f1 score: 0.752
      Epoch: 27 | train loss: 58.0584 |acc: 0.856 |recall: 0.84 |f1 score: 0.848 | validate loss: 53.158 |acc: 0.802 |recall: 0.774 |f1 score: 0.788
      Epoch: 28 | train loss: 54.2857 |acc: 0.858 |recall: 0.834 |f1 score: 0.845 | validate loss: 60.243 |acc: 0.816 |recall: 0.772 |f1 score: 0.793
      Epoch: 29 | train loss: 56.44759 |acc: 0.845 |recall: 0.838 |f1 score: 0.841 | validate loss: 56.497 |acc: 0.768 |recall: 0.77 |f1 score: 0.769
      Epoch: 30 | train loss: 57.90492 |acc: 0.868 |recall: 0.832 |f1 score: 0.85 | validate loss: 75.158 |acc: 0.773 |recall: 0.762 |f1 score: 0.768
      Epoch: 31 | train loss: 56.81468 |acc: 0.861 |recall: 0.835 |f1 score: 0.847 | validate loss: 56.742 |acc: 0.796 |recall: 0.784 |f1 score: 0.79
      Epoch: 32 | train loss: 54.72623 |acc: 0.86 |recall: 0.844 |f1 score: 0.852 | validate loss: 63.175 |acc: 0.757 |recall: 0.78 |f1 score: 0.768
      Epoch: 33 | train loss: 60.10299 |acc: 0.846 |recall: 0.813 |f1 score: 0.829 | validate loss: 68.994 |acc: 0.768 |recall: 0.724 |f1 score: 0.745
      Epoch: 34 | train loss: 59.67491 |acc: 0.849 |recall: 0.826 |f1 score: 0.837 | validate loss: 58.662 |acc: 0.8 |recall: 0.739 |f1 score: 0.769
      Epoch: 35 | train loss: 65.01099 |acc: 0.857 |recall: 0.83 |f1 score: 0.844 | validate loss: 69.299 |acc: 0.772 |recall: 0.752 |f1 score: 0.762
      Epoch: 36 | train loss: 61.52783 |acc: 0.856 |recall: 0.828 |f1 score: 0.842 | validate loss: 82.373 |acc: 0.761 |recall: 0.777 |f1 score: 0.769
      Epoch: 37 | train loss: 66.19576 |acc: 0.844 |recall: 0.822 |f1 score: 0.833 | validate loss: 79.853 |acc: 0.791 |recall: 0.77 |f1 score: 0.781
      Epoch: 38 | train loss: 60.32529 |acc: 0.841 |recall: 0.828 |f1 score: 0.835 | validate loss: 69.346 |acc: 0.773 |recall: 0.755 |f1 score: 0.764
      Epoch: 39 | train loss: 63.8836 |acc: 0.837 |recall: 0.819 |f1 score: 0.828 | validate loss: 74.759 |acc: 0.732 |recall: 0.759 |f1 score: 0.745
      Epoch: 40 | train loss: 67.28363 |acc: 0.838 |recall: 0.824 |f1 score: 0.831 | validate loss: 63.027 |acc: 0.768 |recall: 0.764 |f1 score: 0.766
      Epoch: 41 | train loss: 61.40488 |acc: 0.852 |recall: 0.826 |f1 score: 0.839 | validate loss: 58.976 |acc: 0.802 |recall: 0.755 |f1 score: 0.778
      Epoch: 42 | train loss: 61.04982 |acc: 0.856 |recall: 0.817 |f1 score: 0.836 | validate loss: 58.47 |acc: 0.783 |recall: 0.74 |f1 score: 0.761
      Epoch: 43 | train loss: 64.40567 |acc: 0.849 |recall: 0.821 |f1 score: 0.835 | validate loss: 63.506 |acc: 0.764 |recall: 0.765 |f1 score: 0.765
      Epoch: 44 | train loss: 65.09746 |acc: 0.845 |recall: 0.805 |f1 score: 0.825 | validate loss: 65.535 |acc: 0.773 |recall: 0.743 |f1 score: 0.758
      Epoch: 45 | train loss: 63.26585 |acc: 0.848 |recall: 0.808 |f1 score: 0.827 | validate loss: 62.477 |acc: 0.789 |recall: 0.733 |f1 score: 0.76
      Epoch: 46 | train loss: 63.91504 |acc: 0.847 |recall: 0.812 |f1 score: 0.829 | validate loss: 59.916 |acc: 0.779 |recall: 0.751 |f1 score: 0.765
      Epoch: 47 | train loss: 62.3592 |acc: 0.845 |recall: 0.824 |f1 score: 0.835 | validate loss: 63.363 |acc: 0.775 |recall: 0.761 |f1 score: 0.768
      Epoch: 48 | train loss: 63.13221 |acc: 0.843 |recall: 0.823 |f1 score: 0.833 | validate loss: 65.71 |acc: 0.767 |recall: 0.755 |f1 score: 0.761
      Epoch: 49 | train loss: 64.9964 |acc: 0.845 |recall: 0.811 |f1 score: 0.828 | validate loss: 65.174 |acc: 0.768 |recall: 0.74 |f1 score: 0.754
      Epoch: 50 | train loss: 62.40605 |acc: 0.847 |recall: 0.817 |f1 score: 0.832 | validate loss: 60.761 |acc: 0.776 |recall: 0.746 |f1 score: 0.761
      Epoch: 51 | train loss: 63.05476 |acc: 0.845 |recall: 0.812 |f1 score: 0.828 | validate loss: 64.217 |acc: 0.764 |recall: 0.748 |f1 score: 0.756
      Epoch: 52 | train loss: 59.77727 |acc: 0.84 |recall: 0.831 |f1 score: 0.836 | validate loss: 60.48 |acc: 0.79 |recall: 0.759 |f1 score: 0.774
      Epoch: 53 | train loss: 62.7249 |acc: 0.828 |recall: 0.813 |f1 score: 0.821 | validate loss: 64.584 |acc: 0.757 |recall: 0.757 |f1 score: 0.757
      Epoch: 54 | train loss: 61.1763 |acc: 0.842 |recall: 0.832 |f1 score: 0.837 | validate loss: 61.088 |acc: 0.775 |recall: 0.768 |f1 score: 0.771
      Epoch: 55 | train loss: 64.04366 |acc: 0.835 |recall: 0.816 |f1 score: 0.826 | validate loss: 68.183 |acc: 0.784 |recall: 0.742 |f1 score: 0.762
      Epoch: 56 | train loss: 66.76939 |acc: 0.84 |recall: 0.813 |f1 score: 0.827 | validate loss: 67.284 |acc: 0.77 |recall: 0.748 |f1 score: 0.759
      Epoch: 57 | train loss: 67.85329 |acc: 0.826 |recall: 0.789 |f1 score: 0.807 | validate loss: 69.961 |acc: 0.766 |recall: 0.732 |f1 score: 0.749
      Epoch: 58 | train loss: 64.79573 |acc: 0.84 |recall: 0.812 |f1 score: 0.826 | validate loss: 73.358 |acc: 0.754 |recall: 0.735 |f1 score: 0.745
      Epoch: 59 | train loss: 65.36249 |acc: 0.862 |recall: 0.826 |f1 score: 0.844 | validate loss: 66.552 |acc: 0.783 |recall: 0.766 |f1 score: 0.774
      Epoch: 60 | train loss: 63.43061 |acc: 0.835 |recall: 0.811 |f1 score: 0.823 | validate loss: 63.138 |acc: 0.771 |recall: 0.746 |f1 score: 0.759
      Epoch: 61 | train loss: 62.34639 |acc: 0.848 |recall: 0.825 |f1 score: 0.836 | validate loss: 59.656 |acc: 0.783 |recall: 0.756 |f1 score: 0.769
      Epoch: 62 | train loss: 61.83451 |acc: 0.83 |recall: 0.814 |f1 score: 0.822 | validate loss: 60.443 |acc: 0.765 |recall: 0.751 |f1 score: 0.758
      Epoch: 63 | train loss: 64.78461 |acc: 0.854 |recall: 0.818 |f1 score: 0.836 | validate loss: 61.125 |acc: 0.786 |recall: 0.748 |f1 score: 0.767
      Epoch: 64 | train loss: 63.43409 |acc: 0.838 |recall: 0.818 |f1 score: 0.828 | validate loss: 62.396 |acc: 0.77 |recall: 0.757 |f1 score: 0.764
      Epoch: 65 | train loss: 61.20197 |acc: 0.854 |recall: 0.815 |f1 score: 0.834 | validate loss: 59.019 |acc: 0.79 |recall: 0.75 |f1 score: 0.769
      Epoch: 66 | train loss: 59.69791 |acc: 0.851 |recall: 0.82 |f1 score: 0.836 | validate loss: 55.06 |acc: 0.789 |recall: 0.754 |f1 score: 0.771
      Epoch: 67 | train loss: 63.16074 |acc: 0.836 |recall: 0.811 |f1 score: 0.823 | validate loss: 61.48 |acc: 0.764 |recall: 0.745 |f1 score: 0.755
      Epoch: 68 | train loss: 62.15521 |acc: 0.845 |recall: 0.824 |f1 score: 0.835 | validate loss: 62.407 |acc: 0.778 |recall: 0.761 |f1 score: 0.769
      Epoch: 69 | train loss: 61.90574 |acc: 0.847 |recall: 0.828 |f1 score: 0.838 | validate loss: 59.801 |acc: 0.781 |recall: 0.762 |f1 score: 0.771
      Epoch: 70 | train loss: 60.51348 |acc: 0.852 |recall: 0.827 |f1 score: 0.839 | validate loss: 56.632 |acc: 0.781 |recall: 0.761 |f1 score: 0.771
      Epoch: 71 | train loss: 62.78683 |acc: 0.856 |recall: 0.823 |f1 score: 0.84 | validate loss: 62.867 |acc: 0.796 |recall: 0.757 |f1 score: 0.776
      Epoch: 72 | train loss: 62.11708 |acc: 0.845 |recall: 0.82 |f1 score: 0.833 | validate loss: 57.211 |acc: 0.784 |recall: 0.754 |f1 score: 0.769
      Epoch: 73 | train loss: 63.2298 |acc: 0.839 |recall: 0.816 |f1 score: 0.828 | validate loss: 60.247 |acc: 0.764 |recall: 0.752 |f1 score: 0.758
      Epoch: 74 | train loss: 61.87119 |acc: 0.848 |recall: 0.828 |f1 score: 0.838 | validate loss: 59.692 |acc: 0.782 |recall: 0.765 |f1 score: 0.774
      Epoch: 75 | train loss: 59.88628 |acc: 0.851 |recall: 0.821 |f1 score: 0.836 | validate loss: 59.461 |acc: 0.78 |recall: 0.755 |f1 score: 0.767
      Epoch: 76 | train loss: 61.97182 |acc: 0.858 |recall: 0.812 |f1 score: 0.835 | validate loss: 59.748 |acc: 0.78 |recall: 0.749 |f1 score: 0.765
      Epoch: 77 | train loss: 62.2035 |acc: 0.836 |recall: 0.811 |f1 score: 0.823 | validate loss: 56.778 |acc: 0.768 |recall: 0.748 |f1 score: 0.758
      Epoch: 78 | train loss: 59.90309 |acc: 0.846 |recall: 0.823 |f1 score: 0.835 | validate loss: 59.424 |acc: 0.771 |recall: 0.76 |f1 score: 0.765
      Epoch: 79 | train loss: 62.48097 |acc: 0.844 |recall: 0.821 |f1 score: 0.833 | validate loss: 57.535 |acc: 0.769 |recall: 0.755 |f1 score: 0.762
      Epoch: 80 | train loss: 65.83723 |acc: 0.853 |recall: 0.83 |f1 score: 0.842 | validate loss: 60.798 |acc: 0.782 |recall: 0.762 |f1 score: 0.772
      Epoch: 81 | train loss: 67.69897 |acc: 0.848 |recall: 0.812 |f1 score: 0.83 | validate loss: 62.135 |acc: 0.78 |recall: 0.746 |f1 score: 0.763
      Epoch: 82 | train loss: 64.45554 |acc: 0.863 |recall: 0.845 |f1 score: 0.854 | validate loss: 62.102 |acc: 0.793 |recall: 0.775 |f1 score: 0.784
      Epoch: 83 | train loss: 59.9239 |acc: 0.857 |recall: 0.84 |f1 score: 0.848 | validate loss: 57.003 |acc: 0.788 |recall: 0.771 |f1 score: 0.779
      Epoch: 84 | train loss: 65.42567 |acc: 0.859 |recall: 0.831 |f1 score: 0.845 | validate loss: 61.993 |acc: 0.788 |recall: 0.763 |f1 score: 0.775
      Epoch: 85 | train loss: 62.69893 |acc: 0.852 |recall: 0.828 |f1 score: 0.84 | validate loss: 59.489 |acc: 0.786 |recall: 0.761 |f1 score: 0.773
      Epoch: 86 | train loss: 64.58199 |acc: 0.858 |recall: 0.831 |f1 score: 0.845 | validate loss: 60.414 |acc: 0.789 |recall: 0.764 |f1 score: 0.776
      Epoch: 87 | train loss: 58.41865 |acc: 0.875 |recall: 0.838 |f1 score: 0.856 | validate loss: 56.525 |acc: 0.805 |recall: 0.768 |f1 score: 0.786
      Epoch: 88 | train loss: 61.39529 |acc: 0.848 |recall: 0.824 |f1 score: 0.836 | validate loss: 56.678 |acc: 0.783 |recall: 0.759 |f1 score: 0.771
      Epoch: 89 | train loss: 63.69639 |acc: 0.857 |recall: 0.818 |f1 score: 0.837 | validate loss: 59.014 |acc: 0.787 |recall: 0.751 |f1 score: 0.769
      Epoch: 90 | train loss: 61.78225 |acc: 0.841 |recall: 0.84 |f1 score: 0.84 | validate loss: 59.58 |acc: 0.773 |recall: 0.775 |f1 score: 0.774
      Epoch: 91 | train loss: 58.19114 |acc: 0.845 |recall: 0.826 |f1 score: 0.836 | validate loss: 55.284 |acc: 0.776 |recall: 0.758 |f1 score: 0.767
      Epoch: 92 | train loss: 58.67227 |acc: 0.857 |recall: 0.82 |f1 score: 0.838 | validate loss: 54.982 |acc: 0.787 |recall: 0.753 |f1 score: 0.77
      Epoch: 93 | train loss: 60.79532 |acc: 0.858 |recall: 0.83 |f1 score: 0.844 | validate loss: 57.808 |acc: 0.792 |recall: 0.764 |f1 score: 0.778
      Epoch: 94 | train loss: 56.71145 |acc: 0.872 |recall: 0.851 |f1 score: 0.861 | validate loss: 53.551 |acc: 0.804 |recall: 0.785 |f1 score: 0.795
      Epoch: 95 | train loss: 58.791 |acc: 0.864 |recall: 0.83 |f1 score: 0.847 | validate loss: 54.284 |acc: 0.793 |recall: 0.765 |f1 score: 0.779
      Epoch: 96 | train loss: 60.07491 |acc: 0.849 |recall: 0.828 |f1 score: 0.839 | validate loss: 55.524 |acc: 0.78 |recall: 0.764 |f1 score: 0.772
      Epoch: 97 | train loss: 61.53479 |acc: 0.86 |recall: 0.825 |f1 score: 0.842 | validate loss: 56.891 |acc: 0.796 |recall: 0.759 |f1 score: 0.777
      Epoch: 98 | train loss: 61.94878 |acc: 0.85 |recall: 0.836 |f1 score: 0.843 | validate loss: 57.019 |acc: 0.783 |recall: 0.771 |f1 score: 0.777
      Epoch: 99 | train loss: 58.49541 |acc: 0.86 |recall: 0.834 |f1 score: 0.847 | validate loss: 56.162 |acc: 0.795 |recall: 0.767 |f1 score: 0.781
      

    • 第六步: 绘制损失曲线和评估曲线图
      • 训练和验证损失对照曲线:

    6.5 模型训练 - 图1


    • 分析: 损失对照曲线一直下降, 从第5个epoch开始, 迅速降到比较理想的位置, 说明模型能够从数据中获取规律了, 到第40个批次之后, 模型趋于稳定, 说明参数基本能够已经得到最优化效果, 此时, 根据对scheduler的设置, 通过该方法已经对优化器进行了近8次的迭代, 应该在我们原本设置的初始学习率基础上缩小了0.2的8次方倍, 此时应该找到了当前最优解, 因此也就趋于稳定了.

    • 训练和验证准确率对照曲线:

    6.5 模型训练 - 图2


    • 分析:
    • 首先,准确率是指识别正确的实体识别出的实体中的比例.
    • 根据对照曲线来看,整体学习结果都在趋于准确率上升方向增加,而且随着批次的增加曲线震动相对平稳,不过可能由于训练与验证样本分布不均衡或者噪声等原因,导致最终验证集的准确度没有达到与训练集相同的情况.
    • 最终的训练集和验证集的召回率分别在:0.85和0.78左右.

    • 训练和验证召回率对照曲线:

    6.5 模型训练 - 图3


    • 分析:
    • 在此召回率是指识别正确的实体占当前批次所包含的所有实体总数的比例.
    • 关于训练和验证召回率对照曲线,可以看出召回率的变化相对比较平滑,基本上也在40步左右趋于稳定.
    • 最终的训练集和验证集的召回率分别在:0.83和0.75左右.

    • 训练和验证F1值对照曲线:

    6.5 模型训练 - 图4


    • 分析:
    • F1值主要是指训练效果而言,在不多识别实体的情况下同时提高准确度的衡量指标.
    • 其公式为:2×准确率×召回率 / (准确率+召回率)
    • 从曲线可见整体F1值上升与损失、召回率的曲线比较接近,说明在识别出的实体中,正确率比较问题,不过根据前面的准确度来分析,可能在识别过程中,增加了识别出的实体个数而导致不稳定。从这方面来说,可以验证样本不均衡问题以及噪声对模型的影响还是比较大的。
    • 从整体而言,F1值基本也在第40步之后趋于稳定,最终的训练集和验证集的结果在:0.85和0.75左右。

    • 小节总结:
      • 学习了数据预处理的相关方法
        • 原始数据集的字符经过数字化编码变成向量
        • 标注数据集的字符经过数字化编码变成向量
      • 学习生成批量训练数据的方法
      • 学习了模型训练相关代码的实现
        • 准确率和召回率评估的代码
        • 模型构建类的全部内部函数代码
        • 启动训练流程的代码