MCPcopy
hub / github.com/PixArt-alpha/PixArt-sigma / load_checkpoint

Function load_checkpoint

diffusion/utils/checkpoint.py:40–84  ·  view source on GitHub ↗
(checkpoint,
                    model,
                    model_ema=None,
                    optimizer=None,
                    lr_scheduler=None,
                    load_ema=False,
                    resume_optimizer=True,
                    resume_lr_scheduler=True,
                    max_length=120,
                    )

Source from the content-addressed store, hash-verified

38
39
40def load_checkpoint(checkpoint,
41 model,
42 model_ema=None,
43 optimizer=None,
44 lr_scheduler=None,
45 load_ema=False,
46 resume_optimizer=True,
47 resume_lr_scheduler=True,
48 max_length=120,
49 ):
50 assert isinstance(checkpoint, str)
51 ckpt_file = checkpoint
52 checkpoint = torch.load(ckpt_file, map_location="cpu")
53
54 state_dict_keys = ['pos_embed', 'base_model.pos_embed', 'model.pos_embed']
55 for key in state_dict_keys:
56 if key in checkpoint['state_dict']:
57 del checkpoint['state_dict'][key]
58 if 'state_dict_ema' in checkpoint and key in checkpoint['state_dict_ema']:
59 del checkpoint['state_dict_ema'][key]
60 break
61
62 if load_ema:
63 state_dict = checkpoint['state_dict_ema']
64 else:
65 state_dict = checkpoint.get('state_dict', checkpoint) # to be compatible with the official checkpoint
66
67 null_embed = torch.load(f'output/pretrained_models/null_embed_diffusers_{max_length}token.pth', map_location='cpu')
68 state_dict['y_embedder.y_embedding'] = null_embed['uncond_prompt_embeds'][0]
69
70 missing, unexpect = model.load_state_dict(state_dict, strict=False)
71 if model_ema is not None:
72 model_ema.load_state_dict(checkpoint['state_dict_ema'], strict=False)
73 if optimizer is not None and resume_optimizer:
74 optimizer.load_state_dict(checkpoint['optimizer'])
75 if lr_scheduler is not None and resume_lr_scheduler:
76 lr_scheduler.load_state_dict(checkpoint['scheduler'])
77 logger = get_root_logger()
78 if optimizer is not None:
79 epoch = checkpoint.get('epoch', re.match(r'.*epoch_(\d*).*.pth', ckpt_file).group()[0])
80 logger.info(f'Resume checkpoint of epoch {epoch} from {ckpt_file}. Load ema: {load_ema}, '
81 f'resume optimizer: {resume_optimizer}, resume lr scheduler: {resume_lr_scheduler}.')
82 return epoch, missing, unexpect
83 logger.info(f'Load checkpoint from {ckpt_file}. Load ema: {load_ema}.')
84 return missing, unexpect

Callers 2

train.pyFile · 0.90

Calls 1

get_root_loggerFunction · 0.90

Tested by

no test coverage detected