MCPcopy
hub / github.com/thunlp/OpenKE / TestDataLoader

Class TestDataLoader

openke/data/TestDataLoader.py:25–153  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

23 return self.data_total
24
25class TestDataLoader(object):
26
27 def __init__(self, in_path = "./", sampling_mode = 'link', type_constrain = True):
28 base_file = os.path.abspath(os.path.join(os.path.dirname(__file__), "../release/Base.so"))
29 self.lib = ctypes.cdll.LoadLibrary(base_file)
30 """for link prediction"""
31 self.lib.getHeadBatch.argtypes = [
32 ctypes.c_void_p,
33 ctypes.c_void_p,
34 ctypes.c_void_p,
35 ]
36 self.lib.getTailBatch.argtypes = [
37 ctypes.c_void_p,
38 ctypes.c_void_p,
39 ctypes.c_void_p,
40 ]
41 """for triple classification"""
42 self.lib.getTestBatch.argtypes = [
43 ctypes.c_void_p,
44 ctypes.c_void_p,
45 ctypes.c_void_p,
46 ctypes.c_void_p,
47 ctypes.c_void_p,
48 ctypes.c_void_p,
49 ]
50 """set essential parameters"""
51 self.in_path = in_path
52 self.sampling_mode = sampling_mode
53 self.type_constrain = type_constrain
54 self.read()
55
56 def read(self):
57 self.lib.setInPath(ctypes.create_string_buffer(self.in_path.encode(), len(self.in_path) * 2))
58 self.lib.randReset()
59 self.lib.importTestFiles()
60
61 if self.type_constrain:
62 self.lib.importTypeFiles()
63
64 self.relTotal = self.lib.getRelationTotal()
65 self.entTotal = self.lib.getEntityTotal()
66 self.testTotal = self.lib.getTestTotal()
67
68 self.test_h = np.zeros(self.entTotal, dtype=np.int64)
69 self.test_t = np.zeros(self.entTotal, dtype=np.int64)
70 self.test_r = np.zeros(self.entTotal, dtype=np.int64)
71 self.test_h_addr = self.test_h.__array_interface__["data"][0]
72 self.test_t_addr = self.test_t.__array_interface__["data"][0]
73 self.test_r_addr = self.test_r.__array_interface__["data"][0]
74
75 self.test_pos_h = np.zeros(self.testTotal, dtype=np.int64)
76 self.test_pos_t = np.zeros(self.testTotal, dtype=np.int64)
77 self.test_pos_r = np.zeros(self.testTotal, dtype=np.int64)
78 self.test_pos_h_addr = self.test_pos_h.__array_interface__["data"][0]
79 self.test_pos_t_addr = self.test_pos_t.__array_interface__["data"][0]
80 self.test_pos_r_addr = self.test_pos_r.__array_interface__["data"][0]
81 self.test_neg_h = np.zeros(self.testTotal, dtype=np.int64)
82 self.test_neg_t = np.zeros(self.testTotal, dtype=np.int64)

Calls

no outgoing calls

Tested by

no test coverage detected