MCPcopy
hub / github.com/PaddlePaddle/PaddleRec / main

Function main

tools/static_trainer.py:46–214  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

44
45
46def main(args):
47 paddle.seed(12345)
48
49 # load config
50 config = load_yaml(args.config_yaml)
51 config["yaml_path"] = args.config_yaml
52 config["config_abs_dir"] = args.abs_dir
53 # modify config from command
54 if args.opt:
55 for parameter in args.opt:
56 parameter = parameter.strip()
57 key, value = parameter.split("=")
58 if type(config.get(key)) is int:
59 value = int(value)
60 if type(config.get(key)) is float:
61 value = float(value)
62 if type(config.get(key)) is bool:
63 value = (True if value.lower() == "true" else False)
64 config[key] = value
65 # load static model class
66 static_model_class = load_static_model_class(config)
67 input_data = static_model_class.create_feeds()
68 input_data_names = [data.name for data in input_data]
69
70 fetch_vars = static_model_class.net(input_data)
71
72 #infer_target_var = model.infer_target_var
73 logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
74
75 use_gpu = config.get("runner.use_gpu", True)
76 use_xpu = config.get("runner.use_xpu", False)
77 use_auc = config.get("runner.use_auc", False)
78 use_visual = config.get("runner.use_visual", False)
79 use_inference = config.get("runner.use_inference", False)
80 auc_num = config.get("runner.auc_num", 1)
81 train_data_dir = config.get("runner.train_data_dir", None)
82 epochs = config.get("runner.epochs", None)
83 print_interval = config.get("runner.print_interval", None)
84 model_save_path = config.get("runner.model_save_path", "model_output")
85 model_init_path = config.get("runner.model_init_path", None)
86 batch_size = config.get("runner.train_batch_size", None)
87 reader_type = config.get("runner.reader_type", "DataLoader")
88 use_fleet = config.get("runner.use_fleet", False)
89 use_save_data = config.get("runner.use_save_data", False)
90 os.environ["CPU_NUM"] = str(config.get("runner.thread_num", 1))
91 logger.info("**************common.configs**********")
92 logger.info(
93 "use_gpu: {}, use_xpu: {}, use_visual: {}, train_batch_size: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}".
94 format(use_gpu, use_xpu, use_visual, batch_size, train_data_dir,
95 epochs, print_interval, model_save_path))
96 logger.info("**************common.configs**********")
97
98 if use_xpu:
99 xpu_device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
100 place = paddle.set_device(xpu_device)
101 else:
102 place = paddle.set_device('gpu' if use_gpu else 'cpu')
103

Callers 1

static_trainer.pyFile · 0.70

Calls 15

load_yamlFunction · 0.90
load_static_model_classFunction · 0.90
load_static_parameterFunction · 0.90
get_readerFunction · 0.90
create_data_loaderFunction · 0.90
reset_aucFunction · 0.90
save_static_modelFunction · 0.90
save_dataFunction · 0.90
save_inference_modelFunction · 0.90
dataloader_trainFunction · 0.85
dataset_trainFunction · 0.85
create_feedsMethod · 0.45

Tested by

no test coverage detected