| 23 | return self.nbatches |
| 24 | |
| 25 | class TrainDataLoader(object): |
| 26 | |
| 27 | def __init__(self, |
| 28 | in_path = "./", |
| 29 | tri_file = None, |
| 30 | ent_file = None, |
| 31 | rel_file = None, |
| 32 | batch_size = None, |
| 33 | nbatches = None, |
| 34 | threads = 8, |
| 35 | sampling_mode = "normal", |
| 36 | bern_flag = False, |
| 37 | filter_flag = True, |
| 38 | neg_ent = 1, |
| 39 | neg_rel = 0): |
| 40 | |
| 41 | base_file = os.path.abspath(os.path.join(os.path.dirname(__file__), "../release/Base.so")) |
| 42 | self.lib = ctypes.cdll.LoadLibrary(base_file) |
| 43 | """argtypes""" |
| 44 | self.lib.sampling.argtypes = [ |
| 45 | ctypes.c_void_p, |
| 46 | ctypes.c_void_p, |
| 47 | ctypes.c_void_p, |
| 48 | ctypes.c_void_p, |
| 49 | ctypes.c_int64, |
| 50 | ctypes.c_int64, |
| 51 | ctypes.c_int64, |
| 52 | ctypes.c_int64, |
| 53 | ctypes.c_int64, |
| 54 | ctypes.c_int64, |
| 55 | ctypes.c_int64 |
| 56 | ] |
| 57 | self.in_path = in_path |
| 58 | self.tri_file = tri_file |
| 59 | self.ent_file = ent_file |
| 60 | self.rel_file = rel_file |
| 61 | if in_path != None: |
| 62 | self.tri_file = in_path + "train2id.txt" |
| 63 | self.ent_file = in_path + "entity2id.txt" |
| 64 | self.rel_file = in_path + "relation2id.txt" |
| 65 | """set essential parameters""" |
| 66 | self.work_threads = threads |
| 67 | self.nbatches = nbatches |
| 68 | self.batch_size = batch_size |
| 69 | self.bern = bern_flag |
| 70 | self.filter = filter_flag |
| 71 | self.negative_ent = neg_ent |
| 72 | self.negative_rel = neg_rel |
| 73 | self.sampling_mode = sampling_mode |
| 74 | self.cross_sampling_flag = 0 |
| 75 | self.read() |
| 76 | |
| 77 | def read(self): |
| 78 | if self.in_path != None: |
| 79 | self.lib.setInPath(ctypes.create_string_buffer(self.in_path.encode(), len(self.in_path) * 2)) |
| 80 | else: |
| 81 | self.lib.setTrainPath(ctypes.create_string_buffer(self.tri_file.encode(), len(self.tri_file) * 2)) |
| 82 | self.lib.setEntPath(ctypes.create_string_buffer(self.ent_file.encode(), len(self.ent_file) * 2)) |
no outgoing calls
no test coverage detected