MCPcopy
hub / github.com/PaddlePaddle/PaddleRec / main

Function main

tools/to_static.py:47–95  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

45
46
47def main(args):
48 paddle.seed(12345)
49 # load config
50 config = load_yaml(args.config_yaml)
51 dy_model_class = load_dy_model_class(args.abs_dir)
52 config["config_abs_dir"] = args.abs_dir
53 # modify config from command
54 if args.opt:
55 for parameter in args.opt:
56 parameter = parameter.strip()
57 key, value = parameter.split("=")
58 if type(config.get(key)) is int:
59 value = int(value)
60 if type(config.get(key)) is bool:
61 value = (True if value.lower() == "true" else False)
62 config[key] = value
63
64 # tools.vars
65 use_gpu = config.get("runner.use_gpu", True)
66 train_data_dir = config.get("runner.train_data_dir", None)
67 epochs = config.get("runner.epochs", None)
68 print_interval = config.get("runner.print_interval", None)
69 model_save_path = config.get("runner.model_save_path", "model_output")
70 model_init_path = config.get("runner.model_init_path", None)
71 end_epoch = config.get("runner.infer_end_epoch", 0)
72 CE = config.get("runner.CE", False)
73 logger.info("**************common.configs**********")
74 logger.info(
75 "use_gpu: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}".
76 format(use_gpu, train_data_dir, epochs, print_interval,
77 model_save_path))
78 logger.info("**************common.configs**********")
79
80 place = paddle.set_device('gpu' if use_gpu else 'cpu')
81
82 dy_model = dy_model_class.create_model(config)
83 if not CE:
84 model_save_path = os.path.join(model_save_path, str(end_epoch - 1))
85
86 load_model(model_init_path, dy_model)
87 # example dnn model forward
88 dy_model = paddle.jit.to_static(
89 dy_model,
90 input_spec=[[
91 paddle.static.InputSpec(
92 shape=[None, 1], dtype='int64') for jj in range(26)
93 ], paddle.static.InputSpec(
94 shape=[None, 13], dtype='float32')])
95 save_jit_model(dy_model, model_save_path, prefix='tostatic')
96
97
98if __name__ == '__main__':

Callers 1

to_static.pyFile · 0.70

Calls 5

load_yamlFunction · 0.90
load_dy_model_classFunction · 0.90
load_modelFunction · 0.90
save_jit_modelFunction · 0.90
create_modelMethod · 0.45

Tested by

no test coverage detected