Evaluates given checkpoint on a datamodule testset. This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc. :param cfg: DictConfig configuration composed by Hydra. :return: Tuple[
(cfg: DictConfig)
| 31 | |
| 32 | @utils.task_wrapper |
| 33 | def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: |
| 34 | """Evaluates given checkpoint on a datamodule testset. |
| 35 | |
| 36 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during |
| 37 | failure. Useful for multiruns, saving info about the crash, etc. |
| 38 | |
| 39 | :param cfg: DictConfig configuration composed by Hydra. |
| 40 | :return: Tuple[dict, dict] with metrics and dict with all instantiated objects. |
| 41 | """ |
| 42 | assert cfg.ckpt_path |
| 43 | |
| 44 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") |
| 45 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) |
| 46 | |
| 47 | log.info(f"Instantiating model <{cfg.model._target_}>") |
| 48 | model: LightningModule = hydra.utils.instantiate(cfg.model) |
| 49 | |
| 50 | log.info("Instantiating loggers...") |
| 51 | logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) |
| 52 | |
| 53 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") |
| 54 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger) |
| 55 | |
| 56 | object_dict = { |
| 57 | "cfg": cfg, |
| 58 | "datamodule": datamodule, |
| 59 | "model": model, |
| 60 | "logger": logger, |
| 61 | "trainer": trainer, |
| 62 | } |
| 63 | |
| 64 | if logger: |
| 65 | log.info("Logging hyperparameters!") |
| 66 | utils.log_hyperparameters(object_dict) |
| 67 | |
| 68 | log.info("Starting testing!") |
| 69 | trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) |
| 70 | |
| 71 | # for predictions use trainer.predict(...) |
| 72 | # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path) |
| 73 | |
| 74 | metric_dict = trainer.callback_metrics |
| 75 | |
| 76 | return metric_dict, object_dict |
| 77 | |
| 78 | |
| 79 | @hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml") |
no outgoing calls