MCPcopy
hub / github.com/microsoft/Swin-Transformer / load_pretrained

Function load_pretrained

utils_simmim.py:101–123  ·  view source on GitHub ↗
(config, model, logger)

Source from the content-addressed store, hash-verified

99
100
101def load_pretrained(config, model, logger):
102 logger.info(f">>>>>>>>>> Fine-tuned from {config.MODEL.PRETRAINED} ..........")
103 checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
104 checkpoint_model = checkpoint['model']
105
106 if any([True if 'encoder.' in k else False for k in checkpoint_model.keys()]):
107 checkpoint_model = {k.replace('encoder.', ''): v for k, v in checkpoint_model.items() if k.startswith('encoder.')}
108 logger.info('Detect pre-trained model, remove [encoder.] prefix.')
109 else:
110 logger.info('Detect non-pre-trained model, pass without doing anything.')
111
112 if config.MODEL.TYPE in ['swin', 'swinv2']:
113 logger.info(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........")
114 checkpoint = remap_pretrained_keys_swin(model, checkpoint_model, logger)
115 else:
116 raise NotImplementedError
117
118 msg = model.load_state_dict(checkpoint_model, strict=False)
119 logger.info(msg)
120
121 del checkpoint
122 torch.cuda.empty_cache()
123 logger.info(f">>>>>>>>>> loaded successfully '{config.MODEL.PRETRAINED}'")
124
125
126def remap_pretrained_keys_swin(model, checkpoint_model, logger):

Callers 1

mainFunction · 0.90

Calls 2

load_state_dictMethod · 0.80

Tested by

no test coverage detected