(cfg, model, logger, phase="train")
| 1 | import torch |
| 2 | |
| 3 | def load_pretrained(cfg, model, logger, phase="train"): |
| 4 | logger.info(f"Loading pretrain model from {cfg.TRAIN.PRETRAINED}") |
| 5 | if phase == "train": |
| 6 | ckpt_path = cfg.TRAIN.PRETRAINED |
| 7 | elif phase == "test": |
| 8 | ckpt_path = cfg.TEST.CHECKPOINTS |
| 9 | |
| 10 | state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] |
| 11 | model.load_state_dict(state_dict, strict=True) |
| 12 | return model |
| 13 | |
| 14 | |
| 15 | def load_pretrained_vae(cfg, model, logger): |