(config, model, logger)
| 43 | |
| 44 | |
| 45 | def load_pretrained(config, model, logger): |
| 46 | logger.info(f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......") |
| 47 | checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') |
| 48 | state_dict = checkpoint['model'] |
| 49 | |
| 50 | # delete relative_position_index since we always re-init it |
| 51 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] |
| 52 | for k in relative_position_index_keys: |
| 53 | del state_dict[k] |
| 54 | |
| 55 | # delete relative_coords_table since we always re-init it |
| 56 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k] |
| 57 | for k in relative_position_index_keys: |
| 58 | del state_dict[k] |
| 59 | |
| 60 | # delete attn_mask since we always re-init it |
| 61 | attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] |
| 62 | for k in attn_mask_keys: |
| 63 | del state_dict[k] |
| 64 | |
| 65 | # bicubic interpolate relative_position_bias_table if not match |
| 66 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] |
| 67 | for k in relative_position_bias_table_keys: |
| 68 | relative_position_bias_table_pretrained = state_dict[k] |
| 69 | relative_position_bias_table_current = model.state_dict()[k] |
| 70 | L1, nH1 = relative_position_bias_table_pretrained.size() |
| 71 | L2, nH2 = relative_position_bias_table_current.size() |
| 72 | if nH1 != nH2: |
| 73 | logger.warning(f"Error in loading {k}, passing......") |
| 74 | else: |
| 75 | if L1 != L2: |
| 76 | # bicubic interpolate relative_position_bias_table if not match |
| 77 | S1 = int(L1 ** 0.5) |
| 78 | S2 = int(L2 ** 0.5) |
| 79 | relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( |
| 80 | relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), |
| 81 | mode='bicubic') |
| 82 | state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) |
| 83 | |
| 84 | # bicubic interpolate absolute_pos_embed if not match |
| 85 | absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k] |
| 86 | for k in absolute_pos_embed_keys: |
| 87 | # dpe |
| 88 | absolute_pos_embed_pretrained = state_dict[k] |
| 89 | absolute_pos_embed_current = model.state_dict()[k] |
| 90 | _, L1, C1 = absolute_pos_embed_pretrained.size() |
| 91 | _, L2, C2 = absolute_pos_embed_current.size() |
| 92 | if C1 != C1: |
| 93 | logger.warning(f"Error in loading {k}, passing......") |
| 94 | else: |
| 95 | if L1 != L2: |
| 96 | S1 = int(L1 ** 0.5) |
| 97 | S2 = int(L2 ** 0.5) |
| 98 | absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) |
| 99 | absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) |
| 100 | absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( |
| 101 | absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') |
| 102 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1) |
no test coverage detected