(config, model, logger)
| 99 | |
| 100 | |
| 101 | def load_pretrained(config, model, logger): |
| 102 | logger.info(f">>>>>>>>>> Fine-tuned from {config.MODEL.PRETRAINED} ..........") |
| 103 | checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') |
| 104 | checkpoint_model = checkpoint['model'] |
| 105 | |
| 106 | if any([True if 'encoder.' in k else False for k in checkpoint_model.keys()]): |
| 107 | checkpoint_model = {k.replace('encoder.', ''): v for k, v in checkpoint_model.items() if k.startswith('encoder.')} |
| 108 | logger.info('Detect pre-trained model, remove [encoder.] prefix.') |
| 109 | else: |
| 110 | logger.info('Detect non-pre-trained model, pass without doing anything.') |
| 111 | |
| 112 | if config.MODEL.TYPE in ['swin', 'swinv2']: |
| 113 | logger.info(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........") |
| 114 | checkpoint = remap_pretrained_keys_swin(model, checkpoint_model, logger) |
| 115 | else: |
| 116 | raise NotImplementedError |
| 117 | |
| 118 | msg = model.load_state_dict(checkpoint_model, strict=False) |
| 119 | logger.info(msg) |
| 120 | |
| 121 | del checkpoint |
| 122 | torch.cuda.empty_cache() |
| 123 | logger.info(f">>>>>>>>>> loaded successfully '{config.MODEL.PRETRAINED}'") |
| 124 | |
| 125 | |
| 126 | def remap_pretrained_keys_swin(model, checkpoint_model, logger): |
no test coverage detected