MCPcopy Index your code
hub / github.com/PaddlePaddle/PaddleRec / Main

Class Main

tools/static_ps_trainer.py:69–268  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

67
68
69class Main(object):
70 def __init__(self, config):
71 self.metrics = {}
72 self.config = config
73 self.input_data = None
74 self.reader = None
75 self.exe = None
76 self.train_result_dict = {}
77 self.train_result_dict["speed"] = []
78 self.model = None
79 self.pure_bf16 = self.config['pure_bf16']
80
81 def run(self):
82 self.init_fleet_with_gloo()
83 self.network()
84 if fleet.is_server():
85 self.run_server()
86 elif fleet.is_worker():
87 self.run_worker()
88 fleet.stop_worker()
89 self.record_result()
90 logger.info("Run Success, Exit.")
91
92 def init_fleet_with_gloo(use_gloo=True):
93 if use_gloo:
94 os.environ["PADDLE_WITH_GLOO"] = "1"
95 role = role_maker.PaddleCloudRoleMaker()
96 fleet.init(role)
97 else:
98 fleet.init()
99
100 def network(self):
101 self.model = get_model(self.config)
102 self.input_data = self.model.create_feeds()
103 self.inference_feed_var = self.model.create_feeds(is_infer=False)
104 self.init_reader()
105 self.metrics = self.model.net(self.input_data)
106 self.inference_target_var = self.model.inference_target_var
107 logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
108 self.model.create_optimizer(get_strategy(self.config))
109
110 def run_server(self):
111 logger.info("Run Server Begin")
112 fleet.init_server(config.get("runner.warmup_model_path"))
113 fleet.run_server()
114
115 def run_worker(self):
116 logger.info("Run Worker Begin")
117 use_cuda = int(config.get("runner.use_gpu"))
118 place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
119 self.exe = paddle.static.Executor(place)
120
121 with open("./{}_worker_main_program.prototxt".format(
122 fleet.worker_index()), 'w+') as f:
123 f.write(str(paddle.static.default_main_program()))
124 with open("./{}_worker_startup_program.prototxt".format(
125 fleet.worker_index()), 'w+') as f:
126 f.write(str(paddle.static.default_startup_program()))

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected