Runs training until the specified global step count has been reached. This method makes calls to `self.trainer.train()` until the global step count is equal to `steps`. It will additionally save checkpoints (if a `CheckpointManager` was passed to `Controller.__init__`) and summarize
(self, steps: int, checkpoint_at_completion: bool = True)
| 255 | _orbit_api_gauge.get_cell().set(True) |
| 256 | |
| 257 | def train(self, steps: int, checkpoint_at_completion: bool = True): |
| 258 | """Runs training until the specified global step count has been reached. |
| 259 | |
| 260 | This method makes calls to `self.trainer.train()` until the global step |
| 261 | count is equal to `steps`. It will additionally save checkpoints (if a |
| 262 | `CheckpointManager` was passed to `Controller.__init__`) and summarize |
| 263 | training output (if `summary_dir` is set). |
| 264 | |
| 265 | When async checkpointing is enabled, a sync is triggered at the end of this |
| 266 | method to make sure any ongoing async checkpoint saving is finished before |
| 267 | returning. |
| 268 | |
| 269 | Args: |
| 270 | steps: The global step count to train up to. |
| 271 | checkpoint_at_completion: Whether to save a checkpoint when this method |
| 272 | returns (regardless of the checkpointing interval). Defaults to `True`. |
| 273 | """ |
| 274 | self._require("trainer", for_method="train") |
| 275 | |
| 276 | # TODO(momernick): Support steps=None or -1 (training to exhaustion). |
| 277 | current_step = self.global_step.numpy() # Cache, since this is expensive. |
| 278 | _log(f"train | step: {current_step: 6d} | training until step {steps}...") |
| 279 | while current_step < steps: |
| 280 | # Calculates steps to run for the next train loop. |
| 281 | num_steps = min(steps - current_step, self.steps_per_loop) |
| 282 | self._train_n_steps(num_steps) |
| 283 | self._maybe_save_checkpoint() |
| 284 | current_step = self.global_step.numpy() |
| 285 | |
| 286 | if checkpoint_at_completion: |
| 287 | self._maybe_save_checkpoint(check_interval=False) |
| 288 | |
| 289 | self._sync_on_async_checkpointing() |
| 290 | |
| 291 | def evaluate(self, steps: int = -1) -> Optional[runner.Output]: |
| 292 | """Runs evaluation for the given number of steps. |