(config, ckpt, verbose=True, freeze=True)
| 284 | |
| 285 | |
| 286 | def load_model_from_config(config, ckpt, verbose=True, freeze=True): |
| 287 | print(f"Loading model from {ckpt}") |
| 288 | if ckpt.endswith("ckpt"): |
| 289 | pl_sd = torch.load(ckpt, map_location="cpu") |
| 290 | if "global_step" in pl_sd: |
| 291 | print(f"Global Step: {pl_sd['global_step']}") |
| 292 | sd = pl_sd["state_dict"] |
| 293 | elif ckpt.endswith("safetensors"): |
| 294 | sd = load_safetensors(ckpt) |
| 295 | else: |
| 296 | raise NotImplementedError |
| 297 | |
| 298 | model = instantiate_from_config(config.model) |
| 299 | |
| 300 | m, u = model.load_state_dict(sd, strict=False) |
| 301 | |
| 302 | if len(m) > 0 and verbose: |
| 303 | print("missing keys:") |
| 304 | print(m) |
| 305 | if len(u) > 0 and verbose: |
| 306 | print("unexpected keys:") |
| 307 | print(u) |
| 308 | |
| 309 | if freeze: |
| 310 | for param in model.parameters(): |
| 311 | param.requires_grad = False |
| 312 | |
| 313 | model.eval() |
| 314 | return model |
| 315 | |
| 316 | |
| 317 | def get_configs_path() -> str: |
nothing calls this directly
no test coverage detected