MCPcopy
hub / github.com/PaddlePaddle/PaddleRec / heter_train_loop

Method heter_train_loop

tools/static_ps_trainer.py:223–263  ·  view source on GitHub ↗
(self, epoch)

Source from the content-addressed store, hash-verified

221 debug=debug)
222
223 def heter_train_loop(self, epoch):
224 logger.info(
225 "Epoch: {}, Running Begin. Check running metrics at heter_log".
226 format(epoch))
227 reader_type = self.config.get("runner.reader_type")
228 if reader_type == "QueueDataset":
229 self.exe.train_from_dataset(
230 program=paddle.static.default_main_program(),
231 dataset=self.reader,
232 debug=config.get("runner.dataset_debug"))
233 elif reader_type == "DataLoader":
234 batch_id = 0
235 train_run_cost = 0.0
236 total_examples = 0
237 self.reader.start()
238 while True:
239 try:
240 train_start = time.time()
241 # --------------------------------------------------- #
242 self.exe.run(program=paddle.static.default_main_program())
243 # --------------------------------------------------- #
244 train_run_cost += time.time() - train_start
245 total_examples += self.config.get("runner.batch_size")
246 batch_id += 1
247 print_step = int(config.get("runner.print_period"))
248 if batch_id % print_step == 0:
249 profiler_string = ""
250 profiler_string += "avg_batch_cost: {} sec, ".format(
251 format((train_run_cost) / print_step, '.5f'))
252 profiler_string += "avg_samples: {}, ".format(
253 format(total_examples / print_step, '.5f'))
254 profiler_string += "ips: {} {}/sec ".format(
255 format(total_examples / (train_run_cost), '.5f'),
256 self.count_method)
257 logger.info("Epoch: {}, Batch: {}, {}".format(
258 epoch, batch_id, profiler_string))
259 train_run_cost = 0.0
260 total_examples = 0
261 except paddle.core.EOFException:
262 self.reader.reset()
263 break
264
265 def record_result(self):
266 logger.info("train_result_dict: {}".format(self.train_result_dict))

Callers 1

run_workerMethod · 0.95

Calls 3

startMethod · 0.80
runMethod · 0.45
resetMethod · 0.45

Tested by

no test coverage detected