MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / load_pretrained

Function load_pretrained

mGPT/utils/load_checkpoint.py:3–12  ·  view source on GitHub ↗
(cfg, model, logger, phase="train")

Source from the content-addressed store, hash-verified

1import torch
2
3def load_pretrained(cfg, model, logger, phase="train"):
4 logger.info(f"Loading pretrain model from {cfg.TRAIN.PRETRAINED}")
5 if phase == "train":
6 ckpt_path = cfg.TRAIN.PRETRAINED
7 elif phase == "test":
8 ckpt_path = cfg.TEST.CHECKPOINTS
9
10 state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
11 model.load_state_dict(state_dict, strict=True)
12 return model
13
14
15def load_pretrained_vae(cfg, model, logger):

Callers 2

mainFunction · 0.90
mainFunction · 0.90

Calls 1

load_state_dictMethod · 0.80

Tested by 1

mainFunction · 0.72