MCPcopy
hub / github.com/PaddlePaddle/PaddleRec / main

Function main

tools/trainer.py:47–220  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

45
46
47def main(args):
48 paddle.seed(12345)
49 # load config
50 config = load_yaml(args.config_yaml)
51 dy_model_class = load_dy_model_class(args.abs_dir)
52 config["config_abs_dir"] = args.abs_dir
53 # modify config from command
54 if args.opt:
55 for parameter in args.opt:
56 parameter = parameter.strip()
57 key, value = parameter.split("=")
58 if type(config.get(key)) is int:
59 value = int(value)
60 if type(config.get(key)) is float:
61 value = float(value)
62 if type(config.get(key)) is bool:
63 value = (True if value.lower() == "true" else False)
64 config[key] = value
65
66 # tools.vars
67 use_gpu = config.get("runner.use_gpu", True)
68 use_auc = config.get("runner.use_auc", False)
69 use_npu = config.get("runner.use_npu", False)
70 use_xpu = config.get("runner.use_xpu", False)
71 use_visual = config.get("runner.use_visual", False)
72 train_data_dir = config.get("runner.train_data_dir", None)
73 epochs = config.get("runner.epochs", None)
74 print_interval = config.get("runner.print_interval", None)
75 train_batch_size = config.get("runner.train_batch_size", None)
76 model_save_path = config.get("runner.model_save_path", "model_output")
77 model_init_path = config.get("runner.model_init_path", None)
78 use_fleet = config.get("runner.use_fleet", False)
79
80 logger.info("**************common.configs**********")
81 logger.info(
82 "use_gpu: {}, use_xpu: {}, use_npu: {}, use_visual: {}, train_batch_size: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}".
83 format(use_gpu, use_xpu, use_npu, use_visual, train_batch_size,
84 train_data_dir, epochs, print_interval, model_save_path))
85 logger.info("**************common.configs**********")
86
87 if use_xpu:
88 xpu_device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
89 place = paddle.set_device(xpu_device)
90 elif use_npu:
91 npu_device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0))
92 place = paddle.set_device(npu_device)
93 else:
94 place = paddle.set_device('gpu' if use_gpu else 'cpu')
95
96 dy_model = dy_model_class.create_model(config)
97
98 # Create a log_visual object and store the data in the path
99 if use_visual:
100 from visualdl import LogWriter
101 log_visual = LogWriter(args.abs_dir + "/visualDL_log/train")
102
103 if model_init_path is not None:
104 load_model(model_init_path, dy_model)

Callers 1

trainer.pyFile · 0.70

Calls 15

load_yamlFunction · 0.90
load_dy_model_classFunction · 0.90
load_modelFunction · 0.90
create_data_loaderFunction · 0.90
save_modelFunction · 0.90
clear_gradMethod · 0.80
stepMethod · 0.80
accumulateMethod · 0.80
numpyMethod · 0.80
create_modelMethod · 0.45
create_optimizerMethod · 0.45
initMethod · 0.45

Tested by

no test coverage detected