1. #ifndef CORRUPT_H
    2. #define CORRUPT_H
    3. #include "Random.h"
    4. #include "Triple.h"
    5. #include "Reader.h"
    6. #include "iso646.h"
    7. INT corrupt_head(INT id, INT h, INT r, bool filter_flag = true) {
    8. INT lef, rig, mid, ll, rr;
    9. // filter_flag是指要不要进行过滤,在这里t是随机生成的,它会带来什么问题呢?负采样得到的h,r,tmp很有可能也在三元组中出现过,这样效果就会适得其反
    10. if (not filter_flag) {
    11. INT tmp = rand_max(id, entityTotal - 1);
    12. if (tmp < h)
    13. return tmp;
    14. else
    15. return tmp + 1;
    16. }
    17. // trainHead是依据三元组头部的ID大小进行排序的列表,可以进行二分查找,因为头实体ID相同的话它是按关系RelationID大小排列的
    18. // lef即当前头实体在有序TrainHead列表第一次出现的位置(从这里开始查找)
    19. // rig即当前头实体在有序TrainHead列表最后一次出现的位置(查找区间尽头)
    20. lef = lefHead[h] - 1;
    21. rig = rigHead[h];
    22. while (lef + 1 < rig) {
    23. mid = (lef + rig) >> 1;
    24. if (trainHead[mid].r >= r) rig = mid; else
    25. lef = mid;
    26. }
    27. // ll是在区间lef~rig中第一次出现关系为r的三元组下标
    28. ll = rig;
    29. lef = lefHead[h];
    30. rig = rigHead[h] + 1;
    31. while (lef + 1 < rig) {
    32. mid = (lef + rig) >> 1;
    33. if (trainHead[mid].r <= r) lef = mid; else
    34. rig = mid;
    35. }
    36. // rr是在区间lef~rig中最后一次出现关系为r的三元组下标
    37. rr = lef;
    38. // 那么我们最后想返回什么呢?就是一个没有在任何头实体为h、关系为r的三元组中出现过的t,其实就是很简单的一个事情
    39. INT tmp = rand_max(id, entityTotal - (rr - ll + 1));
    40. if (tmp < trainHead[ll].t) return tmp;
    41. if (tmp > trainHead[rr].t - rr + ll - 1) return tmp + rr - ll + 1;
    42. lef = ll, rig = rr + 1;
    43. // 随机到的tmp恰好落入relation为r的tail区间里了,这种情况有点麻烦
    44. // 但是好在所有头实体为h、关系为r的三元组中也是按尾实体ID升序排列的,所以还可以进行一次二分查找,找到未在tail区间中出现过的尾实体ID
    45. while (lef + 1 < rig) {
    46. mid = (lef + rig) >> 1;
    47. if (trainHead[mid].t - mid + ll - 1 < tmp)
    48. lef = mid;
    49. else
    50. rig = mid;
    51. }
    52. return tmp + lef - ll + 1;
    53. }
    54. INT corrupt_tail(INT id, INT t, INT r, bool filter_flag = true) {
    55. INT lef, rig, mid, ll, rr;
    56. if (not filter_flag) {
    57. INT tmp = rand_max(id, entityTotal - 1);
    58. if (tmp < t)
    59. return tmp;
    60. else
    61. return tmp + 1;
    62. }
    63. lef = lefTail[t] - 1;
    64. rig = rigTail[t];
    65. while (lef + 1 < rig) {
    66. mid = (lef + rig) >> 1;
    67. if (trainTail[mid].r >= r) rig = mid; else
    68. lef = mid;
    69. }
    70. ll = rig;
    71. lef = lefTail[t];
    72. rig = rigTail[t] + 1;
    73. while (lef + 1 < rig) {
    74. mid = (lef + rig) >> 1;
    75. if (trainTail[mid].r <= r) lef = mid; else
    76. rig = mid;
    77. }
    78. rr = lef;
    79. INT tmp = rand_max(id, entityTotal - (rr - ll + 1));
    80. if (tmp < trainTail[ll].h) return tmp;
    81. if (tmp > trainTail[rr].h - rr + ll - 1) return tmp + rr - ll + 1;
    82. lef = ll, rig = rr + 1;
    83. while (lef + 1 < rig) {
    84. mid = (lef + rig) >> 1;
    85. if (trainTail[mid].h - mid + ll - 1 < tmp)
    86. lef = mid;
    87. else
    88. rig = mid;
    89. }
    90. return tmp + lef - ll + 1;
    91. }
    92. INT corrupt_rel(INT id, INT h, INT t, INT r, bool p = false, bool filter_flag = true) {
    93. INT lef, rig, mid, ll, rr;
    94. if (not filter_flag) {
    95. INT tmp = rand_max(id, relationTotal - 1);
    96. if (tmp < r)
    97. return tmp;
    98. else
    99. return tmp + 1;
    100. }
    101. lef = lefRel[h] - 1;
    102. rig = rigRel[h];
    103. while (lef + 1 < rig) {
    104. mid = (lef + rig) >> 1;
    105. if (trainRel[mid].t >= t) rig = mid; else
    106. lef = mid;
    107. }
    108. ll = rig;
    109. lef = lefRel[h];
    110. rig = rigRel[h] + 1;
    111. while (lef + 1 < rig) {
    112. mid = (lef + rig) >> 1;
    113. if (trainRel[mid].t <= t) lef = mid; else
    114. rig = mid;
    115. }
    116. rr = lef;
    117. INT tmp;
    118. if(p == false) {
    119. tmp = rand_max(id, relationTotal - (rr - ll + 1));
    120. }
    121. else {
    122. INT start = r * (relationTotal - 1);
    123. REAL sum = 1;
    124. bool *record = (bool *)calloc(relationTotal - 1, sizeof(bool));
    125. for (INT i = ll; i <= rr; ++i){
    126. if (trainRel[i].r > r){
    127. sum -= prob[start + trainRel[i].r-1];
    128. record[trainRel[i].r-1] = true;
    129. }
    130. else if (trainRel[i].r < r){
    131. sum -= prob[start + trainRel[i].r];
    132. record[trainRel[i].r] = true;
    133. }
    134. }
    135. REAL *prob_tmp = (REAL *)calloc(relationTotal-(rr-ll+1), sizeof(REAL));
    136. INT cnt = 0;
    137. REAL rec = 0;
    138. for (INT i = start; i < start + relationTotal - 1; ++i) {
    139. if (record[i-start])
    140. continue;
    141. rec += prob[i] / sum;
    142. prob_tmp[cnt++] = rec;
    143. }
    144. REAL m = rand_max(id, 10000) / 10000.0;
    145. lef = 0;
    146. rig = cnt - 1;
    147. while (lef < rig) {
    148. mid = (lef + rig) >> 1;
    149. if (prob_tmp[mid] < m)
    150. lef = mid + 1;
    151. else
    152. rig = mid;
    153. }
    154. tmp = rig;
    155. free(prob_tmp);
    156. free(record);
    157. }
    158. if (tmp < trainRel[ll].r) return tmp;
    159. if (tmp > trainRel[rr].r - rr + ll - 1) return tmp + rr - ll + 1;
    160. lef = ll, rig = rr + 1;
    161. while (lef + 1 < rig) {
    162. mid = (lef + rig) >> 1;
    163. if (trainRel[mid].r - mid + ll - 1 < tmp)
    164. lef = mid;
    165. else
    166. rig = mid;
    167. }
    168. return tmp + lef - ll + 1;
    169. }
    170. bool _find(INT h, INT t, INT r) {
    171. INT lef = 0;
    172. INT rig = tripleTotal - 1;
    173. INT mid;
    174. while (lef + 1 < rig) {
    175. INT mid = (lef + rig) >> 1;
    176. if ((tripleList[mid]. h < h) || (tripleList[mid]. h == h && tripleList[mid]. r < r) || (tripleList[mid]. h == h && tripleList[mid]. r == r && tripleList[mid]. t < t)) lef = mid; else rig = mid;
    177. }
    178. if (tripleList[lef].h == h && tripleList[lef].r == r && tripleList[lef].t == t) return true;
    179. if (tripleList[rig].h == h && tripleList[rig].r == r && tripleList[rig].t == t) return true;
    180. return false;
    181. }
    182. INT corrupt(INT h, INT r){
    183. INT ll = tail_lef[r];
    184. INT rr = tail_rig[r];
    185. INT loop = 0;
    186. INT t;
    187. while(true) {
    188. t = tail_type[rand(ll, rr)];
    189. if (not _find(h, t, r)) {
    190. return t;
    191. } else {
    192. loop ++;
    193. if (loop >= 1000) {
    194. return corrupt_head(0, h, r);
    195. }
    196. }
    197. }
    198. }
    199. #endif