Runs evaluation on the test set using the given instructions. Args: output_dir: The directory to save evaluation results. instructions: The system instructions to evaluate. config: The experiment configuration.
(output_dir: str, instructions: str, config: ExperimentConfig)
| 594 | |
| 595 | |
| 596 | def run_eval(output_dir: str, instructions: str, config: ExperimentConfig): |
| 597 | """Runs evaluation on the test set using the given instructions. |
| 598 | |
| 599 | Args: |
| 600 | output_dir: The directory to save evaluation results. |
| 601 | instructions: The system instructions to evaluate. |
| 602 | config: The experiment configuration. |
| 603 | """ |
| 604 | eval_dataset = _get_dataset(config.eval_dataset) |
| 605 | tau_bench_run_config = RunConfig( |
| 606 | env=config.tau_bench_env, |
| 607 | model=config.agent_model, |
| 608 | model_provider=config.agent_model_provider, |
| 609 | user_model=config.user_model, |
| 610 | user_model_provider=config.user_model_provider, |
| 611 | agent_strategy='tool-calling', |
| 612 | user_strategy='llm', |
| 613 | max_concurrency=config.max_concurrency, |
| 614 | num_trials=config.num_eval_trials, |
| 615 | task_ids=eval_dataset, |
| 616 | log_dir=output_dir, |
| 617 | task_split=config.eval_dataset.split, |
| 618 | ) |
| 619 | with open(os.path.join(output_dir, 'prompt.txt'), 'w') as f: |
| 620 | f.write(instructions) |
| 621 | |
| 622 | json.dump( |
| 623 | tau_bench_run_config.model_dump(), |
| 624 | open(os.path.join(output_dir, 'run_config.json'), 'w'), |
| 625 | ) |
| 626 | tau_bench_results = run_tau_bench_rollouts( |
| 627 | tau_bench_run_config, |
| 628 | system_instruction=instructions, |
| 629 | rater=_rater(config) if config.use_rater else None, |
| 630 | ) |
| 631 | total = len(tau_bench_results) |
| 632 | numerator = sum(1 for res in tau_bench_results if res.reward == 1) |
| 633 | print( |
| 634 | f'average reward (total={total}): {numerator/total if total > 0 else 0}' |
| 635 | ) |
| 636 | json.dump( |
| 637 | dict(results=[r.model_dump() for r in tau_bench_results]), |
| 638 | open(os.path.join(output_dir, 'results.json'), 'w'), |
| 639 | ) |
no test coverage detected