| 23 | return self.data_total |
| 24 | |
| 25 | class TestDataLoader(object): |
| 26 | |
| 27 | def __init__(self, in_path = "./", sampling_mode = 'link', type_constrain = True): |
| 28 | base_file = os.path.abspath(os.path.join(os.path.dirname(__file__), "../release/Base.so")) |
| 29 | self.lib = ctypes.cdll.LoadLibrary(base_file) |
| 30 | """for link prediction""" |
| 31 | self.lib.getHeadBatch.argtypes = [ |
| 32 | ctypes.c_void_p, |
| 33 | ctypes.c_void_p, |
| 34 | ctypes.c_void_p, |
| 35 | ] |
| 36 | self.lib.getTailBatch.argtypes = [ |
| 37 | ctypes.c_void_p, |
| 38 | ctypes.c_void_p, |
| 39 | ctypes.c_void_p, |
| 40 | ] |
| 41 | """for triple classification""" |
| 42 | self.lib.getTestBatch.argtypes = [ |
| 43 | ctypes.c_void_p, |
| 44 | ctypes.c_void_p, |
| 45 | ctypes.c_void_p, |
| 46 | ctypes.c_void_p, |
| 47 | ctypes.c_void_p, |
| 48 | ctypes.c_void_p, |
| 49 | ] |
| 50 | """set essential parameters""" |
| 51 | self.in_path = in_path |
| 52 | self.sampling_mode = sampling_mode |
| 53 | self.type_constrain = type_constrain |
| 54 | self.read() |
| 55 | |
| 56 | def read(self): |
| 57 | self.lib.setInPath(ctypes.create_string_buffer(self.in_path.encode(), len(self.in_path) * 2)) |
| 58 | self.lib.randReset() |
| 59 | self.lib.importTestFiles() |
| 60 | |
| 61 | if self.type_constrain: |
| 62 | self.lib.importTypeFiles() |
| 63 | |
| 64 | self.relTotal = self.lib.getRelationTotal() |
| 65 | self.entTotal = self.lib.getEntityTotal() |
| 66 | self.testTotal = self.lib.getTestTotal() |
| 67 | |
| 68 | self.test_h = np.zeros(self.entTotal, dtype=np.int64) |
| 69 | self.test_t = np.zeros(self.entTotal, dtype=np.int64) |
| 70 | self.test_r = np.zeros(self.entTotal, dtype=np.int64) |
| 71 | self.test_h_addr = self.test_h.__array_interface__["data"][0] |
| 72 | self.test_t_addr = self.test_t.__array_interface__["data"][0] |
| 73 | self.test_r_addr = self.test_r.__array_interface__["data"][0] |
| 74 | |
| 75 | self.test_pos_h = np.zeros(self.testTotal, dtype=np.int64) |
| 76 | self.test_pos_t = np.zeros(self.testTotal, dtype=np.int64) |
| 77 | self.test_pos_r = np.zeros(self.testTotal, dtype=np.int64) |
| 78 | self.test_pos_h_addr = self.test_pos_h.__array_interface__["data"][0] |
| 79 | self.test_pos_t_addr = self.test_pos_t.__array_interface__["data"][0] |
| 80 | self.test_pos_r_addr = self.test_pos_r.__array_interface__["data"][0] |
| 81 | self.test_neg_h = np.zeros(self.testTotal, dtype=np.int64) |
| 82 | self.test_neg_t = np.zeros(self.testTotal, dtype=np.int64) |
no outgoing calls
no test coverage detected