A GEPA adapter for evaluating agent performance on tau-bench benchmark.
| 251 | |
| 252 | |
| 253 | class TauBenchAdapter( |
| 254 | GEPAAdapter[ |
| 255 | TauBenchDataInst, |
| 256 | TauBenchTrajectory, |
| 257 | TauBenchRolloutOutput, |
| 258 | ] |
| 259 | ): |
| 260 | """A GEPA adapter for evaluating agent performance on tau-bench benchmark.""" |
| 261 | |
| 262 | def __init__( |
| 263 | self, |
| 264 | env_name: str, |
| 265 | agent_model: str = 'gemini-2.5-flash', |
| 266 | agent_model_provider: str = 'vertex_ai', |
| 267 | user_model: str = 'gemini-2.5-pro', |
| 268 | user_model_provider: str = 'vertex_ai', |
| 269 | agent_strategy: str = 'tool-calling', |
| 270 | user_strategy: str = 'llm', |
| 271 | system_instruction_name: str = 'system_instruction', |
| 272 | max_concurrency: int = 4, |
| 273 | rater: rater_lib.Rater | None = None, |
| 274 | log_dir: str | None = None, |
| 275 | ): |
| 276 | """Initializes the TauBenchAdapter. |
| 277 | |
| 278 | Args: |
| 279 | env_name: environment |
| 280 | agent_model: The model to use for the agent. |
| 281 | agent_model_provider: The provider for the agent model. |
| 282 | user_model: The model to use for simulating the user. |
| 283 | user_model_provider: The provider for the user model. |
| 284 | agent_strategy: The agent strategy to use (e.g., 'tool-calling'). |
| 285 | user_strategy: The user simulation strategy (e.g., 'llm'). |
| 286 | system_instruction_name: The key in the candidate dictionary that holds |
| 287 | the system instruction. |
| 288 | max_concurrency: The maximum number of tasks to run in parallel. |
| 289 | rater: An optional rater to evaluate the agent's performance. |
| 290 | log_dir: The directory to save traces and other logs. |
| 291 | """ |
| 292 | self._env_name = env_name |
| 293 | self._agent_model = agent_model |
| 294 | self._agent_model_provider = agent_model_provider |
| 295 | self._user_model = user_model |
| 296 | self._user_model_provider = user_model_provider |
| 297 | self._agent_strategy = agent_strategy |
| 298 | self._user_strategy = user_strategy |
| 299 | self._max_concurrency = max_concurrency |
| 300 | self._system_instruction_name = system_instruction_name |
| 301 | self._rater = rater |
| 302 | self._log_dir = log_dir |
| 303 | |
| 304 | def evaluate( |
| 305 | self, |
| 306 | batch: list[TauBenchDataInst], |
| 307 | candidate: dict[str, str], |
| 308 | capture_traces: bool = False, |
| 309 | ) -> EvaluationBatch[TauBenchTrajectory, TauBenchRolloutOutput]: |
| 310 | """Evaluates a candidate prompt on a batch of tau-bench tasks. |