| 1259 | |
| 1260 | |
| 1261 | def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): |
| 1262 | # Rescale the grid of position embeddings when loading from state_dict |
| 1263 | old_pos_embed = state_dict.get('visual.positional_embedding', None) |
| 1264 | if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): |
| 1265 | return |
| 1266 | grid_size = to_2tuple(model.visual.grid_size) |
| 1267 | # FIXME detect different token configs (ie no class token, or more) |
| 1268 | extra_tokens = 1 |
| 1269 | new_seq_len = grid_size[0] * grid_size[1] + extra_tokens |
| 1270 | if new_seq_len == old_pos_embed.shape[0]: |
| 1271 | return |
| 1272 | |
| 1273 | if extra_tokens: |
| 1274 | pos_emb_tok, pos_emb_img = old_pos_embed[: |
| 1275 | extra_tokens], old_pos_embed[extra_tokens:] |
| 1276 | else: |
| 1277 | pos_emb_tok, pos_emb_img = None, old_pos_embed |
| 1278 | old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) |
| 1279 | |
| 1280 | logging.info('Resizing position embedding grid-size from %s to %s', |
| 1281 | old_grid_size, grid_size) |
| 1282 | pos_emb_img = pos_emb_img.reshape( |
| 1283 | 1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) |
| 1284 | pos_emb_img = F.interpolate( |
| 1285 | pos_emb_img, |
| 1286 | size=grid_size, |
| 1287 | mode=interpolation, |
| 1288 | align_corners=True, |
| 1289 | ) |
| 1290 | pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape( |
| 1291 | 1, grid_size[0] * grid_size[1], -1)[0] |
| 1292 | if pos_emb_tok is not None: |
| 1293 | new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) |
| 1294 | else: |
| 1295 | new_pos_embed = pos_emb_img |
| 1296 | state_dict['visual.positional_embedding'] = new_pos_embed |
| 1297 | |
| 1298 | |
| 1299 | @torch.no_grad() |