MCPcopy
hub / github.com/microsoft/Swin-Transformer / load_checkpoint

Function load_checkpoint

utils_moe.py:31–61  ·  view source on GitHub ↗
(config, model, optimizer, lr_scheduler, loss_scaler, logger)

Source from the content-addressed store, hash-verified

29
30
31def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger):
32 global_rank = dist.get_rank()
33 logger.info(f"==============> Rank[{global_rank}] Resuming form {config.MODEL.RESUME}....................")
34 if config.MODEL.RESUME.endswith(f'.pth'):
35 if config.TRAIN.MOE.SAVE_MASTER:
36 resume_path = config.MODEL.RESUME + f'.global'
37 else:
38 resume_path = config.MODEL.RESUME + f'.rank{global_rank}'
39 logger.info(f"===> Rank[{global_rank}] Re-formatting checkpoint name to {resume_path}......")
40 else:
41 resume_path = config.MODEL.RESUME
42
43 checkpoint = torch.load(resume_path, map_location='cpu')
44 msg = model.load_state_dict(checkpoint['model'], strict=False)
45 logger.info(msg)
46 max_accuracy = 0.0
47 if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
48 optimizer.load_state_dict(checkpoint['optimizer'])
49 lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
50 config.defrost()
51 config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
52 config.freeze()
53 if 'scaler' in checkpoint:
54 loss_scaler.load_state_dict(checkpoint['scaler'])
55 logger.info(f"=>Rank[{global_rank}] loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
56 if 'max_accuracy' in checkpoint:
57 max_accuracy = checkpoint['max_accuracy']
58
59 del checkpoint
60 torch.cuda.empty_cache()
61 return max_accuracy
62
63
64def load_pretrained(config, model, logger):

Callers 1

mainFunction · 0.90

Calls 1

load_state_dictMethod · 0.80

Tested by

no test coverage detected