MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / guess_inputs

Function guess_inputs

scripts/dump-model-params.py:46–68  ·  view source on GitHub ↗
(input_dir)

Source from the content-addressed store, hash-verified

44
45
46def guess_inputs(input_dir):
47 meta_candidates = []
48 model_candidates = []
49 for path in os.listdir(input_dir):
50 if path.startswith('graph-') and path.endswith('.meta'):
51 meta_candidates.append(path)
52 if path.startswith('model-') and path.endswith('.index'):
53 modelid = int(path[len('model-'):-len('.index')])
54 model_candidates.append((path, modelid))
55 assert len(meta_candidates)
56 meta = sorted(meta_candidates)[-1]
57 if len(meta_candidates) > 1:
58 logger.info("Choosing {} from {} as graph file.".format(meta, meta_candidates))
59 else:
60 logger.info("Choosing {} as graph file.".format(meta))
61
62 assert len(model_candidates)
63 model = sorted(model_candidates, key=lambda x: x[1])[-1][0]
64 if len(model_candidates) > 1:
65 logger.info("Choosing {} from {} as model file.".format(model, [x[0] for x in model_candidates]))
66 else:
67 logger.info("Choosing {} as model file.".format(model))
68 return os.path.join(input_dir, model), os.path.join(input_dir, meta)
69
70
71if __name__ == '__main__':

Callers 1

Calls 3

appendMethod · 0.80
formatMethod · 0.80
joinMethod · 0.80

Tested by

no test coverage detected