在一些文本的预处理和调研的时候,单进程太慢了,尤其是在几十万数据集的时候
解决如下:

基本框架

  1. import json
  2. from itertools import combinations
  3. import threading
  4. import time
  5. import datetime
  6. import multiprocessing as mp
  7. import copy
  8. import json
  9. import random
  10. random.seed(413)
  11. def Function(name,json_data):
  12. result = {}
  13. # Process Data
  14. # Return reuslt in Dict format
  15. return result
  16. def prepare_data():
  17. input_data = []
  18. # Prepare Data
  19. # Single Sample is A Dict
  20. return input_data
  21. def extract_result(results):
  22. # Extract result and Display
  23. # Result is a List According to Input result task division
  24. def multi_process_tag(target_data):
  25. num_cores = int(mp.cpu_count())
  26. print("本地计算机有: " + str(num_cores) + " 核心")
  27. pool = mp.Pool(num_cores)
  28. # Split the task based on CPU core in Computer
  29. param_dict = {}
  30. start = 0
  31. end = len(target_data)
  32. step = int((end - start)/num_cores)
  33. print("per Task Step: ",step)
  34. # Construct Param Dict for multi-processing
  35. for i in range(num_cores):
  36. param_dict['task{}'.format(i)]= target_data[start:start+step]
  37. start = start+step
  38. param_dict['task{}'.format(num_cores)]= target_data[start:]
  39. start_t = datetime.datetime.now()
  40. # Run and get result
  41. results = [pool.apply_async(Function, args=(name, param)) for name, param in param_dict.items()]
  42. results = [p.get() for p in results]
  43. end_t = datetime.datetime.now()
  44. elapsed_sec = (end_t - start_t).total_seconds()
  45. print("多进程计算 共消耗: " + "{:.2f}".format(elapsed_sec) + " 秒")
  46. return results
  47. if __name__ == '__main__':
  48. data = prepare_data()
  49. results = multi_process_tag(data)
  50. extract_result(results)

实例

2000条计算Oracle,原始大致10分钟左右,在8核I7上
用了多进程之后,一个核125条,不到3分钟解决

  1. import json
  2. from itertools import combinations
  3. from rouge import Rouge
  4. import threading
  5. import time
  6. import math
  7. import datetime
  8. import multiprocessing as mp
  9. from rouge import Rouge
  10. import copy
  11. import json
  12. import random
  13. random.seed(413)
  14. rouge = Rouge()
  15. def get_score(hyp,ref):
  16. try:
  17. temp_rouge = rouge.get_scores(hyp, ref)
  18. cur_score = (temp_rouge[0]["rouge-1"]['f'] + temp_rouge[0]["rouge-2"]['f'] + temp_rouge[0]["rouge-l"]['f'])/3
  19. except :
  20. cur_score = 0
  21. return cur_score
  22. def get_oracle(sent_list,summary):
  23. Chosen_idx = []
  24. best_score = 0
  25. cal_count = 0
  26. while 1:
  27. best_choice = -1
  28. best_sub_score = 0
  29. for i in range(len(sent_list)):
  30. if i not in Chosen_idx and len(sent_list[i]) != 0 :
  31. cal_count += 1
  32. temp_chosen = copy.deepcopy(Chosen_idx)
  33. temp_chosen.append(i)
  34. temp_chosen_sents = [sent_list[i] for i in temp_chosen]
  35. #print(temp_chosen)
  36. #print(temp_chosen_sents)
  37. cur_score = get_score(" ".join(temp_chosen_sents),summary)
  38. cur_sub_score = cur_score - best_score
  39. if cur_sub_score > best_sub_score:
  40. best_sub_score = cur_sub_score
  41. best_choice = i
  42. if best_choice == -1:
  43. break
  44. Chosen_idx.append(best_choice)
  45. best_sents = [sent_list[i] for i in Chosen_idx]
  46. best_score = get_score(" ".join(best_sents),summary)
  47. best_sents = [sent_list[i] for i in Chosen_idx]
  48. #print(len(sent_list))
  49. #print(len(best_sents))
  50. #print(cal_count)
  51. try:
  52. temp_rouge = rouge.get_scores(" ".join(best_sents), summary)
  53. except :
  54. return 0,0,0
  55. return temp_rouge[0]["rouge-1"]['f'],temp_rouge[0]["rouge-2"]['f'],temp_rouge[0]["rouge-l"]['f']
  56. def Function(name,json_data):
  57. result = {}
  58. result['r1'] = []
  59. result['r2'] = []
  60. result['rl'] = []
  61. for i in range(len(json_data)):
  62. doc = json_data[i]['doc']
  63. summary = json_data[i]['summary']
  64. r1,r2,rl = get_oracle(doc,summary)
  65. result['r1'].append(r1)
  66. result['r2'].append(r2)
  67. result['rl'].append(rl)
  68. return result
  69. def prepare_data():
  70. input_data = []
  71. f = open('test.extract.source','r',encoding = 'utf-8')
  72. f2 = open('test.target','r',encoding = 'utf-8')
  73. f3 = open('QueryResult.txt','r',encoding = 'utf-8')
  74. query = f3.readlines()
  75. query = [[int (j) for j in i.strip().split()] for i in query]
  76. summarys = f2.readlines()
  77. summarys = [i.strip() for i in summarys]
  78. import random
  79. data_index = []
  80. while len(data_index) < 2000:
  81. random_index = random.randint(0,len(summarys)-1)
  82. if random_index not in data_index:
  83. data_index.append(random_index)
  84. print(data_index[:10])
  85. assert data_index[0] == 10455
  86. lines = f.readlines()
  87. ftrain = open('train.extract.source','r',encoding = 'utf-8')
  88. assist_lines = ftrain.readlines()
  89. for i in range(len(lines)):
  90. data = lines[i].strip()
  91. data_dict = json.loads(data)
  92. doc = data_dict['text']
  93. for j in query[i][:1]:
  94. assist = assist_lines[j].strip()
  95. assist_dict = json.loads(assist)
  96. assist_doc = assist_dict['text']
  97. doc = assist_doc + doc
  98. temp_data = {}
  99. temp_data['doc'] = doc
  100. temp_data['summary'] = summarys[data_index[i]]
  101. input_data.append(temp_data)
  102. return input_data
  103. def extract_result(results):
  104. total_samples = 0
  105. Sum1 = 0
  106. Sum2 = 0
  107. SumL = 0
  108. for i in results:
  109. total_samples += len(i['r1'])
  110. Sum1 += sum(i['r1'])
  111. Sum2 += sum(i['r2'])
  112. SumL += sum(i['rl'])
  113. print(total_samples)
  114. print(Sum1/total_samples)
  115. print(Sum2/total_samples)
  116. print(SumL/total_samples)
  117. def multi_process_tag(target_data):
  118. num_cores = int(mp.cpu_count())
  119. print("本地计算机有: " + str(num_cores) + " 核心")
  120. pool = mp.Pool(num_cores)
  121. param_dict = {}
  122. start = 0
  123. end = len(target_data)
  124. step = int((end - start)/num_cores)
  125. print("per Task Step: ",step)
  126. for i in range(num_cores):
  127. param_dict['task{}'.format(i)]= target_data[start:start+step]
  128. start = start+step
  129. param_dict['task{}'.format(num_cores)]= target_data[start:]
  130. start_t = datetime.datetime.now()
  131. results = [pool.apply_async(Function, args=(name, param)) for name, param in param_dict.items()]
  132. results = [p.get() for p in results]
  133. end_t = datetime.datetime.now()
  134. elapsed_sec = (end_t - start_t).total_seconds()
  135. print("多进程计算 共消耗: " + "{:.2f}".format(elapsed_sec) + " 秒")
  136. return results
  137. if __name__ == '__main__':
  138. data = prepare_data()
  139. results = multi_process_tag(data)
  140. extract_result(results)