MCPcopy
hub / github.com/zju3dv/4K4D / load_network

Function load_network

easyvolcap/utils/net_utils.py:375–437  ·  view source on GitHub ↗
(
    model: nn.Module,
    model_dir: str = '',
    resume: bool = True,  # when resume is False, will try as a fresh restart
    epoch: int = -1,
    strict: bool = True,  # report errors if something is wrong
    skips: List[str] = [],
    only: List[str] = [],
    prefix: str = '',  # will match and remove these prefix
    allow_mismatch: List[str] = [],
)

Source from the content-addressed store, hash-verified

373
374
375def load_network(
376 model: nn.Module,
377 model_dir: str = '',
378 resume: bool = True, # when resume is False, will try as a fresh restart
379 epoch: int = -1,
380 strict: bool = True, # report errors if something is wrong
381 skips: List[str] = [],
382 only: List[str] = [],
383 prefix: str = '', # will match and remove these prefix
384 allow_mismatch: List[str] = [],
385):
386 pretrained, model_path = load_pretrained(model_dir, resume, epoch,
387 remove_if_not_resuming=False,
388 warn_if_not_exist=False)
389 if pretrained is None:
390 pretrained, model_path = load_pretrained(model_dir, resume, epoch, '.pth',
391 remove_if_not_resuming=False,
392 warn_if_not_exist=False)
393 if pretrained is None:
394 pretrained, model_path = load_pretrained(model_dir, resume, epoch, '.pt',
395 remove_if_not_resuming=False,
396 warn_if_not_exist=resume)
397 if pretrained is None:
398 return 0
399
400 # log(f'Loading network: {blue(model_path)}')
401 # ordered dict cannot be mutated while iterating
402 # vanilla dict cannot change size while iterating
403 pretrained_model = pretrained['model']
404
405 if skips:
406 keys = list(pretrained_model.keys())
407 for k in keys:
408 if root_of_any(k, skips):
409 del pretrained_model[k]
410
411 if only:
412 keys = list(pretrained_model.keys()) # since the dict has been mutated, some keys might not exist
413 for k in keys:
414 if not root_of_any(k, only):
415 del pretrained_model[k]
416
417 if prefix:
418 keys = list(pretrained_model.keys()) # since the dict has been mutated, some keys might not exist
419 for k in keys:
420 if k.startswith(prefix):
421 pretrained_model[k[len(prefix):]] = pretrained_model[k]
422 del pretrained_model[k]
423
424 for key in allow_mismatch:
425 if key in model.state_dict() and key in pretrained_model and not strict:
426 model_parent = model
427 pretrained_parent = pretrained_model
428 chain = key.split('.')
429 for k in chain[:-1]: # except last one
430 model_parent = getattr(model_parent, k)
431 pretrained_parent = pretrained_parent[k]
432 last_name = chain[-1]

Callers 5

load_networkMethod · 0.90
__init__Method · 0.90
__init__Method · 0.90
mainFunction · 0.90
load_pcdFunction · 0.90

Calls 7

load_pretrainedFunction · 0.85
root_of_anyFunction · 0.85
blueFunction · 0.85
splitMethod · 0.80
logFunction · 0.70
state_dictMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected