MCPcopy Index your code
hub / github.com/PaddlePaddle/PaddleRec / dataloader_train

Function dataloader_train

tools/static_trainer.py:233–281  ·  view source on GitHub ↗
(epoch_id, train_dataloader, input_data_names, fetch_vars,
                     exe, config, use_visual, log_visual, step_num)

Source from the content-addressed store, hash-verified

231
232
233def dataloader_train(epoch_id, train_dataloader, input_data_names, fetch_vars,
234 exe, config, use_visual, log_visual, step_num):
235 print_interval = config.get("runner.print_interval", None)
236 batch_size = config.get("runner.train_batch_size", None)
237 interval_begin = time.time()
238 train_reader_cost = 0.0
239 train_run_cost = 0.0
240 total_samples = 0
241 reader_start = time.time()
242
243 #we will drop the last incomplete batch when dataset size is not divisible by the batch size
244 assert any(train_dataloader(
245 )), "train_dataloader's size is null, please ensure batch size < dataset size!"
246
247 for batch_id, batch_data in enumerate(train_dataloader()):
248 train_reader_cost += time.time() - reader_start
249 train_start = time.time()
250
251 fetch_batch_var = exe.run(
252 program=paddle.static.default_main_program(),
253 feed=dict(zip(input_data_names, batch_data)),
254 fetch_list=[var for _, var in fetch_vars.items()])
255
256 train_run_cost += time.time() - train_start
257 total_samples += batch_size
258 if batch_id % print_interval == 0:
259 metric_str = ""
260 for var_idx, var_name in enumerate(fetch_vars):
261 metric_str += "{}: {}, ".format(
262 var_name, str(fetch_batch_var[var_idx]).strip("[]"))
263 if use_visual:
264 log_visual.add_scalar(
265 tag="train/" + var_name,
266 step=step_num,
267 value=fetch_batch_var[var_idx])
268 logger.info(
269 "epoch: {}, batch_id: {}, ".format(epoch_id,
270 batch_id) + metric_str +
271 "avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} ins/s".
272 format(train_reader_cost / print_interval, (
273 train_reader_cost + train_run_cost) / print_interval,
274 total_samples / print_interval, total_samples / (
275 train_reader_cost + train_run_cost)))
276 train_reader_cost = 0.0
277 train_run_cost = 0.0
278 total_samples = 0
279 reader_start = time.time()
280 step_num = step_num + 1
281 return fetch_batch_var, step_num
282
283
284if __name__ == "__main__":

Callers 1

mainFunction · 0.85

Calls 1

runMethod · 0.45

Tested by

no test coverage detected