A function to load model from a checkpoint, which is used for fine-tuning on a different resolution.
(modelpath, model)
| 247 | replace_layernorm(child) |
| 248 | |
| 249 | def load_model(modelpath, model): |
| 250 | ''' |
| 251 | A function to load model from a checkpoint, which is used |
| 252 | for fine-tuning on a different resolution. |
| 253 | ''' |
| 254 | checkpoint = torch.load(modelpath, map_location='cpu') |
| 255 | state_dict = checkpoint['model'] |
| 256 | model_state_dict = model.state_dict() |
| 257 | # bicubic interpolate attention_biases if not match |
| 258 | |
| 259 | rpe_idx_keys = [ |
| 260 | k for k in state_dict.keys() if "attention_bias_idxs" in k] |
| 261 | for k in rpe_idx_keys: |
| 262 | print("deleting key: ", k) |
| 263 | del state_dict[k] |
| 264 | |
| 265 | relative_position_bias_table_keys = [ |
| 266 | k for k in state_dict.keys() if "attention_biases" in k] |
| 267 | for k in relative_position_bias_table_keys: |
| 268 | relative_position_bias_table_pretrained = state_dict[k] |
| 269 | relative_position_bias_table_current = model_state_dict[k] |
| 270 | nH1, L1 = relative_position_bias_table_pretrained.size() |
| 271 | nH2, L2 = relative_position_bias_table_current.size() |
| 272 | if nH1 != nH2: |
| 273 | logger.warning(f"Error in loading {k} due to different number of heads") |
| 274 | else: |
| 275 | if L1 != L2: |
| 276 | # bicubic interpolate relative_position_bias_table if not match |
| 277 | S1 = int(L1 ** 0.5) |
| 278 | S2 = int(L2 ** 0.5) |
| 279 | relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( |
| 280 | relative_position_bias_table_pretrained.view(1, nH1, S1, S1), size=(S2, S2), |
| 281 | mode='bicubic') |
| 282 | state_dict[k] = relative_position_bias_table_pretrained_resized.view( |
| 283 | nH2, L2) |
| 284 | checkpoint['model'] = state_dict |
| 285 | return checkpoint |
nothing calls this directly
no test coverage detected