Saves the state_dict of module, optimizer, and accountant at path. Args: path: Path to save the state dict objects. module: nn.Module or GradSampleModule to save; wrapped module's state_dict is saved. optimizer: DPOptimizer to save; wrapped optimi
(
self,
*,
path: Union[str, os.PathLike, BinaryIO, IO[bytes]],
module: Union[nn.Module, GradSampleModule],
optimizer: Optional[DPOptimizer] = None,
noise_scheduler: Optional[_NoiseScheduler] = None,
grad_clip_scheduler: Optional[_GradClipScheduler] = None,
checkpoint_dict: Optional[Dict[str, Any]] = None,
module_state_dict_kwargs: Optional[Dict[str, Any]] = None,
torch_save_kwargs: Optional[Dict[str, Any]] = None,
)
| 616 | return self.accountant.get_epsilon(delta) |
| 617 | |
| 618 | def save_checkpoint( |
| 619 | self, |
| 620 | *, |
| 621 | path: Union[str, os.PathLike, BinaryIO, IO[bytes]], |
| 622 | module: Union[nn.Module, GradSampleModule], |
| 623 | optimizer: Optional[DPOptimizer] = None, |
| 624 | noise_scheduler: Optional[_NoiseScheduler] = None, |
| 625 | grad_clip_scheduler: Optional[_GradClipScheduler] = None, |
| 626 | checkpoint_dict: Optional[Dict[str, Any]] = None, |
| 627 | module_state_dict_kwargs: Optional[Dict[str, Any]] = None, |
| 628 | torch_save_kwargs: Optional[Dict[str, Any]] = None, |
| 629 | ): |
| 630 | """ |
| 631 | Saves the state_dict of module, optimizer, and accountant at path. |
| 632 | Args: |
| 633 | path: Path to save the state dict objects. |
| 634 | module: nn.Module or GradSampleModule to save; wrapped module's state_dict is saved. |
| 635 | optimizer: DPOptimizer to save; wrapped optimizer's state_dict is saved. |
| 636 | noise_scheduler: _NoiseScheduler whose state we should save. |
| 637 | grad_clip_scheduler: _GradClipScheduler whose state we should save. |
| 638 | checkpoint_dict: Dict[str, Any]; an already-filled checkpoint dict. |
| 639 | module_state_dict_kwargs: dict of kwargs to pass to ``module.state_dict()`` |
| 640 | torch_save_kwargs: dict of kwargs to pass to ``torch.save()`` |
| 641 | |
| 642 | """ |
| 643 | checkpoint_dict = checkpoint_dict or {} |
| 644 | checkpoint_dict["module_state_dict"] = module.state_dict( |
| 645 | **(module_state_dict_kwargs or {}) |
| 646 | ) |
| 647 | checkpoint_dict["privacy_accountant_state_dict"] = self.accountant.state_dict() |
| 648 | if optimizer is not None: |
| 649 | checkpoint_dict["optimizer_state_dict"] = optimizer.state_dict() |
| 650 | if noise_scheduler is not None: |
| 651 | checkpoint_dict["noise_scheduler_state_dict"] = noise_scheduler.state_dict() |
| 652 | if grad_clip_scheduler is not None: |
| 653 | checkpoint_dict["grad_clip_scheduler_state_dict"] = ( |
| 654 | grad_clip_scheduler.state_dict() |
| 655 | ) |
| 656 | |
| 657 | torch.save(checkpoint_dict, path, **(torch_save_kwargs or {})) |
| 658 | |
| 659 | def load_checkpoint( |
| 660 | self, |