MCPcopy
hub / github.com/microsoft/Swin-Transformer / load_checkpoint

Function load_checkpoint

utils_simmim.py:16–50  ·  view source on GitHub ↗
(config, model, optimizer, lr_scheduler, scaler, logger)

Source from the content-addressed store, hash-verified

14
15
16def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger):
17 logger.info(f">>>>>>>>>> Resuming from {config.MODEL.RESUME} ..........")
18 if config.MODEL.RESUME.startswith('https'):
19 checkpoint = torch.hub.load_state_dict_from_url(
20 config.MODEL.RESUME, map_location='cpu', check_hash=True)
21 else:
22 checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
23
24 # re-map keys due to name change (only for loading provided models)
25 rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k]
26 for k in rpe_mlp_keys:
27 checkpoint['model'][k.replace('rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k)
28
29 msg = model.load_state_dict(checkpoint['model'], strict=False)
30 logger.info(msg)
31
32 max_accuracy = 0.0
33 if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'scaler' in checkpoint and 'epoch' in checkpoint:
34 optimizer.load_state_dict(checkpoint['optimizer'])
35 lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
36 scaler.load_state_dict(checkpoint['scaler'])
37
38 config.defrost()
39 config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
40 config.freeze()
41
42 logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
43 if 'max_accuracy' in checkpoint:
44 max_accuracy = checkpoint['max_accuracy']
45 else:
46 max_accuracy = 0.0
47
48 del checkpoint
49 torch.cuda.empty_cache()
50 return max_accuracy
51
52
53def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, scaler, logger):

Callers 2

mainFunction · 0.90
mainFunction · 0.90

Calls 1

load_state_dictMethod · 0.80

Tested by

no test coverage detected