MCPcopy
hub / github.com/zju3dv/4K4D / save_model

Function save_model

easyvolcap/utils/net_utils.py:455–496  ·  view source on GitHub ↗
(model: nn.Module,
               optimizer: Union[nn.Module, None] = None,
               scheduler: Union[nn.Module, None] = None,
               moderator: Union[nn.Module, None] = None,
               model_dir: str = '',
               epoch: int = -1,
               latest: int = False,
               save_lim: int = 5,
               )

Source from the content-addressed store, hash-verified

453
454
455def save_model(model: nn.Module,
456 optimizer: Union[nn.Module, None] = None,
457 scheduler: Union[nn.Module, None] = None,
458 moderator: Union[nn.Module, None] = None,
459 model_dir: str = '',
460 epoch: int = -1,
461 latest: int = False,
462 save_lim: int = 5,
463 ):
464
465 model = {
466 # Special handling for ddp modules (incorrect naming)
467 'model': model.state_dict() if not isinstance(model, DDP) else model.module.state_dict(),
468 'epoch': epoch
469 }
470
471 if optimizer is not None:
472 model['optimizer'] = optimizer.state_dict()
473
474 if scheduler is not None:
475 model['scheduler'] = scheduler.state_dict()
476
477 if moderator is not None:
478 model['moderator'] = moderator.state_dict()
479
480 if not os.path.exists(model_dir):
481 os.makedirs(model_dir, exist_ok=True)
482
483 model_path = join(model_dir, 'latest.pt' if latest else f'{epoch}.pt')
484 torch.save(model, model_path)
485 log(yellow(f'Saved model {blue(model_path)} at epoch {blue(epoch)}'))
486
487 ext = '.pt'
488 pts = [
489 int(pt.split('.')[0]) for pt in os.listdir(model_dir) if pt != f'latest{ext}' and pt.endswith(ext) and pt.split('.')[0].isnumeric()
490 ]
491 if len(pts) <= save_lim:
492 return
493 else:
494 removing = join(model_dir, f"{min(pts)}.pt")
495 # log(red(f"Removing trained weights: {blue(removing)}"))
496 os.remove(removing)
497
498
499def root_of_any(k, l):

Callers 2

save_networkMethod · 0.90
save_modelMethod · 0.90

Calls 8

yellowFunction · 0.85
blueFunction · 0.85
saveMethod · 0.80
splitMethod · 0.80
logFunction · 0.70
state_dictMethod · 0.45
existsMethod · 0.45
removeMethod · 0.45

Tested by

no test coverage detected