Returns Tau-bench dataset splits.
(
config: ExperimentConfig,
)
| 458 | |
| 459 | |
| 460 | def _get_datasets( |
| 461 | config: ExperimentConfig, |
| 462 | ) -> dict[str, list[int]]: |
| 463 | """Returns Tau-bench dataset splits.""" |
| 464 | random.seed(config.rnd_seed) |
| 465 | train_task_ids = _get_dataset(config.feedback_dataset) |
| 466 | eval_task_ids = _get_dataset(config.pareto_dataset) |
| 467 | test_task_ids = _get_dataset(config.eval_dataset) |
| 468 | logging.info( |
| 469 | 'Using datasets of size: train=%d, eval=%d, test=%d', |
| 470 | len(train_task_ids), |
| 471 | len(eval_task_ids), |
| 472 | len(test_task_ids), |
| 473 | ) |
| 474 | return dict( |
| 475 | train=train_task_ids, |
| 476 | dev=eval_task_ids, |
| 477 | test=test_task_ids, |
| 478 | ) |
| 479 | |
| 480 | |
| 481 | SEED_SYSTEM_INSTRUCTION = ( |
no test coverage detected