| 231 | |
| 232 | |
| 233 | def dataloader_train(epoch_id, train_dataloader, input_data_names, fetch_vars, |
| 234 | exe, config, use_visual, log_visual, step_num): |
| 235 | print_interval = config.get("runner.print_interval", None) |
| 236 | batch_size = config.get("runner.train_batch_size", None) |
| 237 | interval_begin = time.time() |
| 238 | train_reader_cost = 0.0 |
| 239 | train_run_cost = 0.0 |
| 240 | total_samples = 0 |
| 241 | reader_start = time.time() |
| 242 | |
| 243 | #we will drop the last incomplete batch when dataset size is not divisible by the batch size |
| 244 | assert any(train_dataloader( |
| 245 | )), "train_dataloader's size is null, please ensure batch size < dataset size!" |
| 246 | |
| 247 | for batch_id, batch_data in enumerate(train_dataloader()): |
| 248 | train_reader_cost += time.time() - reader_start |
| 249 | train_start = time.time() |
| 250 | |
| 251 | fetch_batch_var = exe.run( |
| 252 | program=paddle.static.default_main_program(), |
| 253 | feed=dict(zip(input_data_names, batch_data)), |
| 254 | fetch_list=[var for _, var in fetch_vars.items()]) |
| 255 | |
| 256 | train_run_cost += time.time() - train_start |
| 257 | total_samples += batch_size |
| 258 | if batch_id % print_interval == 0: |
| 259 | metric_str = "" |
| 260 | for var_idx, var_name in enumerate(fetch_vars): |
| 261 | metric_str += "{}: {}, ".format( |
| 262 | var_name, str(fetch_batch_var[var_idx]).strip("[]")) |
| 263 | if use_visual: |
| 264 | log_visual.add_scalar( |
| 265 | tag="train/" + var_name, |
| 266 | step=step_num, |
| 267 | value=fetch_batch_var[var_idx]) |
| 268 | logger.info( |
| 269 | "epoch: {}, batch_id: {}, ".format(epoch_id, |
| 270 | batch_id) + metric_str + |
| 271 | "avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} ins/s". |
| 272 | format(train_reader_cost / print_interval, ( |
| 273 | train_reader_cost + train_run_cost) / print_interval, |
| 274 | total_samples / print_interval, total_samples / ( |
| 275 | train_reader_cost + train_run_cost))) |
| 276 | train_reader_cost = 0.0 |
| 277 | train_run_cost = 0.0 |
| 278 | total_samples = 0 |
| 279 | reader_start = time.time() |
| 280 | step_num = step_num + 1 |
| 281 | return fetch_batch_var, step_num |
| 282 | |
| 283 | |
| 284 | if __name__ == "__main__": |