MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / load_pretrained_vae

Function load_pretrained_vae

mGPT/utils/load_checkpoint.py:15–34  ·  view source on GitHub ↗
(cfg, model, logger)

Source from the content-addressed store, hash-verified

13
14
15def load_pretrained_vae(cfg, model, logger):
16 state_dict = torch.load(cfg.TRAIN.PRETRAINED_VAE,
17 map_location="cpu")['state_dict']
18 logger.info(f"Loading pretrain vae from {cfg.TRAIN.PRETRAINED_VAE}")
19 # Extract encoder/decoder
20 from collections import OrderedDict
21 vae_dict = OrderedDict()
22 for k, v in state_dict.items():
23 if "motion_vae" in k:
24 name = k.replace("motion_vae.", "")
25 vae_dict[name] = v
26 elif "vae" in k:
27 name = k.replace("vae.", "")
28 vae_dict[name] = v
29 if hasattr(model, 'vae'):
30 model.vae.load_state_dict(vae_dict, strict=True)
31 else:
32 model.motion_vae.load_state_dict(vae_dict, strict=True)
33
34 return model

Callers 2

mainFunction · 0.90
mainFunction · 0.90

Calls 2

itemsMethod · 0.80
load_state_dictMethod · 0.80

Tested by 1

mainFunction · 0.72