MCPcopy
hub / github.com/meta-pytorch/opacus / save_checkpoint

Method save_checkpoint

opacus/privacy_engine.py:618–657  ·  view source on GitHub ↗

Saves the state_dict of module, optimizer, and accountant at path. Args: path: Path to save the state dict objects. module: nn.Module or GradSampleModule to save; wrapped module's state_dict is saved. optimizer: DPOptimizer to save; wrapped optimi

(
        self,
        *,
        path: Union[str, os.PathLike, BinaryIO, IO[bytes]],
        module: Union[nn.Module, GradSampleModule],
        optimizer: Optional[DPOptimizer] = None,
        noise_scheduler: Optional[_NoiseScheduler] = None,
        grad_clip_scheduler: Optional[_GradClipScheduler] = None,
        checkpoint_dict: Optional[Dict[str, Any]] = None,
        module_state_dict_kwargs: Optional[Dict[str, Any]] = None,
        torch_save_kwargs: Optional[Dict[str, Any]] = None,
    )

Source from the content-addressed store, hash-verified

616 return self.accountant.get_epsilon(delta)
617
618 def save_checkpoint(
619 self,
620 *,
621 path: Union[str, os.PathLike, BinaryIO, IO[bytes]],
622 module: Union[nn.Module, GradSampleModule],
623 optimizer: Optional[DPOptimizer] = None,
624 noise_scheduler: Optional[_NoiseScheduler] = None,
625 grad_clip_scheduler: Optional[_GradClipScheduler] = None,
626 checkpoint_dict: Optional[Dict[str, Any]] = None,
627 module_state_dict_kwargs: Optional[Dict[str, Any]] = None,
628 torch_save_kwargs: Optional[Dict[str, Any]] = None,
629 ):
630 """
631 Saves the state_dict of module, optimizer, and accountant at path.
632 Args:
633 path: Path to save the state dict objects.
634 module: nn.Module or GradSampleModule to save; wrapped module's state_dict is saved.
635 optimizer: DPOptimizer to save; wrapped optimizer's state_dict is saved.
636 noise_scheduler: _NoiseScheduler whose state we should save.
637 grad_clip_scheduler: _GradClipScheduler whose state we should save.
638 checkpoint_dict: Dict[str, Any]; an already-filled checkpoint dict.
639 module_state_dict_kwargs: dict of kwargs to pass to ``module.state_dict()``
640 torch_save_kwargs: dict of kwargs to pass to ``torch.save()``
641
642 """
643 checkpoint_dict = checkpoint_dict or {}
644 checkpoint_dict["module_state_dict"] = module.state_dict(
645 **(module_state_dict_kwargs or {})
646 )
647 checkpoint_dict["privacy_accountant_state_dict"] = self.accountant.state_dict()
648 if optimizer is not None:
649 checkpoint_dict["optimizer_state_dict"] = optimizer.state_dict()
650 if noise_scheduler is not None:
651 checkpoint_dict["noise_scheduler_state_dict"] = noise_scheduler.state_dict()
652 if grad_clip_scheduler is not None:
653 checkpoint_dict["grad_clip_scheduler_state_dict"] = (
654 grad_clip_scheduler.state_dict()
655 )
656
657 torch.save(checkpoint_dict, path, **(torch_save_kwargs or {}))
658
659 def load_checkpoint(
660 self,

Callers 1

test_checkpointsMethod · 0.80

Calls 1

state_dictMethod · 0.45

Tested by 1

test_checkpointsMethod · 0.64