MCPcopy
hub / github.com/hustvl/Vim / interpolate_pos_embed

Function interpolate_pos_embed

vim/utils.py:242–263  ·  view source on GitHub ↗
(model, state_dict)

Source from the content-addressed store, hash-verified

240
241# if 'pos_embed' in state_dict:
242def interpolate_pos_embed(model, state_dict):
243 pos_embed_checkpoint = state_dict['pos_embed']
244 embedding_size = pos_embed_checkpoint.shape[-1]
245 num_patches = model.patch_embed.num_patches
246 num_extra_tokens = model.pos_embed.shape[-2] - num_patches
247 # import ipdb; ipdb.set_trace()
248 # height (== width) for the checkpoint position embedding
249 orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
250 # height (== width) for the new position embedding
251 new_size = int(num_patches ** 0.5)
252 # class_token and dist_token are kept unchanged
253 if orig_size != new_size:
254 print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
255 extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
256 # only the position tokens are interpolated
257 pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
258 pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
259 pos_tokens = torch.nn.functional.interpolate(
260 pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
261 pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
262 new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
263 state_dict['pos_embed'] = new_pos_embed

Callers 2

init_weightsMethod · 0.90
init_weightsMethod · 0.90

Calls 3

printFunction · 0.85
flattenMethod · 0.45
catMethod · 0.45

Tested by

no test coverage detected