MCPcopy
hub / github.com/microsoft/Cream / load_model

Function load_model

EfficientViT/classification/utils.py:249–285  ·  view source on GitHub ↗

A function to load model from a checkpoint, which is used for fine-tuning on a different resolution.

(modelpath, model)

Source from the content-addressed store, hash-verified

247 replace_layernorm(child)
248
249def load_model(modelpath, model):
250 '''
251 A function to load model from a checkpoint, which is used
252 for fine-tuning on a different resolution.
253 '''
254 checkpoint = torch.load(modelpath, map_location='cpu')
255 state_dict = checkpoint['model']
256 model_state_dict = model.state_dict()
257 # bicubic interpolate attention_biases if not match
258
259 rpe_idx_keys = [
260 k for k in state_dict.keys() if "attention_bias_idxs" in k]
261 for k in rpe_idx_keys:
262 print("deleting key: ", k)
263 del state_dict[k]
264
265 relative_position_bias_table_keys = [
266 k for k in state_dict.keys() if "attention_biases" in k]
267 for k in relative_position_bias_table_keys:
268 relative_position_bias_table_pretrained = state_dict[k]
269 relative_position_bias_table_current = model_state_dict[k]
270 nH1, L1 = relative_position_bias_table_pretrained.size()
271 nH2, L2 = relative_position_bias_table_current.size()
272 if nH1 != nH2:
273 logger.warning(f"Error in loading {k} due to different number of heads")
274 else:
275 if L1 != L2:
276 # bicubic interpolate relative_position_bias_table if not match
277 S1 = int(L1 ** 0.5)
278 S2 = int(L2 ** 0.5)
279 relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
280 relative_position_bias_table_pretrained.view(1, nH1, S1, S1), size=(S2, S2),
281 mode='bicubic')
282 state_dict[k] = relative_position_bias_table_pretrained_resized.view(
283 nH2, L2)
284 checkpoint['model'] = state_dict
285 return checkpoint

Callers

nothing calls this directly

Calls 3

printFunction · 0.70
state_dictMethod · 0.45
sizeMethod · 0.45

Tested by

no test coverage detected