1. #include "Setting.h"
    2. #include "Random.h"
    3. #include "Reader.h"
    4. #include "Corrupt.h"
    5. #include "Test.h"
    6. #include <cstdlib>
    7. #include <pthread.h>
    8. extern "C" __declspec(dllexport)
    9. void setInPath(char *path);
    10. extern "C" __declspec(dllexport)
    11. void setTrainPath(char *path);
    12. extern "C" __declspec(dllexport)
    13. void setValidPath(char *path);
    14. extern "C" __declspec(dllexport)
    15. void setTestPath(char *path);
    16. extern "C" __declspec(dllexport)
    17. void setEntPath(char *path);
    18. extern "C" __declspec(dllexport)
    19. void setRelPath(char *path);
    20. extern "C" __declspec(dllexport)
    21. void setOutPath(char *path);
    22. extern "C" __declspec(dllexport)
    23. void setWorkThreads(INT threads);
    24. extern "C" __declspec(dllexport)
    25. void setBern(INT con);
    26. extern "C" __declspec(dllexport)
    27. INT getWorkThreads();
    28. extern "C" __declspec(dllexport)
    29. INT getEntityTotal();
    30. extern "C" __declspec(dllexport)
    31. INT getRelationTotal();
    32. extern "C" __declspec(dllexport)
    33. INT getTripleTotal();
    34. extern "C" __declspec(dllexport)
    35. INT getTrainTotal();
    36. extern "C" __declspec(dllexport)
    37. INT getTestTotal();
    38. extern "C" __declspec(dllexport)
    39. INT getValidTotal();
    40. extern "C" __declspec(dllexport)
    41. void randReset();
    42. extern "C" __declspec(dllexport)
    43. void importTrainFiles();
    44. struct Parameter {
    45. INT id;
    46. INT *batch_h;
    47. INT *batch_t;
    48. INT *batch_r;
    49. REAL *batch_y;
    50. INT batchSize;
    51. INT negRate;
    52. INT negRelRate;
    53. bool p;
    54. bool val_loss;
    55. INT mode;
    56. bool filter_flag;
    57. };
    58. void* getBatch(void* con) {
    59. // 将传入的参数直接变成结构体形式进行接收
    60. Parameter *para = (Parameter *)(con);
    61. // 将para中对应的值放到局部变量中
    62. INT id = para -> id;
    63. INT *batch_h = para -> batch_h;
    64. INT *batch_t = para -> batch_t;
    65. INT *batch_r = para -> batch_r;
    66. REAL *batch_y = para -> batch_y;
    67. INT batchSize = para -> batchSize;
    68. INT negRate = para -> negRate;
    69. INT negRelRate = para -> negRelRate;
    70. bool p = para -> p;
    71. // 判断模式是否为训练
    72. bool val_loss = para -> val_loss;
    73. INT mode = para -> mode;
    74. bool filter_flag = para -> filter_flag;
    75. // 既然要并行计算嘛,那么不同线程处理的数据范围肯定是不一样的,这里就是划分了一下每个线程的操作范围
    76. INT lef, rig;
    77. if (batchSize % workThreads == 0) {
    78. lef = id * (batchSize / workThreads);
    79. rig = (id + 1) * (batchSize / workThreads);
    80. } else {
    81. lef = id * (batchSize / workThreads + 1);
    82. rig = (id + 1) * (batchSize / workThreads + 1);
    83. if (rig > batchSize) rig = batchSize;
    84. }
    85. // 一个阈值,决定负采样头部与负采样尾部的比率,500表示各自一半
    86. REAL prob = 500;
    87. // 开始采样负例三元组,用于训练
    88. if (val_loss == false) {
    89. for (INT batch = lef; batch < rig; batch++) {
    90. // rand的实现是赋予每个线程一个种子,这样每个线程可以生成自己的随机数(可能是因为所有线程都拿当前时间作为seed的话可能会造成重复?)
    91. INT i = rand_max(id, trainTotal);
    92. // 得到了处于0~三元组总数之间的一个随机id
    93. batch_h[batch] = trainList[i].h;
    94. batch_t[batch] = trainList[i].t;
    95. batch_r[batch] = trainList[i].r;
    96. batch_y[batch] = 1;
    97. INT last = batchSize;
    98. // 负采样实体negRate次
    99. for (INT times = 0; times < negRate; times ++) {
    100. // 模式为普通采样,即负采样头部与负采样尾部各占一半
    101. if (mode == 0){
    102. if (bernFlag)
    103. prob = 1000 * right_mean[trainList[i].r] / (right_mean[trainList[i].r] + left_mean[trainList[i].r]);
    104. if (randd(id) % 1000 < prob) {
    105. batch_h[batch + last] = trainList[i].h;
    106. batch_t[batch + last] = corrupt_head(id, trainList[i].h, trainList[i].r);
    107. batch_r[batch + last] = trainList[i].r;
    108. } else {
    109. batch_h[batch + last] = corrupt_tail(id, trainList[i].t, trainList[i].r);
    110. batch_t[batch + last] = trainList[i].t;
    111. batch_r[batch + last] = trainList[i].r;
    112. }
    113. batch_y[batch + last] = -1;
    114. last += batchSize;
    115. }
    116. // 模式为头部采样,即仅负采样头部
    117. else {
    118. if(mode == -1){
    119. batch_h[batch + last] = corrupt_tail(id, trainList[i].t, trainList[i].r);
    120. batch_t[batch + last] = trainList[i].t;
    121. batch_r[batch + last] = trainList[i].r;
    122. }
    123. // 模式为尾部采样,即仅负采样尾部
    124. else {
    125. batch_h[batch + last] = trainList[i].h;
    126. batch_t[batch + last] = corrupt_head(id, trainList[i].h, trainList[i].r);
    127. batch_r[batch + last] = trainList[i].r;
    128. }
    129. batch_y[batch + last] = -1;
    130. last += batchSize;
    131. }
    132. }
    133. // 负采样关系negRelRate次
    134. for (INT times = 0; times < negRelRate; times++) {
    135. batch_h[batch + last] = trainList[i].h;
    136. batch_t[batch + last] = trainList[i].t;
    137. batch_r[batch + last] = corrupt_rel(id, trainList[i].h, trainList[i].t, trainList[i].r, p);
    138. batch_y[batch + last] = -1;
    139. last += batchSize;
    140. }
    141. }
    142. }
    143. else
    144. {
    145. for (INT batch = lef; batch < rig; batch++)
    146. {
    147. batch_h[batch] = validList[batch].h;
    148. batch_t[batch] = validList[batch].t;
    149. batch_r[batch] = validList[batch].r;
    150. batch_y[batch] = 1;
    151. }
    152. }
    153. pthread_exit(NULL);
    154. return ((void*)0);
    155. }
    156. extern "C" __declspec(dllexport)
    157. void sampling(
    158. INT *batch_h,
    159. INT *batch_t,
    160. INT *batch_r,
    161. REAL *batch_y,
    162. INT batchSize,
    163. INT negRate = 1,
    164. INT negRelRate = 0,
    165. // mode=0代表普通采样,mode=1代表尾部采样,mode=-1代表头部采样
    166. INT mode = 0,
    167. bool filter_flag = true,
    168. bool p = false,
    169. // val_loss应该代表的是是否为训练模式
    170. bool val_loss = false
    171. ) {
    172. // 根据线程数量,向内存分配指定的大小
    173. pthread_t *pt = (pthread_t *)malloc(workThreads * sizeof(pthread_t));
    174. // 根据线程数量,以及Parameter结构体的大小,向内存分配指定的大小
    175. Parameter *para = (Parameter *)malloc(workThreads * sizeof(Parameter));
    176. //初始化para结构体
    177. for (INT threads = 0; threads < workThreads; threads++) {
    178. para[threads].id = threads;
    179. para[threads].batch_h = batch_h;
    180. para[threads].batch_t = batch_t;
    181. para[threads].batch_r = batch_r;
    182. para[threads].batch_y = batch_y;
    183. para[threads].batchSize = batchSize;
    184. para[threads].negRate = negRate;
    185. para[threads].negRelRate = negRelRate;
    186. para[threads].p = p;
    187. para[threads].val_loss = val_loss;
    188. para[threads].mode = mode;
    189. para[threads].filter_flag = filter_flag;
    190. /*
    191. 创建线程
    192. int pthread_create(
    193. pthread_t *restrict tidp, //新创建的线程ID指向的内存单元。
    194. const pthread_attr_t *restrict attr, //线程属性,默认为NULL
    195. void *(*start_rtn)(void *), //新创建的线程从start_rtn函数的地址开始运行
    196. void *restrict arg //默认为NULL。若上述函数需要参数,将参数放入结构中并将地址作为arg传入。
    197. );
    198. */
    199. pthread_create(&pt[threads], NULL, getBatch, (void*)(para+threads));
    200. }
    201. // 收工
    202. for (INT threads = 0; threads < workThreads; threads++)
    203. pthread_join(pt[threads], NULL);
    204. free(pt);
    205. free(para);
    206. }
    207. int main() {
    208. importTrainFiles();
    209. return 0;
    210. }