引入了新的数据DuIE进行重新训练
使用脚本对数据的格式进行转化
对测试数据的转化见先前

  1. import json
  2. def find_all(sub, s):
  3. index_list = []
  4. index = s.find(sub)
  5. while index != -1:
  6. index_list.append([index,index+len(sub)-1])
  7. index = s.find(sub, index + 1)
  8. if len(index_list) > 0:
  9. return index_list
  10. else:
  11. #print("err")
  12. return index_list
  13. #sen="生生不息CSOL生化狂潮让你填弹狂扫"
  14. #print(find_all("CSOL",sen))
  15. def func1():
  16. with open("D:\\New_desktop\\train_data.json", 'r', encoding='utf-8') as load_f:
  17. labels = {}
  18. info=[]
  19. error=0
  20. total=0
  21. l=['机构', '影视作品', '电视综艺', '网络小说', '歌曲', '生物', '历史人物', '人物', '图书作品', '学校', '音乐专辑', '企业', '书籍', '国家', '出版社', '目', '地点']
  22. for i in l:
  23. labels[i]=0
  24. import random
  25. for line in load_f.readlines():
  26. total+=1
  27. dic = json.loads(line)
  28. after={}
  29. after['text']=dic['text']
  30. after['label'] = {}
  31. append_flag=1
  32. for j in dic['spo_list']:
  33. if j['object_type'] in l:
  34. labels[j['object_type']]+=1
  35. after['label'][j['object_type']] = {}
  36. result1 = find_all(j['object'], dic['text'])
  37. if len(result1) != 0:
  38. after['label'][j['object_type']][j['object']] = find_all(j['object'], dic['text'])
  39. else:
  40. append_flag = 0
  41. else:
  42. append_flag=0
  43. if j['subject_type'] in l:
  44. labels[j['subject_type']] += 1
  45. after['label'][j['subject_type']]={}
  46. result2=find_all(j['subject'], dic['text'])
  47. if (len(result2))!=0:
  48. after['label'][j['subject_type']][j['subject']] =find_all(j['subject'], dic['text'])
  49. else:
  50. append_flag=0
  51. else:
  52. append_flag=0
  53. if append_flag:
  54. info.append(after)
  55. sub_train = []
  56. for i in range(20000):
  57. sub_train.append(random.choice(info))
  58. print(len(info),total)
  59. with open("D:\\New_desktop\\train.json", "w",encoding='utf-8') as dump_f:
  60. for i in sub_train:
  61. a = json.dumps(i, ensure_ascii=False)
  62. dump_f.write(a)
  63. dump_f.write("\n")
  64. with open("D:\\New_desktop\\dev_data.json", 'r', encoding='utf-8') as load_f:
  65. labels2 = {}
  66. for i in l:
  67. labels2[i] = 0
  68. info=[]
  69. error=0
  70. total=0
  71. for line in load_f.readlines():
  72. total += 1
  73. dic = json.loads(line)
  74. after = {}
  75. after['text'] = dic['text']
  76. append_flag = 1
  77. after['label'] = {}
  78. for j in dic['spo_list']:
  79. if j['object_type'] in l:
  80. labels2[j['object_type']] += 1
  81. after['label'][j['object_type']] = {}
  82. result1 = find_all(j['object'], dic['text'])
  83. if len(result1) != 0:
  84. after['label'][j['object_type']][j['object']] = find_all(j['object'], dic['text'])
  85. else:
  86. append_flag = 0
  87. else:
  88. append_flag = 0
  89. if j['subject_type'] in l:
  90. labels2[j['subject_type']] += 1
  91. after['label'][j['subject_type']] = {}
  92. result2 = find_all(j['subject'], dic['text'])
  93. if (len(result2)) != 0:
  94. after['label'][j['subject_type']][j['subject']] = find_all(j['subject'], dic['text'])
  95. else:
  96. append_flag = 0
  97. else:
  98. append_flag = 0
  99. if append_flag:
  100. info.append(after)
  101. print(len(info),total)
  102. for i in labels:
  103. print(i, labels[i],labels2[i])
  104. sub_info=[]
  105. for i in range(2000):
  106. sub_info.append(random.choice(info))
  107. with open("D:\\New_desktop\\dev.json", "w",encoding='utf-8') as dump_f:
  108. for i in sub_info:
  109. a = json.dumps(i, ensure_ascii=False)
  110. dump_f.write(a)
  111. dump_f.write("\n")
  112. func1()

数据特性

共计28类的数据
观察到部分的类别的数量极度的不平衡
对应的结果发现有很大的偏差,或许是因为这个原因
QQ图片20200924001753.png
QQ图片20200924001859.png

解决方案

对应的对类别进行了裁剪
对应的多分类问题里,类别过多以及类别的数据量差异极大的影响了算法性能,因此对数据进行修剪得到新数据

  • ‘机构’,
  • ‘影视作品’,
  • ‘电视综艺’,
  • 网络小说’,
  • ‘歌曲’,
  • ‘生物’,
  • 历史人物’,
  • 人物’,
  • ‘图书作品’,
  • ‘学校’, ‘
  • 音乐专辑’,
  • ‘企业’,
  • ‘书籍’,
  • ‘国家’,
  • ‘出版社’,
  • ‘目’,
  • 地点’

同时由于先前的数据量过大(17w+ 的Train以及 2w+的dev)
进行缩小至2w的train以及2k的dev
随后得到数据
QQ图片20200924002255.png
QQ图片20200924002330.png

实际测试

image.png