| 381 | indirect=True, |
| 382 | ) |
| 383 | def test_custom_schema(env_variables): |
| 384 | class CustomInput(BaseModel): |
| 385 | custom_input: str |
| 386 | |
| 387 | class CustomOutput(BaseModel): |
| 388 | custom_output: str |
| 389 | |
| 390 | mock_model = testing_utils.MockModel.create( |
| 391 | responses=[ |
| 392 | function_call_custom, |
| 393 | '{"custom_output": "response1"}', |
| 394 | 'response2', |
| 395 | ] |
| 396 | ) |
| 397 | |
| 398 | tool_agent = Agent( |
| 399 | name='tool_agent', |
| 400 | model=mock_model, |
| 401 | input_schema=CustomInput, |
| 402 | output_schema=CustomOutput, |
| 403 | output_key='tool_output', |
| 404 | ) |
| 405 | |
| 406 | root_agent = Agent( |
| 407 | name='root_agent', |
| 408 | model=mock_model, |
| 409 | tools=[AgentTool(agent=tool_agent)], |
| 410 | ) |
| 411 | |
| 412 | runner = testing_utils.InMemoryRunner(root_agent) |
| 413 | runner.session.state['state_1'] = 'state1_value' |
| 414 | |
| 415 | assert testing_utils.simplify_events(runner.run('test1')) == [ |
| 416 | ('root_agent', function_call_custom), |
| 417 | ('root_agent', function_response_custom), |
| 418 | ('root_agent', 'response2'), |
| 419 | ] |
| 420 | |
| 421 | assert runner.session.state['tool_output'] == {'custom_output': 'response1'} |
| 422 | |
| 423 | assert len(mock_model.requests) == 3 |
| 424 | # The second request is the tool agent request. |
| 425 | assert mock_model.requests[1].config.response_schema == CustomOutput |
| 426 | assert mock_model.requests[1].config.response_mime_type == 'application/json' |
| 427 | |
| 428 | |
| 429 | @mark.parametrize( |