# coding:utf-8import osimport ctypesimport numpy as npclass TrainDataSampler(object): def __init__(self, nbatches, datasampler): self.nbatches = nbatches self.datasampler = datasampler self.batch = 0 def __iter__(self): return self def __next__(self): self.batch += 1 if self.batch > self.nbatches: raise StopIteration() return self.datasampler() def __len__(self): return self.nbatchesclass TrainDataLoader(object): def __init__(self, in_path = "./", tri_file = None, # 三元组文件 ent_file = None, # 实体文件记录实体与对应id之间的映射关系 rel_file = None, # 关系文件记录关系与对应id之间的映射关系 batch_size = None, nbatches = None, # 根据三元组的总量计算有多少个batch threads = 8, # 读取数据的线程数量 sampling_mode = "normal", # 采样方法,分为普通采样及交叉采样 bern_flag = False, # 负采样时进行伯努利采样还是均匀采样 filter_flag = True, # 进行负采样时是否要过滤掉存在于训练集中的三元组 neg_ent = 1, # 负采样时一个正样本生成负样本的数量,默认情况一个正样本只生成一个负样本 neg_rel = 0): # 负采样时一个关系生成负样本的数量 base_file = os.path.abspath(os.path.join(os.path.dirname(__file__), "../release/Base.dll")) self.lib = ctypes.cdll.LoadLibrary(base_file) """argtypes""" self.lib.sampling.argtypes = [ # C与Python数据类型的转换,回调 ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int64, ctypes.c_int64, ctypes.c_int64, ctypes.c_int64, ctypes.c_int64, ctypes.c_int64, ctypes.c_int64 ] self.in_path = in_path self.tri_file = tri_file self.ent_file = ent_file self.rel_file = rel_file if in_path != None: self.tri_file = in_path + "train2id.txt" self.ent_file = in_path + "entity2id.txt" self.rel_file = in_path + "relation2id.txt" """set essential parameters""" self.work_threads = threads self.nbatches = nbatches self.batch_size = batch_size self.bern = bern_flag self.filter = filter_flag self.negative_ent = neg_ent # 生成负例实体数量 self.negative_rel = neg_rel # 生成负例关系数量 self.sampling_mode = sampling_mode self.cross_sampling_flag = 0 self.read() def read(self): # 将文件路径之类的信息传入动态链接库 if self.in_path != None: self.lib.setInPath(ctypes.create_string_buffer(self.in_path.encode(), len(self.in_path) * 2)) else: self.lib.setTrainPath(ctypes.create_string_buffer(self.tri_file.encode(), len(self.tri_file) * 2)) self.lib.setEntPath(ctypes.create_string_buffer(self.ent_file.encode(), len(self.ent_file) * 2)) self.lib.setRelPath(ctypes.create_string_buffer(self.rel_file.encode(), len(self.rel_file) * 2)) # lib.method均表示调用动态链接库中的函数 # 设置是否进行伯努利采样 self.lib.setBern(self.bern) # 设置工作线程 self.lib.setWorkThreads(self.work_threads) # 重置线程的随机种子 self.lib.randReset() # 底层c++开始将数据读入内存 self.lib.importTrainFiles() # 获取实体、关系以及训练三元组的总数 self.relTotal = self.lib.getRelationTotal() self.entTotal = self.lib.getEntityTotal() self.tripleTotal = self.lib.getTrainTotal() if self.batch_size == None: self.batch_size = self.tripleTotal // self.nbatches if self.nbatches == None: self.nbatches = self.tripleTotal // self.batch_size # batch_seq_size是每次训练时一个batch的数据,其中包括原始及负采样之后得到的三元组,所以在长度上要乘以负采样实体和关系的次数 self.batch_seq_size = self.batch_size * (1 + self.negative_ent + self.negative_rel) # 初始化一个batch的数据,包含头实体、尾实体、关系、标签,其中batch_y为标签,1表示原始三元组,-1表示负采样替换后的三元组 self.batch_h = np.zeros(self.batch_seq_size, dtype=np.int64) self.batch_t = np.zeros(self.batch_seq_size, dtype=np.int64) self.batch_r = np.zeros(self.batch_seq_size, dtype=np.int64) self.batch_y = np.zeros(self.batch_seq_size, dtype=np.float32) # 获得tensor的地址,用于将地址传给c++,让底层代码完成数据的写入工作。这样就完成了数据的传递(居然还能这样) self.batch_h_addr = self.batch_h.__array_interface__["data"][0] self.batch_t_addr = self.batch_t.__array_interface__["data"][0] self.batch_r_addr = self.batch_r.__array_interface__["data"][0] self.batch_y_addr = self.batch_y.__array_interface__["data"][0] # 普通采样 def sampling(self): # 这里通过刚才获取的tensor地址就可以传给c++进行采样操作了 self.lib.sampling( self.batch_h_addr, self.batch_t_addr, self.batch_r_addr, self.batch_y_addr, self.batch_size, self.negative_ent, self.negative_rel, 0, self.filter, 0, 0 ) return { "batch_h": self.batch_h, "batch_t": self.batch_t, "batch_r": self.batch_r, "batch_y": self.batch_y, "mode": "normal" } def sampling_head(self): self.lib.sampling( self.batch_h_addr, self.batch_t_addr, self.batch_r_addr, self.batch_y_addr, self.batch_size, self.negative_ent, self.negative_rel, -1, self.filter, 0, 0 ) return { "batch_h": self.batch_h, "batch_t": self.batch_t[:self.batch_size], "batch_r": self.batch_r[:self.batch_size], "batch_y": self.batch_y, "mode": "head_batch" } def sampling_tail(self): self.lib.sampling( self.batch_h_addr, self.batch_t_addr, self.batch_r_addr, self.batch_y_addr, self.batch_size, self.negative_ent, self.negative_rel, 1, self.filter, 0, 0 ) return { "batch_h": self.batch_h[:self.batch_size], "batch_t": self.batch_t, "batch_r": self.batch_r[:self.batch_size], "batch_y": self.batch_y, "mode": "tail_batch" } # 交叉采样,每遍历到一个batch交替进行头部采样以及尾部采样 def cross_sampling(self): self.cross_sampling_flag = 1 - self.cross_sampling_flag if self.cross_sampling_flag == 0: return self.sampling_head() else: return self.sampling_tail() """interfaces to set essential parameters""" def set_work_threads(self, work_threads): self.work_threads = work_threads def set_in_path(self, in_path): self.in_path = in_path def set_nbatches(self, nbatches): self.nbatches = nbatches def set_batch_size(self, batch_size): self.batch_size = batch_size self.nbatches = self.tripleTotal // self.batch_size def set_ent_neg_rate(self, rate): self.negative_ent = rate def set_rel_neg_rate(self, rate): self.negative_rel = rate def set_bern_flag(self, bern): self.bern = bern def set_filter_flag(self, filter): self.filter = filter """interfaces to get essential parameters""" def get_batch_size(self): return self.batch_size def get_ent_tot(self): return self.entTotal def get_rel_tot(self): return self.relTotal def get_triple_tot(self): return self.tripleTotal def __iter__(self): if self.sampling_mode == "normal": return TrainDataSampler(self.nbatches, self.sampling) else: return TrainDataSampler(self.nbatches, self.cross_sampling) def __len__(self): return self.nbatches