| 13 | |
| 14 | |
| 15 | def load_pretrained_vae(cfg, model, logger): |
| 16 | state_dict = torch.load(cfg.TRAIN.PRETRAINED_VAE, |
| 17 | map_location="cpu")['state_dict'] |
| 18 | logger.info(f"Loading pretrain vae from {cfg.TRAIN.PRETRAINED_VAE}") |
| 19 | # Extract encoder/decoder |
| 20 | from collections import OrderedDict |
| 21 | vae_dict = OrderedDict() |
| 22 | for k, v in state_dict.items(): |
| 23 | if "motion_vae" in k: |
| 24 | name = k.replace("motion_vae.", "") |
| 25 | vae_dict[name] = v |
| 26 | elif "vae" in k: |
| 27 | name = k.replace("vae.", "") |
| 28 | vae_dict[name] = v |
| 29 | if hasattr(model, 'vae'): |
| 30 | model.vae.load_state_dict(vae_dict, strict=True) |
| 31 | else: |
| 32 | model.motion_vae.load_state_dict(vae_dict, strict=True) |
| 33 | |
| 34 | return model |