| 668 | return ema_model |
| 669 | |
| 670 | def save_pretrained(self, path): |
| 671 | if self.model_cls is None: |
| 672 | raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") |
| 673 | |
| 674 | if self.model_config is None: |
| 675 | raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") |
| 676 | |
| 677 | model = self.model_cls.from_config(self.model_config) |
| 678 | state_dict = self.state_dict() |
| 679 | state_dict.pop("shadow_params", None) |
| 680 | |
| 681 | model.register_to_config(**state_dict) |
| 682 | self.copy_to(model.parameters()) |
| 683 | model.save_pretrained(path) |
| 684 | |
| 685 | def get_decay(self, optimization_step: int) -> float: |
| 686 | """ |