Runs training for `num_steps` steps. Also prints/logs updates about training progress, and summarizes training output (if output is returned from `self.trainer.train()`, and if `self.summary_dir` is set). Args: num_steps: An integer specifying how many steps of training to ru
(self, num_steps: int)
| 489 | return self._steps_per_loop |
| 490 | |
| 491 | def _train_n_steps(self, num_steps: int): |
| 492 | """Runs training for `num_steps` steps. |
| 493 | |
| 494 | Also prints/logs updates about training progress, and summarizes training |
| 495 | output (if output is returned from `self.trainer.train()`, and if |
| 496 | `self.summary_dir` is set). |
| 497 | |
| 498 | Args: |
| 499 | num_steps: An integer specifying how many steps of training to run. |
| 500 | |
| 501 | Raises: |
| 502 | RuntimeError: If `global_step` is not properly incremented by `num_steps` |
| 503 | after calling `self.trainer.train(num_steps)`. |
| 504 | """ |
| 505 | if not self.step_timer: |
| 506 | self.step_timer = StepTimer(self.global_step) |
| 507 | current_step = self.global_step.numpy() |
| 508 | |
| 509 | with self.summary_manager.summary_writer().as_default(): |
| 510 | should_record = False # Allows static optimization in no-summary cases. |
| 511 | if self.summary_interval: |
| 512 | # Create a predicate to determine when summaries should be written. |
| 513 | should_record = lambda: (self.global_step % self.summary_interval == 0) |
| 514 | assert isinstance(self.trainer, runner.AbstractTrainer) |
| 515 | with tf.summary.record_if(should_record): |
| 516 | num_steps_tensor = tf.convert_to_tensor(num_steps, dtype=tf.int32) |
| 517 | train_output = self.trainer.train(num_steps_tensor) |
| 518 | |
| 519 | # Verify that global_step was updated properly, then update current_step. |
| 520 | expected_step = current_step + num_steps |
| 521 | if self.global_step.numpy() != expected_step: |
| 522 | message = ( |
| 523 | f"`trainer.train({num_steps})` did not update `global_step` by " |
| 524 | f"{num_steps}. Old value was {current_step}, expected updated value " |
| 525 | f"to be {expected_step}, but it was {self.global_step.numpy()}.") |
| 526 | logging.warning(message) |
| 527 | |
| 528 | train_output = train_output or {} |
| 529 | for action in self.train_actions: |
| 530 | action(train_output) |
| 531 | train_output = tf.nest.map_structure(utils.get_value, train_output) |
| 532 | |
| 533 | current_step = self.global_step.numpy() |
| 534 | steps_per_second = self.step_timer.steps_per_second() |
| 535 | _log(f"train | step: {current_step: 6d} | " |
| 536 | f"steps/sec: {steps_per_second: 6.1f} | " |
| 537 | f"output: {_format_output(train_output)}") |
| 538 | |
| 539 | train_output["steps_per_second"] = steps_per_second |
| 540 | self.summary_manager.write_summaries(train_output) |
| 541 | self.summary_manager.flush() |
| 542 | |
| 543 | def _maybe_save_checkpoint(self, check_interval: bool = True): |
| 544 | """Conditionally saves a checkpoint. |
no test coverage detected