(
model: nn.Module,
model_dir: str = '',
resume: bool = True, # when resume is False, will try as a fresh restart
epoch: int = -1,
strict: bool = True, # report errors if something is wrong
skips: List[str] = [],
only: List[str] = [],
prefix: str = '', # will match and remove these prefix
allow_mismatch: List[str] = [],
)
| 373 | |
| 374 | |
| 375 | def load_network( |
| 376 | model: nn.Module, |
| 377 | model_dir: str = '', |
| 378 | resume: bool = True, # when resume is False, will try as a fresh restart |
| 379 | epoch: int = -1, |
| 380 | strict: bool = True, # report errors if something is wrong |
| 381 | skips: List[str] = [], |
| 382 | only: List[str] = [], |
| 383 | prefix: str = '', # will match and remove these prefix |
| 384 | allow_mismatch: List[str] = [], |
| 385 | ): |
| 386 | pretrained, model_path = load_pretrained(model_dir, resume, epoch, |
| 387 | remove_if_not_resuming=False, |
| 388 | warn_if_not_exist=False) |
| 389 | if pretrained is None: |
| 390 | pretrained, model_path = load_pretrained(model_dir, resume, epoch, '.pth', |
| 391 | remove_if_not_resuming=False, |
| 392 | warn_if_not_exist=False) |
| 393 | if pretrained is None: |
| 394 | pretrained, model_path = load_pretrained(model_dir, resume, epoch, '.pt', |
| 395 | remove_if_not_resuming=False, |
| 396 | warn_if_not_exist=resume) |
| 397 | if pretrained is None: |
| 398 | return 0 |
| 399 | |
| 400 | # log(f'Loading network: {blue(model_path)}') |
| 401 | # ordered dict cannot be mutated while iterating |
| 402 | # vanilla dict cannot change size while iterating |
| 403 | pretrained_model = pretrained['model'] |
| 404 | |
| 405 | if skips: |
| 406 | keys = list(pretrained_model.keys()) |
| 407 | for k in keys: |
| 408 | if root_of_any(k, skips): |
| 409 | del pretrained_model[k] |
| 410 | |
| 411 | if only: |
| 412 | keys = list(pretrained_model.keys()) # since the dict has been mutated, some keys might not exist |
| 413 | for k in keys: |
| 414 | if not root_of_any(k, only): |
| 415 | del pretrained_model[k] |
| 416 | |
| 417 | if prefix: |
| 418 | keys = list(pretrained_model.keys()) # since the dict has been mutated, some keys might not exist |
| 419 | for k in keys: |
| 420 | if k.startswith(prefix): |
| 421 | pretrained_model[k[len(prefix):]] = pretrained_model[k] |
| 422 | del pretrained_model[k] |
| 423 | |
| 424 | for key in allow_mismatch: |
| 425 | if key in model.state_dict() and key in pretrained_model and not strict: |
| 426 | model_parent = model |
| 427 | pretrained_parent = pretrained_model |
| 428 | chain = key.split('.') |
| 429 | for k in chain[:-1]: # except last one |
| 430 | model_parent = getattr(model_parent, k) |
| 431 | pretrained_parent = pretrained_parent[k] |
| 432 | last_name = chain[-1] |
no test coverage detected