MCPcopy Index your code
hub / github.com/microsoft/Cream / load_checkpoint

Function load_checkpoint

TinyViT/utils.py:57–114  ·  view source on GitHub ↗
(config, model, optimizer, lr_scheduler, loss_scaler, logger)

Source from the content-addressed store, hash-verified

55
56
57def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger):
58 logger.info(
59 f"==============> Resuming form {config.MODEL.RESUME}....................")
60 if config.MODEL.RESUME.startswith('https'):
61 checkpoint = torch.hub.load_state_dict_from_url(
62 config.MODEL.RESUME, map_location='cpu', check_hash=True)
63 else:
64 checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
65
66 params = checkpoint['model']
67 now_model_state = model.state_dict()
68 mnames = ['head.weight', 'head.bias'] # (cls, 1024), (cls, )
69 if mnames[-1] in params:
70 ckpt_head_bias = params[mnames[-1]]
71 now_model_bias = now_model_state[mnames[-1]]
72 if ckpt_head_bias.shape != now_model_bias.shape:
73 num_classes = 1000
74
75 if len(ckpt_head_bias) == 21841 and len(now_model_bias) == num_classes:
76 logger.info("Convert checkpoint from 21841 to 1k")
77 # convert 22kto1k
78 fname = './imagenet_1kto22k.txt'
79 with open(fname) as fin:
80 mapping = torch.Tensor(
81 list(map(int, fin.readlines()))).to(torch.long)
82 for name in mnames:
83 v = params[name]
84 shape = list(v.shape)
85 shape[0] = num_classes
86 mean_v = v[mapping[mapping != -1]].mean(0, keepdim=True)
87 v = torch.cat([v, mean_v], 0)
88 v = v[mapping]
89 params[name] = v
90
91 msg = model.load_state_dict(params, strict=False)
92 logger.info(msg)
93 max_accuracy = 0.0
94 if not config.EVAL_MODE:
95 if 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint:
96 if optimizer is not None:
97 optimizer.load_state_dict(checkpoint['optimizer'])
98 if lr_scheduler is not None:
99 lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
100 if 'scaler' in checkpoint:
101 loss_scaler.load_state_dict(checkpoint['scaler'])
102 logger.info(
103 f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
104 if 'max_accuracy' in checkpoint:
105 max_accuracy = checkpoint['max_accuracy']
106
107 if 'epoch' in checkpoint:
108 config.defrost()
109 config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
110 config.freeze()
111
112 del checkpoint
113 torch.cuda.empty_cache()
114 return max_accuracy

Callers 2

mainFunction · 0.90
mainFunction · 0.90

Calls 3

toMethod · 0.80
state_dictMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected