MCPcopy
hub / github.com/hustvl/Vim / load_checkpoint

Function load_checkpoint

seg/mmcv_custom/checkpoint.py:310–483  ·  view source on GitHub ↗

Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str): Sa

(model,
                    filename,
                    map_location='cpu',
                    strict=False,
                    logger=None)

Source from the content-addressed store, hash-verified

308
309
310def load_checkpoint(model,
311 filename,
312 map_location='cpu',
313 strict=False,
314 logger=None):
315 """Load checkpoint from a file or URI.
316
317 Args:
318 model (Module): Module to load checkpoint.
319 filename (str): Accept local filepath, URL, ``torchvision://xxx``,
320 ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
321 details.
322 map_location (str): Same as :func:`torch.load`.
323 strict (bool): Whether to allow different params for the model and
324 checkpoint.
325 logger (:mod:`logging.Logger` or None): The logger for error message.
326
327 Returns:
328 dict or OrderedDict: The loaded checkpoint.
329 """
330 checkpoint = _load_checkpoint(filename, map_location)
331 # OrderedDict is a subclass of dict
332 if not isinstance(checkpoint, dict):
333 raise RuntimeError(
334 f'No state_dict found in checkpoint file {filename}')
335 # get state_dict from checkpoint
336 if 'state_dict' in checkpoint:
337 state_dict = checkpoint['state_dict']
338 elif 'model' in checkpoint:
339 state_dict = checkpoint['model']
340 elif 'module' in checkpoint:
341 state_dict = checkpoint['module']
342 else:
343 state_dict = checkpoint
344 # strip prefix of state_dict
345 if list(state_dict.keys())[0].startswith('module.'):
346 state_dict = {k[7:]: v for k, v in state_dict.items()}
347
348 # for MoBY, load model of online branch
349 if sorted(list(state_dict.keys()))[0].startswith('encoder'):
350 state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
351
352 # reshape absolute position embedding for Swin
353 if state_dict.get('absolute_pos_embed') is not None:
354 absolute_pos_embed = state_dict['absolute_pos_embed']
355 N1, L, C1 = absolute_pos_embed.size()
356 N2, C2, H, W = model.absolute_pos_embed.size()
357 if N1 != N2 or C1 != C2 or L != H*W:
358 logger.warning("Error in loading absolute_pos_embed, pass")
359 else:
360 state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
361
362 rank, _ = get_dist_info()
363 if "rel_pos_bias.relative_position_bias_table" in state_dict:
364 if rank == 0:
365 print("Expand the shared relative position embedding to each layers. ")
366 num_layers = model.get_num_layers()
367 rel_pos_bias = state_dict["rel_pos_bias.relative_position_bias_table"]

Callers 1

mainFunction · 0.85

Calls 13

_load_checkpointFunction · 0.85
printFunction · 0.85
geometric_progressionFunction · 0.85
load_state_dictFunction · 0.85
fFunction · 0.50
getMethod · 0.45
get_num_layersMethod · 0.45
cloneMethod · 0.45
state_dictMethod · 0.45
interp2dMethod · 0.45
toMethod · 0.45
catMethod · 0.45

Tested by 1

mainFunction · 0.68