MCPcopy Index your code
hub / github.com/microsoft/Cream / resize_pos_embed

Function resize_pos_embed

TinyCLIP/src/open_clip/model.py:1261–1296  ·  view source on GitHub ↗
(state_dict, model, interpolation: str = 'bicubic', seq_dim=1)

Source from the content-addressed store, hash-verified

1259
1260
1261def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
1262 # Rescale the grid of position embeddings when loading from state_dict
1263 old_pos_embed = state_dict.get('visual.positional_embedding', None)
1264 if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
1265 return
1266 grid_size = to_2tuple(model.visual.grid_size)
1267 # FIXME detect different token configs (ie no class token, or more)
1268 extra_tokens = 1
1269 new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
1270 if new_seq_len == old_pos_embed.shape[0]:
1271 return
1272
1273 if extra_tokens:
1274 pos_emb_tok, pos_emb_img = old_pos_embed[:
1275 extra_tokens], old_pos_embed[extra_tokens:]
1276 else:
1277 pos_emb_tok, pos_emb_img = None, old_pos_embed
1278 old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
1279
1280 logging.info('Resizing position embedding grid-size from %s to %s',
1281 old_grid_size, grid_size)
1282 pos_emb_img = pos_emb_img.reshape(
1283 1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
1284 pos_emb_img = F.interpolate(
1285 pos_emb_img,
1286 size=grid_size,
1287 mode=interpolation,
1288 align_corners=True,
1289 )
1290 pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(
1291 1, grid_size[0] * grid_size[1], -1)[0]
1292 if pos_emb_tok is not None:
1293 new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
1294 else:
1295 new_pos_embed = pos_emb_img
1296 state_dict['visual.positional_embedding'] = new_pos_embed
1297
1298
1299@torch.no_grad()

Callers 1

load_checkpointFunction · 0.90

Calls 1

getMethod · 0.45

Tested by

no test coverage detected