Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str): Sa
(model,
filename,
map_location='cpu',
strict=False,
logger=None)
| 284 | |
| 285 | |
| 286 | def load_checkpoint(model, |
| 287 | filename, |
| 288 | map_location='cpu', |
| 289 | strict=False, |
| 290 | logger=None): |
| 291 | """Load checkpoint from a file or URI. |
| 292 | |
| 293 | Args: |
| 294 | model (Module): Module to load checkpoint. |
| 295 | filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
| 296 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for |
| 297 | details. |
| 298 | map_location (str): Same as :func:`torch.load`. |
| 299 | strict (bool): Whether to allow different params for the model and |
| 300 | checkpoint. |
| 301 | logger (:mod:`logging.Logger` or None): The logger for error message. |
| 302 | |
| 303 | Returns: |
| 304 | dict or OrderedDict: The loaded checkpoint. |
| 305 | """ |
| 306 | checkpoint = _load_checkpoint(filename, map_location) |
| 307 | # OrderedDict is a subclass of dict |
| 308 | if not isinstance(checkpoint, dict): |
| 309 | raise RuntimeError( |
| 310 | f'No state_dict found in checkpoint file {filename}') |
| 311 | # get state_dict from checkpoint |
| 312 | if 'state_dict' in checkpoint: |
| 313 | state_dict = checkpoint['state_dict'] |
| 314 | elif 'model' in checkpoint: |
| 315 | state_dict = checkpoint['model'] |
| 316 | else: |
| 317 | state_dict = checkpoint |
| 318 | # strip prefix of state_dict |
| 319 | if list(state_dict.keys())[0].startswith('module.'): |
| 320 | state_dict = {k[7:]: v for k, v in state_dict.items()} |
| 321 | |
| 322 | # for MoBY, load model of online branch |
| 323 | if sorted(list(state_dict.keys()))[0].startswith('encoder'): |
| 324 | state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} |
| 325 | |
| 326 | # reshape absolute position embedding |
| 327 | if state_dict.get('absolute_pos_embed') is not None: |
| 328 | absolute_pos_embed = state_dict['absolute_pos_embed'] |
| 329 | N1, L, C1 = absolute_pos_embed.size() |
| 330 | N2, C2, H, W = model.absolute_pos_embed.size() |
| 331 | if N1 != N2 or C1 != C2 or L != H*W: |
| 332 | logger.warning("Error in loading absolute_pos_embed, pass") |
| 333 | else: |
| 334 | state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2) |
| 335 | |
| 336 | # interpolate position bias table if needed |
| 337 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] |
| 338 | for table_key in relative_position_bias_table_keys: |
| 339 | table_pretrained = state_dict[table_key] |
| 340 | table_current = model.state_dict()[table_key] |
| 341 | L1, nH1 = table_pretrained.size() |
| 342 | L2, nH2 = table_current.size() |
| 343 | if nH1 != nH2: |
nothing calls this directly
no test coverage detected