1. # coding:utf-8
    2. import os
    3. import ctypes
    4. import numpy as np
    5. class TrainDataSampler(object):
    6. def __init__(self, nbatches, datasampler):
    7. self.nbatches = nbatches
    8. self.datasampler = datasampler
    9. self.batch = 0
    10. def __iter__(self):
    11. return self
    12. def __next__(self):
    13. self.batch += 1
    14. if self.batch > self.nbatches:
    15. raise StopIteration()
    16. return self.datasampler()
    17. def __len__(self):
    18. return self.nbatches
    19. class TrainDataLoader(object):
    20. def __init__(self,
    21. in_path = "./",
    22. tri_file = None, # 三元组文件
    23. ent_file = None, # 实体文件记录实体与对应id之间的映射关系
    24. rel_file = None, # 关系文件记录关系与对应id之间的映射关系
    25. batch_size = None,
    26. nbatches = None, # 根据三元组的总量计算有多少个batch
    27. threads = 8, # 读取数据的线程数量
    28. sampling_mode = "normal", # 采样方法,分为普通采样及交叉采样
    29. bern_flag = False, # 负采样时进行伯努利采样还是均匀采样
    30. filter_flag = True, # 进行负采样时是否要过滤掉存在于训练集中的三元组
    31. neg_ent = 1, # 负采样时一个正样本生成负样本的数量,默认情况一个正样本只生成一个负样本
    32. neg_rel = 0): # 负采样时一个关系生成负样本的数量
    33. base_file = os.path.abspath(os.path.join(os.path.dirname(__file__), "../release/Base.dll"))
    34. self.lib = ctypes.cdll.LoadLibrary(base_file)
    35. """argtypes"""
    36. self.lib.sampling.argtypes = [ # C与Python数据类型的转换,回调
    37. ctypes.c_void_p,
    38. ctypes.c_void_p,
    39. ctypes.c_void_p,
    40. ctypes.c_void_p,
    41. ctypes.c_int64,
    42. ctypes.c_int64,
    43. ctypes.c_int64,
    44. ctypes.c_int64,
    45. ctypes.c_int64,
    46. ctypes.c_int64,
    47. ctypes.c_int64
    48. ]
    49. self.in_path = in_path
    50. self.tri_file = tri_file
    51. self.ent_file = ent_file
    52. self.rel_file = rel_file
    53. if in_path != None:
    54. self.tri_file = in_path + "train2id.txt"
    55. self.ent_file = in_path + "entity2id.txt"
    56. self.rel_file = in_path + "relation2id.txt"
    57. """set essential parameters"""
    58. self.work_threads = threads
    59. self.nbatches = nbatches
    60. self.batch_size = batch_size
    61. self.bern = bern_flag
    62. self.filter = filter_flag
    63. self.negative_ent = neg_ent # 生成负例实体数量
    64. self.negative_rel = neg_rel # 生成负例关系数量
    65. self.sampling_mode = sampling_mode
    66. self.cross_sampling_flag = 0
    67. self.read()
    68. def read(self):
    69. # 将文件路径之类的信息传入动态链接库
    70. if self.in_path != None:
    71. self.lib.setInPath(ctypes.create_string_buffer(self.in_path.encode(), len(self.in_path) * 2))
    72. else:
    73. self.lib.setTrainPath(ctypes.create_string_buffer(self.tri_file.encode(), len(self.tri_file) * 2))
    74. self.lib.setEntPath(ctypes.create_string_buffer(self.ent_file.encode(), len(self.ent_file) * 2))
    75. self.lib.setRelPath(ctypes.create_string_buffer(self.rel_file.encode(), len(self.rel_file) * 2))
    76. # lib.method均表示调用动态链接库中的函数
    77. # 设置是否进行伯努利采样
    78. self.lib.setBern(self.bern)
    79. # 设置工作线程
    80. self.lib.setWorkThreads(self.work_threads)
    81. # 重置线程的随机种子
    82. self.lib.randReset()
    83. # 底层c++开始将数据读入内存
    84. self.lib.importTrainFiles()
    85. # 获取实体、关系以及训练三元组的总数
    86. self.relTotal = self.lib.getRelationTotal()
    87. self.entTotal = self.lib.getEntityTotal()
    88. self.tripleTotal = self.lib.getTrainTotal()
    89. if self.batch_size == None:
    90. self.batch_size = self.tripleTotal // self.nbatches
    91. if self.nbatches == None:
    92. self.nbatches = self.tripleTotal // self.batch_size
    93. # batch_seq_size是每次训练时一个batch的数据,其中包括原始及负采样之后得到的三元组,所以在长度上要乘以负采样实体和关系的次数
    94. self.batch_seq_size = self.batch_size * (1 + self.negative_ent + self.negative_rel)
    95. # 初始化一个batch的数据,包含头实体、尾实体、关系、标签,其中batch_y为标签,1表示原始三元组,-1表示负采样替换后的三元组
    96. self.batch_h = np.zeros(self.batch_seq_size, dtype=np.int64)
    97. self.batch_t = np.zeros(self.batch_seq_size, dtype=np.int64)
    98. self.batch_r = np.zeros(self.batch_seq_size, dtype=np.int64)
    99. self.batch_y = np.zeros(self.batch_seq_size, dtype=np.float32)
    100. # 获得tensor的地址,用于将地址传给c++,让底层代码完成数据的写入工作。这样就完成了数据的传递(居然还能这样)
    101. self.batch_h_addr = self.batch_h.__array_interface__["data"][0]
    102. self.batch_t_addr = self.batch_t.__array_interface__["data"][0]
    103. self.batch_r_addr = self.batch_r.__array_interface__["data"][0]
    104. self.batch_y_addr = self.batch_y.__array_interface__["data"][0]
    105. # 普通采样
    106. def sampling(self):
    107. # 这里通过刚才获取的tensor地址就可以传给c++进行采样操作了
    108. self.lib.sampling(
    109. self.batch_h_addr,
    110. self.batch_t_addr,
    111. self.batch_r_addr,
    112. self.batch_y_addr,
    113. self.batch_size,
    114. self.negative_ent,
    115. self.negative_rel,
    116. 0,
    117. self.filter,
    118. 0,
    119. 0
    120. )
    121. return {
    122. "batch_h": self.batch_h,
    123. "batch_t": self.batch_t,
    124. "batch_r": self.batch_r,
    125. "batch_y": self.batch_y,
    126. "mode": "normal"
    127. }
    128. def sampling_head(self):
    129. self.lib.sampling(
    130. self.batch_h_addr,
    131. self.batch_t_addr,
    132. self.batch_r_addr,
    133. self.batch_y_addr,
    134. self.batch_size,
    135. self.negative_ent,
    136. self.negative_rel,
    137. -1,
    138. self.filter,
    139. 0,
    140. 0
    141. )
    142. return {
    143. "batch_h": self.batch_h,
    144. "batch_t": self.batch_t[:self.batch_size],
    145. "batch_r": self.batch_r[:self.batch_size],
    146. "batch_y": self.batch_y,
    147. "mode": "head_batch"
    148. }
    149. def sampling_tail(self):
    150. self.lib.sampling(
    151. self.batch_h_addr,
    152. self.batch_t_addr,
    153. self.batch_r_addr,
    154. self.batch_y_addr,
    155. self.batch_size,
    156. self.negative_ent,
    157. self.negative_rel,
    158. 1,
    159. self.filter,
    160. 0,
    161. 0
    162. )
    163. return {
    164. "batch_h": self.batch_h[:self.batch_size],
    165. "batch_t": self.batch_t,
    166. "batch_r": self.batch_r[:self.batch_size],
    167. "batch_y": self.batch_y,
    168. "mode": "tail_batch"
    169. }
    170. # 交叉采样,每遍历到一个batch交替进行头部采样以及尾部采样
    171. def cross_sampling(self):
    172. self.cross_sampling_flag = 1 - self.cross_sampling_flag
    173. if self.cross_sampling_flag == 0:
    174. return self.sampling_head()
    175. else:
    176. return self.sampling_tail()
    177. """interfaces to set essential parameters"""
    178. def set_work_threads(self, work_threads):
    179. self.work_threads = work_threads
    180. def set_in_path(self, in_path):
    181. self.in_path = in_path
    182. def set_nbatches(self, nbatches):
    183. self.nbatches = nbatches
    184. def set_batch_size(self, batch_size):
    185. self.batch_size = batch_size
    186. self.nbatches = self.tripleTotal // self.batch_size
    187. def set_ent_neg_rate(self, rate):
    188. self.negative_ent = rate
    189. def set_rel_neg_rate(self, rate):
    190. self.negative_rel = rate
    191. def set_bern_flag(self, bern):
    192. self.bern = bern
    193. def set_filter_flag(self, filter):
    194. self.filter = filter
    195. """interfaces to get essential parameters"""
    196. def get_batch_size(self):
    197. return self.batch_size
    198. def get_ent_tot(self):
    199. return self.entTotal
    200. def get_rel_tot(self):
    201. return self.relTotal
    202. def get_triple_tot(self):
    203. return self.tripleTotal
    204. def __iter__(self):
    205. if self.sampling_mode == "normal":
    206. return TrainDataSampler(self.nbatches, self.sampling)
    207. else:
    208. return TrainDataSampler(self.nbatches, self.cross_sampling)
    209. def __len__(self):
    210. return self.nbatches