(config, model, logger)
| 115 | |
| 116 | |
| 117 | def load_pretrained(config, model, logger): |
| 118 | logger.info( |
| 119 | f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......") |
| 120 | checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') |
| 121 | state_dict = checkpoint['model'] |
| 122 | |
| 123 | # delete relative_position_index since we always re-init it |
| 124 | relative_position_index_keys = [ |
| 125 | k for k in state_dict.keys() if "relative_position_index" in k] |
| 126 | for k in relative_position_index_keys: |
| 127 | del state_dict[k] |
| 128 | |
| 129 | # delete relative_coords_table since we always re-init it |
| 130 | relative_position_index_keys = [ |
| 131 | k for k in state_dict.keys() if "relative_coords_table" in k] |
| 132 | for k in relative_position_index_keys: |
| 133 | del state_dict[k] |
| 134 | |
| 135 | # delete attn_mask since we always re-init it |
| 136 | attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] |
| 137 | for k in attn_mask_keys: |
| 138 | del state_dict[k] |
| 139 | |
| 140 | model_state_dict = model.state_dict() |
| 141 | |
| 142 | # bicubic interpolate relative_position_bias_table if not match |
| 143 | relative_position_bias_table_keys = [ |
| 144 | k for k in state_dict.keys() if "relative_position_bias_table" in k] |
| 145 | for k in relative_position_bias_table_keys: |
| 146 | relative_position_bias_table_pretrained = state_dict[k] |
| 147 | relative_position_bias_table_current = model_state_dict[k] |
| 148 | L1, nH1 = relative_position_bias_table_pretrained.size() |
| 149 | L2, nH2 = relative_position_bias_table_current.size() |
| 150 | if nH1 != nH2: |
| 151 | logger.warning(f"Error in loading {k}, passing......") |
| 152 | else: |
| 153 | if L1 != L2: |
| 154 | # bicubic interpolate relative_position_bias_table if not match |
| 155 | S1 = int(L1 ** 0.5) |
| 156 | S2 = int(L2 ** 0.5) |
| 157 | relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( |
| 158 | relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), |
| 159 | mode='bicubic') |
| 160 | state_dict[k] = relative_position_bias_table_pretrained_resized.view( |
| 161 | nH2, L2).permute(1, 0) |
| 162 | |
| 163 | # bicubic interpolate attention_biases if not match |
| 164 | relative_position_bias_table_keys = [ |
| 165 | k for k in state_dict.keys() if "attention_biases" in k] |
| 166 | for k in relative_position_bias_table_keys: |
| 167 | relative_position_bias_table_pretrained = state_dict[k] |
| 168 | relative_position_bias_table_current = model_state_dict[k] |
| 169 | nH1, L1 = relative_position_bias_table_pretrained.size() |
| 170 | nH2, L2 = relative_position_bias_table_current.size() |
| 171 | if nH1 != nH2: |
| 172 | logger.warning(f"Error in loading {k}, passing......") |
| 173 | else: |
| 174 | if L1 != L2: |
no test coverage detected