(config, model, optimizer, lr_scheduler, loss_scaler, logger)
| 55 | |
| 56 | |
| 57 | def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger): |
| 58 | logger.info( |
| 59 | f"==============> Resuming form {config.MODEL.RESUME}....................") |
| 60 | if config.MODEL.RESUME.startswith('https'): |
| 61 | checkpoint = torch.hub.load_state_dict_from_url( |
| 62 | config.MODEL.RESUME, map_location='cpu', check_hash=True) |
| 63 | else: |
| 64 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') |
| 65 | |
| 66 | params = checkpoint['model'] |
| 67 | now_model_state = model.state_dict() |
| 68 | mnames = ['head.weight', 'head.bias'] # (cls, 1024), (cls, ) |
| 69 | if mnames[-1] in params: |
| 70 | ckpt_head_bias = params[mnames[-1]] |
| 71 | now_model_bias = now_model_state[mnames[-1]] |
| 72 | if ckpt_head_bias.shape != now_model_bias.shape: |
| 73 | num_classes = 1000 |
| 74 | |
| 75 | if len(ckpt_head_bias) == 21841 and len(now_model_bias) == num_classes: |
| 76 | logger.info("Convert checkpoint from 21841 to 1k") |
| 77 | # convert 22kto1k |
| 78 | fname = './imagenet_1kto22k.txt' |
| 79 | with open(fname) as fin: |
| 80 | mapping = torch.Tensor( |
| 81 | list(map(int, fin.readlines()))).to(torch.long) |
| 82 | for name in mnames: |
| 83 | v = params[name] |
| 84 | shape = list(v.shape) |
| 85 | shape[0] = num_classes |
| 86 | mean_v = v[mapping[mapping != -1]].mean(0, keepdim=True) |
| 87 | v = torch.cat([v, mean_v], 0) |
| 88 | v = v[mapping] |
| 89 | params[name] = v |
| 90 | |
| 91 | msg = model.load_state_dict(params, strict=False) |
| 92 | logger.info(msg) |
| 93 | max_accuracy = 0.0 |
| 94 | if not config.EVAL_MODE: |
| 95 | if 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint: |
| 96 | if optimizer is not None: |
| 97 | optimizer.load_state_dict(checkpoint['optimizer']) |
| 98 | if lr_scheduler is not None: |
| 99 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
| 100 | if 'scaler' in checkpoint: |
| 101 | loss_scaler.load_state_dict(checkpoint['scaler']) |
| 102 | logger.info( |
| 103 | f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") |
| 104 | if 'max_accuracy' in checkpoint: |
| 105 | max_accuracy = checkpoint['max_accuracy'] |
| 106 | |
| 107 | if 'epoch' in checkpoint: |
| 108 | config.defrost() |
| 109 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 |
| 110 | config.freeze() |
| 111 | |
| 112 | del checkpoint |
| 113 | torch.cuda.empty_cache() |
| 114 | return max_accuracy |
no test coverage detected