Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for details. map_location (str): Sa
(model,
filename,
map_location='cpu',
strict=False,
logger=None)
| 308 | |
| 309 | |
| 310 | def load_checkpoint(model, |
| 311 | filename, |
| 312 | map_location='cpu', |
| 313 | strict=False, |
| 314 | logger=None): |
| 315 | """Load checkpoint from a file or URI. |
| 316 | |
| 317 | Args: |
| 318 | model (Module): Module to load checkpoint. |
| 319 | filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
| 320 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for |
| 321 | details. |
| 322 | map_location (str): Same as :func:`torch.load`. |
| 323 | strict (bool): Whether to allow different params for the model and |
| 324 | checkpoint. |
| 325 | logger (:mod:`logging.Logger` or None): The logger for error message. |
| 326 | |
| 327 | Returns: |
| 328 | dict or OrderedDict: The loaded checkpoint. |
| 329 | """ |
| 330 | checkpoint = _load_checkpoint(filename, map_location) |
| 331 | # OrderedDict is a subclass of dict |
| 332 | if not isinstance(checkpoint, dict): |
| 333 | raise RuntimeError( |
| 334 | f'No state_dict found in checkpoint file {filename}') |
| 335 | # get state_dict from checkpoint |
| 336 | if 'state_dict' in checkpoint: |
| 337 | state_dict = checkpoint['state_dict'] |
| 338 | elif 'model' in checkpoint: |
| 339 | state_dict = checkpoint['model'] |
| 340 | elif 'module' in checkpoint: |
| 341 | state_dict = checkpoint['module'] |
| 342 | else: |
| 343 | state_dict = checkpoint |
| 344 | # strip prefix of state_dict |
| 345 | if list(state_dict.keys())[0].startswith('module.'): |
| 346 | state_dict = {k[7:]: v for k, v in state_dict.items()} |
| 347 | |
| 348 | # for MoBY, load model of online branch |
| 349 | if sorted(list(state_dict.keys()))[0].startswith('encoder'): |
| 350 | state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} |
| 351 | |
| 352 | # reshape absolute position embedding for Swin |
| 353 | if state_dict.get('absolute_pos_embed') is not None: |
| 354 | absolute_pos_embed = state_dict['absolute_pos_embed'] |
| 355 | N1, L, C1 = absolute_pos_embed.size() |
| 356 | N2, C2, H, W = model.absolute_pos_embed.size() |
| 357 | if N1 != N2 or C1 != C2 or L != H*W: |
| 358 | logger.warning("Error in loading absolute_pos_embed, pass") |
| 359 | else: |
| 360 | state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2) |
| 361 | |
| 362 | rank, _ = get_dist_info() |
| 363 | if "rel_pos_bias.relative_position_bias_table" in state_dict: |
| 364 | if rank == 0: |
| 365 | print("Expand the shared relative position embedding to each layers. ") |
| 366 | num_layers = model.get_num_layers() |
| 367 | rel_pos_bias = state_dict["rel_pos_bias.relative_position_bias_table"] |