Runs interleaved training and evaluation. This method interleaves calls to `self.train()` and `self.evaluate()`, training the model until the global step count equals `train_steps`, and running an evaluation for `eval_steps` every `eval_interval` training steps. In addition, this me
(
self,
train_steps: int,
eval_steps: int = -1,
eval_interval: Optional[int] = None,
)
| 351 | return eval_output |
| 352 | |
| 353 | def train_and_evaluate( |
| 354 | self, |
| 355 | train_steps: int, |
| 356 | eval_steps: int = -1, |
| 357 | eval_interval: Optional[int] = None, |
| 358 | ) -> Optional[runner.Output]: |
| 359 | """Runs interleaved training and evaluation. |
| 360 | |
| 361 | This method interleaves calls to `self.train()` and `self.evaluate()`, |
| 362 | training the model until the global step count equals `train_steps`, and |
| 363 | running an evaluation for `eval_steps` every `eval_interval` training steps. |
| 364 | In addition, this method will run a final evaluation at the end of the |
| 365 | training sequence. |
| 366 | |
| 367 | When async checkpointing is enabled, a sync is triggered at the end of this |
| 368 | method to make sure any ongoing async checkpoint saving is finished before |
| 369 | returning. |
| 370 | |
| 371 | Args: |
| 372 | train_steps: The global step count to train up to. |
| 373 | eval_steps: The number of steps to run during an evaluation. If -1, this |
| 374 | method will evaluate over the entire evaluation dataset. |
| 375 | eval_interval: The number of training steps to run between evaluations. If |
| 376 | set, training will always stop every `eval_interval` steps, even if this |
| 377 | results in a shorter inner loop than specified by `steps_per_loop` |
| 378 | setting. If None, evaluation will only be performed after training is |
| 379 | complete. |
| 380 | |
| 381 | Returns: |
| 382 | The evaluation results as a dictionary mapping names to NumPy values. |
| 383 | """ |
| 384 | self._require("trainer", for_method="train_and_evaluate") |
| 385 | self._require("evaluator", for_method="train_and_evaluate") |
| 386 | |
| 387 | output = None |
| 388 | current_step = self.global_step.numpy() # Cache, since this is expensive. |
| 389 | eval_interval = eval_interval or (train_steps - current_step) |
| 390 | while current_step < train_steps: |
| 391 | interval = min(train_steps - current_step, eval_interval) |
| 392 | num_steps = current_step + interval |
| 393 | self.train(steps=num_steps, checkpoint_at_completion=False) |
| 394 | output = self.evaluate(steps=eval_steps) |
| 395 | current_step = self.global_step.numpy() |
| 396 | self._maybe_save_checkpoint(check_interval=False) |
| 397 | self._sync_on_async_checkpointing() |
| 398 | return output |
| 399 | |
| 400 | def evaluate_continuously( |
| 401 | self, |