(cls, path, model_cls, foreach=False)
| 659 | |
| 660 | @classmethod |
| 661 | def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel": |
| 662 | _, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True) |
| 663 | model = model_cls.from_pretrained(path) |
| 664 | |
| 665 | ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach) |
| 666 | |
| 667 | ema_model.load_state_dict(ema_kwargs) |
| 668 | return ema_model |
| 669 | |
| 670 | def save_pretrained(self, path): |
| 671 | if self.model_cls is None: |
no test coverage detected