其余python包中基本就模型、损失函数(MarginLoss等)、NegativeSampling以及Trainer几处与算法相关。由于是torch写的,所以很易懂,就不展开讲了
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom .Model import Modelclass TransE(Model):def __init__(self, ent_tot, rel_tot, dim = 100, p_norm = 1, norm_flag = True, margin = None, epsilon = None):super(TransE, self).__init__(ent_tot, rel_tot)self.dim = dim # 即嵌入向量的维度self.margin = marginself.epsilon = epsilonself.norm_flag = norm_flagself.p_norm = p_norm# ent_tot指的是实体的数量self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim)# rel_tot指的是关系的数量self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim)# 默认情况下,直接对嵌入向量进行哈维尔初始化if margin == None or epsilon == None:nn.init.xavier_uniform_(self.ent_embeddings.weight.data)nn.init.xavier_uniform_(self.rel_embeddings.weight.data)else:self.embedding_range = nn.Parameter(torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False)nn.init.uniform_(tensor = self.ent_embeddings.weight.data,a = -self.embedding_range.item(),b = self.embedding_range.item())nn.init.uniform_(tensor = self.rel_embeddings.weight.data,a= -self.embedding_range.item(),b= self.embedding_range.item())# 定义margin,在使用SigmoidLoss的时候会拍上用场,使用MarginLoss不需要定义,因为Margin已经在损失函数中设置好了if margin != None:self.margin = nn.Parameter(torch.Tensor([margin]))self.margin.requires_grad = Falseself.margin_flag = Trueelse:self.margin_flag = Falsedef _calc(self, h, t, r, mode):if self.norm_flag:# 对嵌入向量进行正态化操作,2指的是平方范数h = F.normalize(h, 2, -1)r = F.normalize(r, 2, -1)t = F.normalize(t, 2, -1)if mode != 'normal':# 这里要变换形状是因为当模式为头部负采样时或尾部负采样时h或t会有很多负样例,形状不同导致不能运算,所以作一下变换h = h.view(-1, r.shape[0], h.shape[-1])t = t.view(-1, r.shape[0], t.shape[-1])r = r.view(-1, r.shape[0], r.shape[-1])if mode == 'head_batch':score = h + (r - t)else:score = (h + r) - t# 这个计算得到的score实际上就是(h + r) - t这个向量的模score = torch.norm(score, self.p_norm, -1).flatten()return scoredef forward(self, data):batch_h = data['batch_h']batch_t = data['batch_t']batch_r = data['batch_r']mode = data['mode']# 获取实体以及关系batch的全部嵌入向量(即返回一个嵌入矩阵)h = self.ent_embeddings(batch_h)t = self.ent_embeddings(batch_t)r = self.rel_embeddings(batch_r)score = self._calc(h ,t, r, mode)if self.margin_flag:return self.margin - scoreelse:return scoredef regularization(self, data):batch_h = data['batch_h']batch_t = data['batch_t']batch_r = data['batch_r']h = self.ent_embeddings(batch_h)t = self.ent_embeddings(batch_t)r = self.rel_embeddings(batch_r)regul = (torch.mean(h ** 2) +torch.mean(t ** 2) +torch.mean(r ** 2)) / 3return reguldef predict(self, data):score = self.forward(data)if self.margin_flag:score = self.margin - scorereturn score.cpu().data.numpy()else:return score.cpu().data.numpy()
