- 模型预测的实现过程:
import osimport torchimport torch.nn as nn# 导入RNN模型结构from RNN_MODEL import RNN# 导入bert预训练模型编码函数from bert_chinese_encode import get_bert_encode_for_single# 预加载的模型参数路径MODEL_PATH = './BERT_RNN.pth'# 隐层节点数, 输入层尺寸, 类别数都和训练时相同即可n_hidden = 128input_size = 768n_categories = 2# 实例化RNN模型, 并加载保存模型参数rnn = RNN(input_size, n_hidden, n_categories)rnn.load_state_dict(torch.load(MODEL_PATH))def _test(line_tensor):"""模型测试函数, 它将用在模型预测函数中, 用于调用RNN模型并返回结果.它的参数line_tensor代表输入文本的张量表示"""# 初始化隐层张量hidden = rnn.initHidden()# 与训练时相同, 遍历输入文本的每一个字符for i in range(line_tensor.size()[0]):# 将其逐次输送给rnn模型output, hidden = rnn(line_tensor[i].unsqueeze(0), hidden)# 获得rnn模型最终的输出return outputdef predict(input_line):"""模型预测函数, 输入参数input_line代表需要预测的文本"""# 不自动求解梯度with torch.no_grad():# 将input_line使用bert模型进行编码output = _test(get_bert_encode_for_single(input_line))# 从output中取出最大值对应的索引, 比较的维度是1_, topi = output.topk(1, 1)# 返回结果数值return topi.item()
tensor.topk演示:
>>> tr = torch.randn(1, 2)>>> trtensor([[-0.1808, -1.4170]])>>> tr.topk(1, 1)torch.return_types.topk(values=tensor([[-0.1808]]), indices=tensor([[0]]))
- 代码位置: /data/doctor_offline/review_model/predict.py
- 输入参数:
input_line = "点瘀样尖针性发多"
- 调用:
result = predict(input_line)print("result:", result)
- 输出效果:
result: 0
- 模型批量预测的实现过程:
def batch_predict(input_path, output_path):"""批量预测函数, 以原始文本(待识别的命名实体组成的文件)输入路径和预测过滤后(去除掉非命名实体的文件)的输出路径为参数"""# 待识别的命名实体组成的文件是以疾病名称为csv文件名,# 文件中的每一行是该疾病对应的症状命名实体# 读取路径下的每一个csv文件名, 装入csv列表之中csv_list = os.listdir(input_path)# 遍历每一个csv文件for csv in csv_list:# 以读的方式打开每一个csv文件with open(os.path.join(input_path, csv), "r") as fr:# 再以写的方式打开输出路径的同名csv文件with open(os.path.join(output_path, csv), "w") as fw:# 读取csv文件的每一行input_line = fr.readline()# 使用模型进行预测res = predict(input_line)# 如果结果为1if res:# 说明审核成功, 写入到输出csv中fw.write(input_line + "\n")else:pass
- 代码位置: /data/doctor_offline/review_model/predict.py
- 输入参数:
input_path = "/data/doctor_offline/structured/noreview/"output_path = "/data/doctor_offline/structured/reviewed/"
- 调用:
batch_predict(input_path, output_path)
- 输出效果:
- 在输出路径下生成与输入路径等数量的同名csv文件, 内部的症状实体是被审核的可用实体.
- 小节总结:
- 学习并实现了模型预测的函数: predict(input_line).
- 学习并实现了模型批量预测的函数: batch_predict(input_path, output_path)
