MCPcopy
hub / github.com/ashleve/lightning-hydra-template / evaluate

Function evaluate

src/eval.py:33–76  ·  view source on GitHub ↗

Evaluates given checkpoint on a datamodule testset. This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc. :param cfg: DictConfig configuration composed by Hydra. :return: Tuple[

(cfg: DictConfig)

Source from the content-addressed store, hash-verified

31
32@utils.task_wrapper
33def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
34 """Evaluates given checkpoint on a datamodule testset.
35
36 This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
37 failure. Useful for multiruns, saving info about the crash, etc.
38
39 :param cfg: DictConfig configuration composed by Hydra.
40 :return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
41 """
42 assert cfg.ckpt_path
43
44 log.info(f"Instantiating datamodule <{cfg.data._target_}>")
45 datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
46
47 log.info(f"Instantiating model <{cfg.model._target_}>")
48 model: LightningModule = hydra.utils.instantiate(cfg.model)
49
50 log.info("Instantiating loggers...")
51 logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
52
53 log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
54 trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
55
56 object_dict = {
57 "cfg": cfg,
58 "datamodule": datamodule,
59 "model": model,
60 "logger": logger,
61 "trainer": trainer,
62 }
63
64 if logger:
65 log.info("Logging hyperparameters!")
66 utils.log_hyperparameters(object_dict)
67
68 log.info("Starting testing!")
69 trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
70
71 # for predictions use trainer.predict(...)
72 # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
73
74 metric_dict = trainer.callback_metrics
75
76 return metric_dict, object_dict
77
78
79@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")

Callers 2

test_train_evalFunction · 0.90
mainFunction · 0.85

Calls

no outgoing calls

Tested by 1

test_train_evalFunction · 0.72