(self, epoch)
| 221 | debug=debug) |
| 222 | |
| 223 | def heter_train_loop(self, epoch): |
| 224 | logger.info( |
| 225 | "Epoch: {}, Running Begin. Check running metrics at heter_log". |
| 226 | format(epoch)) |
| 227 | reader_type = self.config.get("runner.reader_type") |
| 228 | if reader_type == "QueueDataset": |
| 229 | self.exe.train_from_dataset( |
| 230 | program=paddle.static.default_main_program(), |
| 231 | dataset=self.reader, |
| 232 | debug=config.get("runner.dataset_debug")) |
| 233 | elif reader_type == "DataLoader": |
| 234 | batch_id = 0 |
| 235 | train_run_cost = 0.0 |
| 236 | total_examples = 0 |
| 237 | self.reader.start() |
| 238 | while True: |
| 239 | try: |
| 240 | train_start = time.time() |
| 241 | # --------------------------------------------------- # |
| 242 | self.exe.run(program=paddle.static.default_main_program()) |
| 243 | # --------------------------------------------------- # |
| 244 | train_run_cost += time.time() - train_start |
| 245 | total_examples += self.config.get("runner.batch_size") |
| 246 | batch_id += 1 |
| 247 | print_step = int(config.get("runner.print_period")) |
| 248 | if batch_id % print_step == 0: |
| 249 | profiler_string = "" |
| 250 | profiler_string += "avg_batch_cost: {} sec, ".format( |
| 251 | format((train_run_cost) / print_step, '.5f')) |
| 252 | profiler_string += "avg_samples: {}, ".format( |
| 253 | format(total_examples / print_step, '.5f')) |
| 254 | profiler_string += "ips: {} {}/sec ".format( |
| 255 | format(total_examples / (train_run_cost), '.5f'), |
| 256 | self.count_method) |
| 257 | logger.info("Epoch: {}, Batch: {}, {}".format( |
| 258 | epoch, batch_id, profiler_string)) |
| 259 | train_run_cost = 0.0 |
| 260 | total_examples = 0 |
| 261 | except paddle.core.EOFException: |
| 262 | self.reader.reset() |
| 263 | break |
| 264 | |
| 265 | def record_result(self): |
| 266 | logger.info("train_result_dict: {}".format(self.train_result_dict)) |
no test coverage detected