MCPcopy
hub / github.com/microsoft/Swin-Transformer / load_checkpoint

Function load_checkpoint

utils.py:18–42  ·  view source on GitHub ↗
(config, model, optimizer, lr_scheduler, loss_scaler, logger)

Source from the content-addressed store, hash-verified

16
17
18def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger):
19 logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
20 if config.MODEL.RESUME.startswith('https'):
21 checkpoint = torch.hub.load_state_dict_from_url(
22 config.MODEL.RESUME, map_location='cpu', check_hash=True)
23 else:
24 checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
25 msg = model.load_state_dict(checkpoint['model'], strict=False)
26 logger.info(msg)
27 max_accuracy = 0.0
28 if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
29 optimizer.load_state_dict(checkpoint['optimizer'])
30 lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
31 config.defrost()
32 config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
33 config.freeze()
34 if 'scaler' in checkpoint:
35 loss_scaler.load_state_dict(checkpoint['scaler'])
36 logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
37 if 'max_accuracy' in checkpoint:
38 max_accuracy = checkpoint['max_accuracy']
39
40 del checkpoint
41 torch.cuda.empty_cache()
42 return max_accuracy
43
44
45def load_pretrained(config, model, logger):

Callers 1

mainFunction · 0.90

Calls 1

load_state_dictMethod · 0.80

Tested by

no test coverage detected