1. #ifndef READER_H
    2. #define READER_H
    3. #include "Setting.h"
    4. #include "Triple.h"
    5. #include <cstdlib>
    6. #include <algorithm>
    7. #include <iostream>
    8. #include <cmath>
    9. INT *freqRel, *freqEnt;
    10. INT *lefHead, *rigHead;
    11. INT *lefTail, *rigTail;
    12. INT *lefRel, *rigRel;
    13. REAL *left_mean, *right_mean;
    14. REAL *prob;
    15. Triple *trainList;
    16. Triple *trainHead;
    17. Triple *trainTail;
    18. Triple *trainRel;
    19. INT *testLef, *testRig;
    20. INT *validLef, *validRig;
    21. extern "C" __declspec(dllexport)
    22. // 不知道干什么用的,提前确定阈值?
    23. void importProb(REAL temp){
    24. if (prob != NULL)
    25. free(prob);
    26. FILE *fin;
    27. fin = fopen((inPath + "kl_prob.txt").c_str(), "r");
    28. printf("Current temperature:%f\n", temp);
    29. prob = (REAL *)calloc(relationTotal * (relationTotal - 1), sizeof(REAL));
    30. INT tmp;
    31. for (INT i = 0; i < relationTotal * (relationTotal - 1); ++i){
    32. tmp = fscanf(fin, "%f", &prob[i]);
    33. }
    34. REAL sum = 0.0;
    35. for (INT i = 0; i < relationTotal; ++i) {
    36. for (INT j = 0; j < relationTotal-1; ++j){
    37. REAL tmp = exp(-prob[i * (relationTotal - 1) + j] / temp);
    38. sum += tmp;
    39. prob[i * (relationTotal - 1) + j] = tmp;
    40. }
    41. for (INT j = 0; j < relationTotal-1; ++j){
    42. prob[i*(relationTotal-1)+j] /= sum;
    43. }
    44. sum = 0;
    45. }
    46. fclose(fin);
    47. }
    48. extern "C" __declspec(dllexport)
    49. void importTrainFiles() {
    50. printf("The toolkit is importing datasets.\n");
    51. FILE *fin;
    52. int tmp;
    53. // 打开关系数据集,第一行存有所有关系的数量
    54. if (rel_file == "")
    55. fin = fopen((inPath + "relation2id.txt").c_str(), "r");
    56. else
    57. fin = fopen(rel_file.c_str(), "r");
    58. tmp = fscanf(fin, "%ld", &relationTotal);
    59. printf("The total of relations is %ld.\n", relationTotal);
    60. fclose(fin);
    61. // 打开实体数据集,第一行存有所有实体的数量
    62. if (ent_file == "")
    63. fin = fopen((inPath + "entity2id.txt").c_str(), "r");
    64. else
    65. fin = fopen(ent_file.c_str(), "r");
    66. tmp = fscanf(fin, "%ld", &entityTotal);
    67. printf("The total of entities is %ld.\n", entityTotal);
    68. fclose(fin);
    69. // 打开三元组数据集,第一行存有所有三元组的数量
    70. if (train_file == "")
    71. fin = fopen((inPath + "train2id.txt").c_str(), "r");
    72. else
    73. fin = fopen(train_file.c_str(), "r");
    74. tmp = fscanf(fin, "%ld", &trainTotal);
    75. // train开头的都分配与训练集一样的长度
    76. trainList = (Triple *)calloc(trainTotal, sizeof(Triple));
    77. trainHead = (Triple *)calloc(trainTotal, sizeof(Triple));
    78. trainTail = (Triple *)calloc(trainTotal, sizeof(Triple));
    79. trainRel = (Triple *)calloc(trainTotal, sizeof(Triple));
    80. // freq开头的长度与实体/关系数量一致,用于存储每个实体/关系出现的频率
    81. freqRel = (INT *)calloc(relationTotal, sizeof(INT));
    82. freqEnt = (INT *)calloc(entityTotal, sizeof(INT));
    83. // 将三元组数据集中的三元组存在trainList中
    84. for (INT i = 0; i < trainTotal; i++) {
    85. tmp = fscanf(fin, "%ld", &trainList[i].h);
    86. tmp = fscanf(fin, "%ld", &trainList[i].t);
    87. tmp = fscanf(fin, "%ld", &trainList[i].r);
    88. }
    89. fclose(fin);
    90. // 按照头实体ID的大小,对trainList进行排序,若头实体ID相等,则判断关系ID;若头实体、关系都相等,则判断尾实体ID;并以升序的方式排列
    91. std::sort(trainList, trainList + trainTotal, Triple::cmp_head);
    92. // 遍历一遍三元组训练集,统计每个实体及关系出现的频率,并把相同的内容复制到trainHead、trainTail、trainRel中
    93. tmp = trainTotal; trainTotal = 1;
    94. trainHead[0] = trainTail[0] = trainRel[0] = trainList[0];
    95. freqEnt[trainList[0].t] += 1;
    96. freqEnt[trainList[0].h] += 1;
    97. freqRel[trainList[0].r] += 1;
    98. for (INT i = 1; i < tmp; i++)
    99. if (trainList[i].h != trainList[i - 1].h || trainList[i].r != trainList[i - 1].r || trainList[i].t != trainList[i - 1].t) {
    100. trainHead[trainTotal] = trainTail[trainTotal] = trainRel[trainTotal] = trainList[trainTotal] = trainList[i];
    101. trainTotal++;
    102. freqEnt[trainList[i].t]++;
    103. freqEnt[trainList[i].h]++;
    104. freqRel[trainList[i].r]++;
    105. }
    106. // 现在为什么要额外三个train开头的列表就明了了,原来trainHead、trainTail、trainRel分别是按照三者ID大小来排序的列表,以便进行二分查找
    107. std::sort(trainHead, trainHead + trainTotal, Triple::cmp_head);
    108. std::sort(trainTail, trainTail + trainTotal, Triple::cmp_tail);
    109. std::sort(trainRel, trainRel + trainTotal, Triple::cmp_rel);
    110. printf("The total of train triples is %ld.\n", trainTotal);
    111. // 初始化这六个列表是因为三元组列表中同一实体或关系可能同时出现多次,那么我们希望快速确定实体或关系在排序好的三元组列表中出现的位置
    112. lefHead = (INT *)calloc(entityTotal, sizeof(INT));
    113. rigHead = (INT *)calloc(entityTotal, sizeof(INT));
    114. lefTail = (INT *)calloc(entityTotal, sizeof(INT));
    115. rigTail = (INT *)calloc(entityTotal, sizeof(INT));
    116. lefRel = (INT *)calloc(entityTotal, sizeof(INT));
    117. rigRel = (INT *)calloc(entityTotal, sizeof(INT));
    118. // 对rigHead、rigTail、rigRel数组全部初始化为-1
    119. memset(rigHead, -1, sizeof(INT)*entityTotal);
    120. memset(rigTail, -1, sizeof(INT)*entityTotal);
    121. memset(rigRel, -1, sizeof(INT)*entityTotal);
    122. // lefTail存储的是对应实体在trainTail有序列表中第一个下标位置,相反rigTail存储的是对应实体在trainTail有序列表中最后一个下标位置
    123. // 其他如Head、Rel同理
    124. for (INT i = 1; i < trainTotal; i++) {
    125. if (trainTail[i].t != trainTail[i - 1].t) {
    126. rigTail[trainTail[i - 1].t] = i - 1;
    127. lefTail[trainTail[i].t] = i;
    128. }
    129. if (trainHead[i].h != trainHead[i - 1].h) {
    130. rigHead[trainHead[i - 1].h] = i - 1;
    131. lefHead[trainHead[i].h] = i;
    132. }
    133. if (trainRel[i].h != trainRel[i - 1].h) {
    134. rigRel[trainRel[i - 1].h] = i - 1;
    135. lefRel[trainRel[i].h] = i;
    136. }
    137. }
    138. lefHead[trainHead[0].h] = 0;
    139. rigHead[trainHead[trainTotal - 1].h] = trainTotal - 1;
    140. lefTail[trainTail[0].t] = 0;
    141. rigTail[trainTail[trainTotal - 1].t] = trainTotal - 1;
    142. lefRel[trainRel[0].h] = 0;
    143. rigRel[trainRel[trainTotal - 1].h] = trainTotal - 1;
    144. // 以下内容只有在进行伯努利采样时才用得到
    145. left_mean = (REAL *)calloc(relationTotal,sizeof(REAL));
    146. right_mean = (REAL *)calloc(relationTotal,sizeof(REAL));
    147. // 遍历所有entity
    148. for (INT i = 0; i < entityTotal; i++) {
    149. for (INT j = lefHead[i] + 1; j <= rigHead[i]; j++)
    150. // 相邻训练头实体对应的关系不等情况下,对头实体的入边+1
    151. // 其实由于trainHead是根据Head与Relation一起排列的,所以Head相同的一段其relation也是有序的,所以这个操作就是将在当前Head为指定Entity时出现过的关系次数统一加一
    152. if (trainHead[j].r != trainHead[j - 1].r)
    153. left_mean[trainHead[j].r] += 1.0;
    154. if (lefHead[i] <= rigHead[i])
    155. left_mean[trainHead[lefHead[i]].r] += 1.0;
    156. for (INT j = lefTail[i] + 1; j <= rigTail[i]; j++)
    157. // 如果左实体的大小小于等于右实体的大小,则以左实体对应的出边+1
    158. // 与前者相似,只不过这次看的是尾实体,即计算的是关系的出度
    159. if (trainTail[j].r != trainTail[j - 1].r)
    160. right_mean[trainTail[j].r] += 1.0;
    161. if (lefTail[i] <= rigTail[i])
    162. right_mean[trainTail[lefTail[i]].r] += 1.0;
    163. }
    164. // 左均值即关系的个数除以关系的入边数量,右均值即关系的个数除以关系的出边数量
    165. for (INT i = 0; i < relationTotal; i++) {
    166. left_mean[i] = freqRel[i] / left_mean[i];
    167. right_mean[i] = freqRel[i] / right_mean[i];
    168. }
    169. }
    170. Triple *testList;
    171. Triple *validList;
    172. Triple *tripleList;
    173. extern "C" __declspec(dllexport)
    174. void importTestFiles() {
    175. FILE *fin;
    176. INT tmp;
    177. if (rel_file == "")
    178. fin = fopen((inPath + "relation2id.txt").c_str(), "r");
    179. else
    180. fin = fopen(rel_file.c_str(), "r");
    181. tmp = fscanf(fin, "%ld", &relationTotal);
    182. fclose(fin);
    183. if (ent_file == "")
    184. fin = fopen((inPath + "entity2id.txt").c_str(), "r");
    185. else
    186. fin = fopen(ent_file.c_str(), "r");
    187. tmp = fscanf(fin, "%ld", &entityTotal);
    188. fclose(fin);
    189. FILE* f_kb1, * f_kb2, * f_kb3;
    190. if (train_file == "")
    191. f_kb2 = fopen((inPath + "train2id.txt").c_str(), "r");
    192. else
    193. f_kb2 = fopen(train_file.c_str(), "r");
    194. if (test_file == "")
    195. f_kb1 = fopen((inPath + "test2id.txt").c_str(), "r");
    196. else
    197. f_kb1 = fopen(test_file.c_str(), "r");
    198. if (valid_file == "")
    199. f_kb3 = fopen((inPath + "valid2id.txt").c_str(), "r");
    200. else
    201. f_kb3 = fopen(valid_file.c_str(), "r");
    202. tmp = fscanf(f_kb1, "%ld", &testTotal);
    203. tmp = fscanf(f_kb2, "%ld", &trainTotal);
    204. tmp = fscanf(f_kb3, "%ld", &validTotal);
    205. tripleTotal = testTotal + trainTotal + validTotal;
    206. testList = (Triple *)calloc(testTotal, sizeof(Triple));
    207. validList = (Triple *)calloc(validTotal, sizeof(Triple));
    208. tripleList = (Triple *)calloc(tripleTotal, sizeof(Triple));
    209. for (INT i = 0; i < testTotal; i++) {
    210. tmp = fscanf(f_kb1, "%ld", &testList[i].h);
    211. tmp = fscanf(f_kb1, "%ld", &testList[i].t);
    212. tmp = fscanf(f_kb1, "%ld", &testList[i].r);
    213. tripleList[i] = testList[i];
    214. }
    215. for (INT i = 0; i < trainTotal; i++) {
    216. tmp = fscanf(f_kb2, "%ld", &tripleList[i + testTotal].h);
    217. tmp = fscanf(f_kb2, "%ld", &tripleList[i + testTotal].t);
    218. tmp = fscanf(f_kb2, "%ld", &tripleList[i + testTotal].r);
    219. }
    220. for (INT i = 0; i < validTotal; i++) {
    221. tmp = fscanf(f_kb3, "%ld", &tripleList[i + testTotal + trainTotal].h);
    222. tmp = fscanf(f_kb3, "%ld", &tripleList[i + testTotal + trainTotal].t);
    223. tmp = fscanf(f_kb3, "%ld", &tripleList[i + testTotal + trainTotal].r);
    224. validList[i] = tripleList[i + testTotal + trainTotal];
    225. }
    226. fclose(f_kb1);
    227. fclose(f_kb2);
    228. fclose(f_kb3);
    229. std::sort(tripleList, tripleList + tripleTotal, Triple::cmp_head);
    230. std::sort(testList, testList + testTotal, Triple::cmp_rel2);
    231. std::sort(validList, validList + validTotal, Triple::cmp_rel2);
    232. printf("The total of test triples is %ld.\n", testTotal);
    233. printf("The total of valid triples is %ld.\n", validTotal);
    234. testLef = (INT *)calloc(relationTotal, sizeof(INT));
    235. testRig = (INT *)calloc(relationTotal, sizeof(INT));
    236. memset(testLef, -1, sizeof(INT) * relationTotal);
    237. memset(testRig, -1, sizeof(INT) * relationTotal);
    238. for (INT i = 1; i < testTotal; i++) {
    239. if (testList[i].r != testList[i-1].r) {
    240. testRig[testList[i-1].r] = i - 1;
    241. testLef[testList[i].r] = i;
    242. }
    243. }
    244. testLef[testList[0].r] = 0;
    245. testRig[testList[testTotal - 1].r] = testTotal - 1;
    246. validLef = (INT *)calloc(relationTotal, sizeof(INT));
    247. validRig = (INT *)calloc(relationTotal, sizeof(INT));
    248. memset(validLef, -1, sizeof(INT)*relationTotal);
    249. memset(validRig, -1, sizeof(INT)*relationTotal);
    250. for (INT i = 1; i < validTotal; i++) {
    251. if (validList[i].r != validList[i-1].r) {
    252. validRig[validList[i-1].r] = i - 1;
    253. validLef[validList[i].r] = i;
    254. }
    255. }
    256. validLef[validList[0].r] = 0;
    257. validRig[validList[validTotal - 1].r] = validTotal - 1;
    258. }
    259. INT* head_lef;
    260. INT* head_rig;
    261. INT* tail_lef;
    262. INT* tail_rig;
    263. INT* head_type;
    264. INT* tail_type;
    265. extern "C" __declspec(dllexport)
    266. void importTypeFiles() {
    267. head_lef = (INT *)calloc(relationTotal, sizeof(INT));
    268. head_rig = (INT *)calloc(relationTotal, sizeof(INT));
    269. tail_lef = (INT *)calloc(relationTotal, sizeof(INT));
    270. tail_rig = (INT *)calloc(relationTotal, sizeof(INT));
    271. INT total_lef = 0;
    272. INT total_rig = 0;
    273. FILE* f_type = fopen((inPath + "type_constrain.txt").c_str(),"r");
    274. INT tmp;
    275. tmp = fscanf(f_type, "%ld", &tmp);
    276. for (INT i = 0; i < relationTotal; i++) {
    277. INT rel, tot;
    278. tmp = fscanf(f_type, "%ld %ld", &rel, &tot);
    279. for (INT j = 0; j < tot; j++) {
    280. tmp = fscanf(f_type, "%ld", &tmp);
    281. total_lef++;
    282. }
    283. tmp = fscanf(f_type, "%ld%ld", &rel, &tot);
    284. for (INT j = 0; j < tot; j++) {
    285. tmp = fscanf(f_type, "%ld", &tmp);
    286. total_rig++;
    287. }
    288. }
    289. fclose(f_type);
    290. head_type = (INT *)calloc(total_lef, sizeof(INT));
    291. tail_type = (INT *)calloc(total_rig, sizeof(INT));
    292. total_lef = 0;
    293. total_rig = 0;
    294. f_type = fopen((inPath + "type_constrain.txt").c_str(),"r");
    295. tmp = fscanf(f_type, "%ld", &tmp);
    296. for (INT i = 0; i < relationTotal; i++) {
    297. INT rel, tot;
    298. tmp = fscanf(f_type, "%ld%ld", &rel, &tot);
    299. head_lef[rel] = total_lef;
    300. for (INT j = 0; j < tot; j++) {
    301. tmp = fscanf(f_type, "%ld", &head_type[total_lef]);
    302. total_lef++;
    303. }
    304. head_rig[rel] = total_lef;
    305. std::sort(head_type + head_lef[rel], head_type + head_rig[rel]);
    306. tmp = fscanf(f_type, "%ld%ld", &rel, &tot);
    307. tail_lef[rel] = total_rig;
    308. for (INT j = 0; j < tot; j++) {
    309. tmp = fscanf(f_type, "%ld", &tail_type[total_rig]);
    310. total_rig++;
    311. }
    312. tail_rig[rel] = total_rig;
    313. std::sort(tail_type + tail_lef[rel], tail_type + tail_rig[rel]);
    314. }
    315. fclose(f_type);
    316. }
    317. #endif