MCPcopy
hub / github.com/microsoft/Cream / load_checkpoint

Function load_checkpoint

EfficientViT/downstream/mmcv_custom/checkpoint.py:286–356  ·  view source on GitHub ↗

Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str): Sa

(model,
                    filename,
                    map_location='cpu',
                    strict=False,
                    logger=None)

Source from the content-addressed store, hash-verified

284
285
286def load_checkpoint(model,
287 filename,
288 map_location='cpu',
289 strict=False,
290 logger=None):
291 """Load checkpoint from a file or URI.
292
293 Args:
294 model (Module): Module to load checkpoint.
295 filename (str): Accept local filepath, URL, ``torchvision://xxx``,
296 ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
297 details.
298 map_location (str): Same as :func:`torch.load`.
299 strict (bool): Whether to allow different params for the model and
300 checkpoint.
301 logger (:mod:`logging.Logger` or None): The logger for error message.
302
303 Returns:
304 dict or OrderedDict: The loaded checkpoint.
305 """
306 checkpoint = _load_checkpoint(filename, map_location)
307 # OrderedDict is a subclass of dict
308 if not isinstance(checkpoint, dict):
309 raise RuntimeError(
310 f'No state_dict found in checkpoint file {filename}')
311 # get state_dict from checkpoint
312 if 'state_dict' in checkpoint:
313 state_dict = checkpoint['state_dict']
314 elif 'model' in checkpoint:
315 state_dict = checkpoint['model']
316 else:
317 state_dict = checkpoint
318 # strip prefix of state_dict
319 if list(state_dict.keys())[0].startswith('module.'):
320 state_dict = {k[7:]: v for k, v in state_dict.items()}
321
322 # for MoBY, load model of online branch
323 if sorted(list(state_dict.keys()))[0].startswith('encoder'):
324 state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
325
326 # reshape absolute position embedding
327 if state_dict.get('absolute_pos_embed') is not None:
328 absolute_pos_embed = state_dict['absolute_pos_embed']
329 N1, L, C1 = absolute_pos_embed.size()
330 N2, C2, H, W = model.absolute_pos_embed.size()
331 if N1 != N2 or C1 != C2 or L != H*W:
332 logger.warning("Error in loading absolute_pos_embed, pass")
333 else:
334 state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
335
336 # interpolate position bias table if needed
337 relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
338 for table_key in relative_position_bias_table_keys:
339 table_pretrained = state_dict[table_key]
340 table_current = model.state_dict()[table_key]
341 L1, nH1 = table_pretrained.size()
342 L2, nH2 = table_current.size()
343 if nH1 != nH2:

Callers

nothing calls this directly

Calls 5

_load_checkpointFunction · 0.70
load_state_dictFunction · 0.70
getMethod · 0.45
sizeMethod · 0.45
state_dictMethod · 0.45

Tested by

no test coverage detected