MCPcopy
hub / github.com/PaddlePaddle/PaddleRec / main

Function main

models/match/kim/trainer.py:50–216  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

48
49
50def main(args):
51 paddle.seed(12345)
52 # load config
53 config = load_yaml(args.config_yaml)
54 dy_model_class = load_dy_model_class(args.abs_dir)
55 config["config_abs_dir"] = args.abs_dir
56 # modify config from command
57 if args.opt:
58 for parameter in args.opt:
59 parameter = parameter.strip()
60 key, value = parameter.split("=")
61 if type(config.get(key)) is int:
62 value = int(value)
63 if type(config.get(key)) is float:
64 value = float(value)
65 if type(config.get(key)) is bool:
66 value = (True if value.lower() == "true" else False)
67 config[key] = value
68
69 # tools.vars
70 use_gpu = config.get("runner.use_gpu", True)
71 use_xpu = config.get("runner.use_xpu", False)
72 use_visual = config.get("runner.use_visual", False)
73 train_data_dir = config.get("runner.train_data_dir", None)
74 epochs = config.get("runner.epochs", None)
75 print_interval = config.get("runner.print_interval", None)
76 train_batch_size = config.get("runner.train_batch_size", None)
77 model_save_path = config.get("runner.model_save_path", "model_output")
78 model_init_path = config.get("runner.model_init_path", None)
79 use_fleet = config.get("runner.use_fleet", False)
80
81 logger.info("**************common.configs**********")
82 logger.info(
83 "use_gpu: {}, use_xpu: {}, use_visual: {}, train_batch_size: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}".
84 format(use_gpu, use_xpu, use_visual, train_batch_size, train_data_dir,
85 epochs, print_interval, model_save_path))
86 logger.info("**************common.configs**********")
87
88 if use_xpu:
89 xpu_device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
90 place = paddle.set_device(xpu_device)
91 else:
92 place = paddle.set_device('gpu' if use_gpu else 'cpu')
93
94 dy_model = dy_model_class.create_model(config)
95
96 # Create a log_visual object and store the data in the path
97 if use_visual:
98 from visualdl import LogWriter
99 log_visual = LogWriter(args.abs_dir + "/visualDL_log/train")
100
101 if model_init_path is not None:
102 load_model(model_init_path, dy_model)
103
104 # to do : add optimizer function
105 optimizer = dy_model_class.create_optimizer(dy_model, config)
106
107 # use fleet run collective

Callers 1

trainer.pyFile · 0.70

Calls 14

load_yamlFunction · 0.90
load_dy_model_classFunction · 0.90
load_modelFunction · 0.90
create_data_loaderFunction · 0.90
save_modelFunction · 0.90
clear_gradMethod · 0.80
stepMethod · 0.80
accumulateMethod · 0.80
numpyMethod · 0.80
create_modelMethod · 0.45
create_optimizerMethod · 0.45
initMethod · 0.45

Tested by

no test coverage detected