| 14 | return getattr(importlib.import_module(module, package=None), cls) |
| 15 | |
| 16 | def instantiate_from_config(config) -> object: |
| 17 | if not "target" in config: |
| 18 | raise KeyError("Expected key `target` to instantiate.") |
| 19 | model = get_obj_from_str(config["target"])(**config.get("params", dict())) |
| 20 | ckpt_path = config.get("ckpt", None) |
| 21 | if ckpt_path is not None: |
| 22 | state_dict = torch.load(ckpt_path, map_location="cpu") |
| 23 | # see if it's a ckpt from training by checking for "model" |
| 24 | if "ema" in state_dict: |
| 25 | state_dict = state_dict["ema"] |
| 26 | elif "model" in state_dict: |
| 27 | raise NotImplementedError("Loading from 'model' key not implemented yet.") |
| 28 | state_dict = state_dict["model"] |
| 29 | model.load_state_dict(state_dict, strict=True) |
| 30 | print(f'target {config["target"]} loaded from {ckpt_path}') |
| 31 | return model |
| 32 | |