(cls, path, model_cls)
| 115 | |
| 116 | @classmethod |
| 117 | def from_pretrained(cls, path, model_cls) -> "EMA": |
| 118 | _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) |
| 119 | model = model_cls.from_pretrained(path) |
| 120 | |
| 121 | ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) |
| 122 | |
| 123 | ema_model.load_state_dict(ema_kwargs) |
| 124 | return ema_model |
| 125 | |
| 126 | def save_pretrained(self, path): |
| 127 | if self.model_cls is None: |
no test coverage detected