(model, state_dict)
| 240 | |
| 241 | # if 'pos_embed' in state_dict: |
| 242 | def interpolate_pos_embed(model, state_dict): |
| 243 | pos_embed_checkpoint = state_dict['pos_embed'] |
| 244 | embedding_size = pos_embed_checkpoint.shape[-1] |
| 245 | num_patches = model.patch_embed.num_patches |
| 246 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches |
| 247 | # import ipdb; ipdb.set_trace() |
| 248 | # height (== width) for the checkpoint position embedding |
| 249 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) |
| 250 | # height (== width) for the new position embedding |
| 251 | new_size = int(num_patches ** 0.5) |
| 252 | # class_token and dist_token are kept unchanged |
| 253 | if orig_size != new_size: |
| 254 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) |
| 255 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] |
| 256 | # only the position tokens are interpolated |
| 257 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] |
| 258 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) |
| 259 | pos_tokens = torch.nn.functional.interpolate( |
| 260 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) |
| 261 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) |
| 262 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) |
| 263 | state_dict['pos_embed'] = new_pos_embed |
no test coverage detected