MCPcopy
hub / github.com/PaddlePaddle/PaddleRec / run_worker

Method run_worker

tools/static_ps_trainer.py:115–176  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

113 fleet.run_server()
114
115 def run_worker(self):
116 logger.info("Run Worker Begin")
117 use_cuda = int(config.get("runner.use_gpu"))
118 place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
119 self.exe = paddle.static.Executor(place)
120
121 with open("./{}_worker_main_program.prototxt".format(
122 fleet.worker_index()), 'w+') as f:
123 f.write(str(paddle.static.default_main_program()))
124 with open("./{}_worker_startup_program.prototxt".format(
125 fleet.worker_index()), 'w+') as f:
126 f.write(str(paddle.static.default_startup_program()))
127
128 self.exe.run(paddle.static.default_startup_program())
129 if self.pure_bf16:
130 self.model.optimizer.amp_init(self.exe.place)
131 fleet.init_worker()
132
133 save_model_path = self.config.get("runner.model_save_path")
134 if save_model_path and (not os.path.exists(save_model_path)):
135 os.makedirs(save_model_path)
136
137 reader_type = self.config.get("runner.reader_type", "QueueDataset")
138 epochs = int(self.config.get("runner.epochs"))
139 sync_mode = self.config.get("runner.sync_mode")
140
141 if reader_type == "InmemoryDataset":
142 self.reader.load_into_memory()
143
144 for epoch in range(epochs):
145 epoch_start_time = time.time()
146
147 if sync_mode == "heter":
148 self.heter_train_loop(epoch)
149 elif reader_type == "QueueDataset":
150 self.dataset_train_loop(epoch)
151 elif reader_type == "InmemoryDataset":
152 self.dataset_train_loop(epoch)
153
154 epoch_time = time.time() - epoch_start_time
155 epoch_speed = self.example_nums / epoch_time
156 logger.info(
157 "Epoch: {}, using time {} second, ips {} {}/sec.".format(
158 epoch, epoch_time, epoch_speed, self.count_method))
159 self.train_result_dict["speed"].append(epoch_speed)
160
161 model_dir = "{}/{}".format(save_model_path, epoch)
162 if fleet.is_first_worker() and save_model_path:
163 if is_distributed_env():
164 fleet.save_inference_model(
165 self.exe, model_dir,
166 [feed.name for feed in self.inference_feed_var],
167 self.inference_target_var)
168 else:
169 paddle.fluid.io.save_inference_model(
170 model_dir,
171 [feed.name for feed in self.inference_feed_var],
172 [self.inference_target_var], self.exe)

Callers 1

runMethod · 0.95

Calls 4

heter_train_loopMethod · 0.95
dataset_train_loopMethod · 0.95
is_distributed_envFunction · 0.90
runMethod · 0.45

Tested by

no test coverage detected