MCPcopy
hub / github.com/thunlp/OpenKE / TrainDataLoader

Class TrainDataLoader

openke/data/TrainDataLoader.py:25–229  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

23 return self.nbatches
24
25class TrainDataLoader(object):
26
27 def __init__(self,
28 in_path = "./",
29 tri_file = None,
30 ent_file = None,
31 rel_file = None,
32 batch_size = None,
33 nbatches = None,
34 threads = 8,
35 sampling_mode = "normal",
36 bern_flag = False,
37 filter_flag = True,
38 neg_ent = 1,
39 neg_rel = 0):
40
41 base_file = os.path.abspath(os.path.join(os.path.dirname(__file__), "../release/Base.so"))
42 self.lib = ctypes.cdll.LoadLibrary(base_file)
43 """argtypes"""
44 self.lib.sampling.argtypes = [
45 ctypes.c_void_p,
46 ctypes.c_void_p,
47 ctypes.c_void_p,
48 ctypes.c_void_p,
49 ctypes.c_int64,
50 ctypes.c_int64,
51 ctypes.c_int64,
52 ctypes.c_int64,
53 ctypes.c_int64,
54 ctypes.c_int64,
55 ctypes.c_int64
56 ]
57 self.in_path = in_path
58 self.tri_file = tri_file
59 self.ent_file = ent_file
60 self.rel_file = rel_file
61 if in_path != None:
62 self.tri_file = in_path + "train2id.txt"
63 self.ent_file = in_path + "entity2id.txt"
64 self.rel_file = in_path + "relation2id.txt"
65 """set essential parameters"""
66 self.work_threads = threads
67 self.nbatches = nbatches
68 self.batch_size = batch_size
69 self.bern = bern_flag
70 self.filter = filter_flag
71 self.negative_ent = neg_ent
72 self.negative_rel = neg_rel
73 self.sampling_mode = sampling_mode
74 self.cross_sampling_flag = 0
75 self.read()
76
77 def read(self):
78 if self.in_path != None:
79 self.lib.setInPath(ctypes.create_string_buffer(self.in_path.encode(), len(self.in_path) * 2))
80 else:
81 self.lib.setTrainPath(ctypes.create_string_buffer(self.tri_file.encode(), len(self.tri_file) * 2))
82 self.lib.setEntPath(ctypes.create_string_buffer(self.ent_file.encode(), len(self.ent_file) * 2))

Calls

no outgoing calls

Tested by

no test coverage detected