Restores the model from a checkpoint. Args: checkpoint_path: An optional string specifying the checkpoint path to restore from. If `None`, will restore from the most recent checkpoint (or initialize the model using a custom `init_fn` if no checkpoints can be found)
(self, checkpoint_path: Optional[str] = None)
| 439 | return output |
| 440 | |
| 441 | def restore_checkpoint(self, checkpoint_path: Optional[str] = None): |
| 442 | """Restores the model from a checkpoint. |
| 443 | |
| 444 | Args: |
| 445 | checkpoint_path: An optional string specifying the checkpoint path to |
| 446 | restore from. If `None`, will restore from the most recent checkpoint |
| 447 | (or initialize the model using a custom `init_fn` if no checkpoints can |
| 448 | be found) using `self.checkpoint_manager.restore_or_initialize()`. |
| 449 | |
| 450 | Returns: |
| 451 | The path to the restored checkpoint if a restore happened, or `None` if no |
| 452 | restore occurred. |
| 453 | """ |
| 454 | self._require("checkpoint_manager", for_method="restore_checkpoint") |
| 455 | |
| 456 | assert isinstance(self.checkpoint_manager, tf.train.CheckpointManager) |
| 457 | with self.strategy.scope(): |
| 458 | # Checkpoint restoring should be inside scope (b/139450638). |
| 459 | if checkpoint_path is not None: |
| 460 | _log(f"restoring model from {checkpoint_path}...") |
| 461 | self.checkpoint_manager.checkpoint.restore(checkpoint_path) |
| 462 | else: |
| 463 | _log("restoring or initializing model...") |
| 464 | checkpoint_path = self.checkpoint_manager.restore_or_initialize() |
| 465 | |
| 466 | if checkpoint_path is not None: |
| 467 | _log(f"restored model from {checkpoint_path}.") |
| 468 | |
| 469 | return checkpoint_path |
| 470 | |
| 471 | def save_checkpoint(self): |
| 472 | """Saves the model to a checkpoint. |