Extract the scorer/network name for a particular shuffle, training fraction, etc. If the engine is not specified, determines which to use from kwargs: additional arguments. For torch-based shuffles, can be used to specify: - snapshot_index - detector_snapshot_
(
cfg: dict,
shuffle: int,
trainFraction: float,
trainingsiterations: str | int = "unknown",
modelprefix: str = "",
engine: Engine | None = None,
**kwargs,
)
| 669 | |
| 670 | |
| 671 | def get_scorer_name( |
| 672 | cfg: dict, |
| 673 | shuffle: int, |
| 674 | trainFraction: float, |
| 675 | trainingsiterations: str | int = "unknown", |
| 676 | modelprefix: str = "", |
| 677 | engine: Engine | None = None, |
| 678 | **kwargs, |
| 679 | ): |
| 680 | """Extract the scorer/network name for a particular shuffle, training fraction, etc. |
| 681 | If the engine is not specified, determines which to use from |
| 682 | kwargs: additional arguments. |
| 683 | For torch-based shuffles, can be used to specify: |
| 684 | - snapshot_index |
| 685 | - detector_snapshot_index |
| 686 | |
| 687 | Returns tuple of DLCscorer, DLCscorerlegacy (old naming convention) |
| 688 | """ |
| 689 | if engine is None: |
| 690 | from deeplabcut.generate_training_dataset.metadata import get_shuffle_engine |
| 691 | |
| 692 | engine = get_shuffle_engine( |
| 693 | cfg=cfg, |
| 694 | trainingsetindex=cfg["TrainingFraction"].index(trainFraction), |
| 695 | shuffle=shuffle, |
| 696 | modelprefix=modelprefix, |
| 697 | ) |
| 698 | |
| 699 | if engine == Engine.PYTORCH: |
| 700 | from deeplabcut.pose_estimation_pytorch.apis.utils import get_scorer_name |
| 701 | |
| 702 | snapshot_index = kwargs.get("snapshot_index", None) |
| 703 | detector_snapshot_index = kwargs.get("detector_snapshot_index", None) |
| 704 | dlc3_scorer = get_scorer_name( |
| 705 | cfg=cfg, |
| 706 | shuffle=shuffle, |
| 707 | train_fraction=trainFraction, |
| 708 | snapshot_index=snapshot_index, |
| 709 | detector_index=detector_snapshot_index, |
| 710 | modelprefix=modelprefix, |
| 711 | ) |
| 712 | return dlc3_scorer, dlc3_scorer |
| 713 | |
| 714 | Task = cfg["Task"] |
| 715 | date = cfg["date"] |
| 716 | |
| 717 | if trainingsiterations == "unknown": |
| 718 | snapshotindex = get_snapshot_index_for_scorer("snapshotindex", cfg["snapshotindex"]) |
| 719 | model_folder = get_model_folder(trainFraction, shuffle, cfg, engine=engine, modelprefix=modelprefix) |
| 720 | train_folder = Path(cfg["project_path"]) / model_folder / "train" |
| 721 | snapshot_names = get_snapshots_from_folder(train_folder) |
| 722 | snapshot_name = snapshot_names[snapshotindex] |
| 723 | trainingsiterations = (snapshot_name.split(os.sep)[-1]).split("-")[-1] |
| 724 | |
| 725 | dlc_cfg = read_plainconfig( |
| 726 | os.path.join( |
| 727 | cfg["project_path"], |
| 728 | str(get_model_folder(trainFraction, shuffle, cfg, engine=engine, modelprefix=modelprefix)), |
nothing calls this directly
no test coverage detected