| 453 | |
| 454 | |
| 455 | def save_model(model: nn.Module, |
| 456 | optimizer: Union[nn.Module, None] = None, |
| 457 | scheduler: Union[nn.Module, None] = None, |
| 458 | moderator: Union[nn.Module, None] = None, |
| 459 | model_dir: str = '', |
| 460 | epoch: int = -1, |
| 461 | latest: int = False, |
| 462 | save_lim: int = 5, |
| 463 | ): |
| 464 | |
| 465 | model = { |
| 466 | # Special handling for ddp modules (incorrect naming) |
| 467 | 'model': model.state_dict() if not isinstance(model, DDP) else model.module.state_dict(), |
| 468 | 'epoch': epoch |
| 469 | } |
| 470 | |
| 471 | if optimizer is not None: |
| 472 | model['optimizer'] = optimizer.state_dict() |
| 473 | |
| 474 | if scheduler is not None: |
| 475 | model['scheduler'] = scheduler.state_dict() |
| 476 | |
| 477 | if moderator is not None: |
| 478 | model['moderator'] = moderator.state_dict() |
| 479 | |
| 480 | if not os.path.exists(model_dir): |
| 481 | os.makedirs(model_dir, exist_ok=True) |
| 482 | |
| 483 | model_path = join(model_dir, 'latest.pt' if latest else f'{epoch}.pt') |
| 484 | torch.save(model, model_path) |
| 485 | log(yellow(f'Saved model {blue(model_path)} at epoch {blue(epoch)}')) |
| 486 | |
| 487 | ext = '.pt' |
| 488 | pts = [ |
| 489 | int(pt.split('.')[0]) for pt in os.listdir(model_dir) if pt != f'latest{ext}' and pt.endswith(ext) and pt.split('.')[0].isnumeric() |
| 490 | ] |
| 491 | if len(pts) <= save_lim: |
| 492 | return |
| 493 | else: |
| 494 | removing = join(model_dir, f"{min(pts)}.pt") |
| 495 | # log(red(f"Removing trained weights: {blue(removing)}")) |
| 496 | os.remove(removing) |
| 497 | |
| 498 | |
| 499 | def root_of_any(k, l): |