其余python包中基本就模型、损失函数(MarginLoss等)、NegativeSampling以及Trainer几处与算法相关。由于是torch写的,所以很易懂,就不展开讲了

    1. import torch
    2. import torch.nn as nn
    3. import torch.nn.functional as F
    4. from .Model import Model
    5. class TransE(Model):
    6. def __init__(self, ent_tot, rel_tot, dim = 100, p_norm = 1, norm_flag = True, margin = None, epsilon = None):
    7. super(TransE, self).__init__(ent_tot, rel_tot)
    8. self.dim = dim # 即嵌入向量的维度
    9. self.margin = margin
    10. self.epsilon = epsilon
    11. self.norm_flag = norm_flag
    12. self.p_norm = p_norm
    13. # ent_tot指的是实体的数量
    14. self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim)
    15. # rel_tot指的是关系的数量
    16. self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim)
    17. # 默认情况下,直接对嵌入向量进行哈维尔初始化
    18. if margin == None or epsilon == None:
    19. nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
    20. nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
    21. else:
    22. self.embedding_range = nn.Parameter(
    23. torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False
    24. )
    25. nn.init.uniform_(
    26. tensor = self.ent_embeddings.weight.data,
    27. a = -self.embedding_range.item(),
    28. b = self.embedding_range.item()
    29. )
    30. nn.init.uniform_(
    31. tensor = self.rel_embeddings.weight.data,
    32. a= -self.embedding_range.item(),
    33. b= self.embedding_range.item()
    34. )
    35. # 定义margin,在使用SigmoidLoss的时候会拍上用场,使用MarginLoss不需要定义,因为Margin已经在损失函数中设置好了
    36. if margin != None:
    37. self.margin = nn.Parameter(torch.Tensor([margin]))
    38. self.margin.requires_grad = False
    39. self.margin_flag = True
    40. else:
    41. self.margin_flag = False
    42. def _calc(self, h, t, r, mode):
    43. if self.norm_flag:
    44. # 对嵌入向量进行正态化操作,2指的是平方范数
    45. h = F.normalize(h, 2, -1)
    46. r = F.normalize(r, 2, -1)
    47. t = F.normalize(t, 2, -1)
    48. if mode != 'normal':
    49. # 这里要变换形状是因为当模式为头部负采样时或尾部负采样时h或t会有很多负样例,形状不同导致不能运算,所以作一下变换
    50. h = h.view(-1, r.shape[0], h.shape[-1])
    51. t = t.view(-1, r.shape[0], t.shape[-1])
    52. r = r.view(-1, r.shape[0], r.shape[-1])
    53. if mode == 'head_batch':
    54. score = h + (r - t)
    55. else:
    56. score = (h + r) - t
    57. # 这个计算得到的score实际上就是(h + r) - t这个向量的模
    58. score = torch.norm(score, self.p_norm, -1).flatten()
    59. return score
    60. def forward(self, data):
    61. batch_h = data['batch_h']
    62. batch_t = data['batch_t']
    63. batch_r = data['batch_r']
    64. mode = data['mode']
    65. # 获取实体以及关系batch的全部嵌入向量(即返回一个嵌入矩阵)
    66. h = self.ent_embeddings(batch_h)
    67. t = self.ent_embeddings(batch_t)
    68. r = self.rel_embeddings(batch_r)
    69. score = self._calc(h ,t, r, mode)
    70. if self.margin_flag:
    71. return self.margin - score
    72. else:
    73. return score
    74. def regularization(self, data):
    75. batch_h = data['batch_h']
    76. batch_t = data['batch_t']
    77. batch_r = data['batch_r']
    78. h = self.ent_embeddings(batch_h)
    79. t = self.ent_embeddings(batch_t)
    80. r = self.rel_embeddings(batch_r)
    81. regul = (torch.mean(h ** 2) +
    82. torch.mean(t ** 2) +
    83. torch.mean(r ** 2)) / 3
    84. return regul
    85. def predict(self, data):
    86. score = self.forward(data)
    87. if self.margin_flag:
    88. score = self.margin - score
    89. return score.cpu().data.numpy()
    90. else:
    91. return score.cpu().data.numpy()