(self, solver: Solver, sample: Sample, rng: random.Random)
| 58 | ).strip() |
| 59 | |
| 60 | def eval_sample(self, solver: Solver, sample: Sample, rng: random.Random) -> None: |
| 61 | message: Message = self._build_message(sample) |
| 62 | |
| 63 | task_state = TaskState( |
| 64 | task_description=self.task_description, |
| 65 | messages=[message], |
| 66 | # to be used by the Random baseline solver only |
| 67 | current_state={"variables": [var for var in sample.causal_graph.nodes]}, |
| 68 | ) |
| 69 | |
| 70 | solver_result: SolverResult = solver(task_state) |
| 71 | |
| 72 | try: |
| 73 | preds = parse_solver_preds(solver_result) |
| 74 | except ValueError: # in case of invalid solver output |
| 75 | preds = None |
| 76 | gold, num_not_ctrl = sample.gold_label, sample.num_not_ctrl |
| 77 | |
| 78 | metrics: Dict[str, float] = self._evaluate_sample(preds, gold, num_not_ctrl) |
| 79 | |
| 80 | record_metrics( |
| 81 | **metrics, |
| 82 | # hack: logviz doesn't support custom log fields, so logging as metric |
| 83 | causal_graph=nx.to_dict_of_lists(sample.causal_graph), |
| 84 | gold_answer=asdict(gold), |
| 85 | n_hyps=sample.hypotheses.number_of_edges(), |
| 86 | valid_hyp=gold.valid_hypothesis, |
| 87 | num_not_ctrl=num_not_ctrl, |
| 88 | ) |
| 89 | |
| 90 | def run(self, recorder: RecorderBase) -> Dict[str, float]: |
| 91 | samples: List[Dict] = self._get_samples() |
nothing calls this directly
no test coverage detected