MCPcopy
hub / github.com/PaddlePaddle/PaddleRec / main

Function main

tools/infer.py:47–194  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

45
46
47def main(args):
48 paddle.seed(12345)
49 # load config
50 config = load_yaml(args.config_yaml)
51 dy_model_class = load_dy_model_class(args.abs_dir)
52 config["config_abs_dir"] = args.abs_dir
53 # modify config from command
54 if args.opt:
55 for parameter in args.opt:
56 parameter = parameter.strip()
57 key, value = parameter.split("=")
58 if type(config.get(key)) is int:
59 value = int(value)
60 if type(config.get(key)) is float:
61 value = float(value)
62 if type(config.get(key)) is bool:
63 value = (True if value.lower() == "true" else False)
64 config[key] = value
65
66 # tools.vars
67 use_gpu = config.get("runner.use_gpu", True)
68 use_auc = config.get("runner.use_auc", False)
69 use_xpu = config.get("runner.use_xpu", False)
70 use_npu = config.get("runner.use_npu", False)
71 use_visual = config.get("runner.use_visual", False)
72 test_data_dir = config.get("runner.test_data_dir", None)
73 print_interval = config.get("runner.print_interval", None)
74 infer_batch_size = config.get("runner.infer_batch_size", None)
75 model_load_path = config.get("runner.infer_load_path", "model_output")
76 start_epoch = config.get("runner.infer_start_epoch", 0)
77 end_epoch = config.get("runner.infer_end_epoch", 10)
78
79 logger.info("**************common.configs**********")
80 logger.info(
81 "use_gpu: {}, use_xpu: {}, use_npu: {}, use_visual: {}, infer_batch_size: {}, test_data_dir: {}, start_epoch: {}, end_epoch: {}, print_interval: {}, model_load_path: {}".
82 format(use_gpu, use_xpu, use_npu, use_visual, infer_batch_size,
83 test_data_dir, start_epoch, end_epoch, print_interval,
84 model_load_path))
85 logger.info("**************common.configs**********")
86
87 if use_xpu:
88 xpu_device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
89 place = paddle.set_device(xpu_device)
90 elif use_npu:
91 npu_device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0))
92 place = paddle.set_device(npu_device)
93 else:
94 place = paddle.set_device('gpu' if use_gpu else 'cpu')
95
96 dy_model = dy_model_class.create_model(config)
97
98 # Create a log_visual object and store the data in the path
99 if use_visual:
100 from visualdl import LogWriter
101 log_visual = LogWriter(args.abs_dir + "/visualDL_log/infer")
102
103 # to do : add optimizer function
104 #optimizer = dy_model_class.create_optimizer(dy_model, config)

Callers 1

infer.pyFile · 0.70

Calls 10

load_yamlFunction · 0.90
load_dy_model_classFunction · 0.90
create_data_loaderFunction · 0.90
load_modelFunction · 0.90
numpyMethod · 0.80
accumulateMethod · 0.80
create_modelMethod · 0.45
create_metricsMethod · 0.45
infer_forwardMethod · 0.45
resetMethod · 0.45

Tested by

no test coverage detected