(checkpoint,
model,
model_ema=None,
optimizer=None,
lr_scheduler=None,
load_ema=False,
resume_optimizer=True,
resume_lr_scheduler=True,
max_length=120,
)
| 38 | |
| 39 | |
| 40 | def load_checkpoint(checkpoint, |
| 41 | model, |
| 42 | model_ema=None, |
| 43 | optimizer=None, |
| 44 | lr_scheduler=None, |
| 45 | load_ema=False, |
| 46 | resume_optimizer=True, |
| 47 | resume_lr_scheduler=True, |
| 48 | max_length=120, |
| 49 | ): |
| 50 | assert isinstance(checkpoint, str) |
| 51 | ckpt_file = checkpoint |
| 52 | checkpoint = torch.load(ckpt_file, map_location="cpu") |
| 53 | |
| 54 | state_dict_keys = ['pos_embed', 'base_model.pos_embed', 'model.pos_embed'] |
| 55 | for key in state_dict_keys: |
| 56 | if key in checkpoint['state_dict']: |
| 57 | del checkpoint['state_dict'][key] |
| 58 | if 'state_dict_ema' in checkpoint and key in checkpoint['state_dict_ema']: |
| 59 | del checkpoint['state_dict_ema'][key] |
| 60 | break |
| 61 | |
| 62 | if load_ema: |
| 63 | state_dict = checkpoint['state_dict_ema'] |
| 64 | else: |
| 65 | state_dict = checkpoint.get('state_dict', checkpoint) # to be compatible with the official checkpoint |
| 66 | |
| 67 | null_embed = torch.load(f'output/pretrained_models/null_embed_diffusers_{max_length}token.pth', map_location='cpu') |
| 68 | state_dict['y_embedder.y_embedding'] = null_embed['uncond_prompt_embeds'][0] |
| 69 | |
| 70 | missing, unexpect = model.load_state_dict(state_dict, strict=False) |
| 71 | if model_ema is not None: |
| 72 | model_ema.load_state_dict(checkpoint['state_dict_ema'], strict=False) |
| 73 | if optimizer is not None and resume_optimizer: |
| 74 | optimizer.load_state_dict(checkpoint['optimizer']) |
| 75 | if lr_scheduler is not None and resume_lr_scheduler: |
| 76 | lr_scheduler.load_state_dict(checkpoint['scheduler']) |
| 77 | logger = get_root_logger() |
| 78 | if optimizer is not None: |
| 79 | epoch = checkpoint.get('epoch', re.match(r'.*epoch_(\d*).*.pth', ckpt_file).group()[0]) |
| 80 | logger.info(f'Resume checkpoint of epoch {epoch} from {ckpt_file}. Load ema: {load_ema}, ' |
| 81 | f'resume optimizer: {resume_optimizer}, resume lr scheduler: {resume_lr_scheduler}.') |
| 82 | return epoch, missing, unexpect |
| 83 | logger.info(f'Load checkpoint from {ckpt_file}. Load ema: {load_ema}.') |
| 84 | return missing, unexpect |
no test coverage detected