Test if `turns`, `max_cost`, `max_tokens` params work as expected. The latter two are not really tested (since we need to turn off caching etc) we just make sure they don't break anything.
(
test_settings: Settings,
sequential: bool,
batch_size: Optional[int],
use_done_tool: bool,
)
| 103 | @pytest.mark.parametrize("sequential", [True, False]) |
| 104 | @pytest.mark.parametrize("use_done_tool", [True, False]) |
| 105 | def test_task_batch_turns( |
| 106 | test_settings: Settings, |
| 107 | sequential: bool, |
| 108 | batch_size: Optional[int], |
| 109 | use_done_tool: bool, |
| 110 | ): |
| 111 | """Test if `turns`, `max_cost`, `max_tokens` params work as expected. |
| 112 | The latter two are not really tested (since we need to turn off caching etc) |
| 113 | we just make sure they don't break anything. |
| 114 | """ |
| 115 | set_global(test_settings) |
| 116 | cfg = _TestChatAgentConfig() |
| 117 | |
| 118 | class _TestChatAgent(ChatAgent): |
| 119 | def handle_message_fallback( |
| 120 | self, msg: str | ChatDocument |
| 121 | ) -> str | DoneTool | None: |
| 122 | |
| 123 | if isinstance(msg, ChatDocument) and msg.metadata.sender == Entity.LLM: |
| 124 | return ( |
| 125 | DoneTool(content=str(msg.content)) |
| 126 | if use_done_tool |
| 127 | else DONE + " " + str(msg.content) |
| 128 | ) |
| 129 | |
| 130 | agent = _TestChatAgent(cfg) |
| 131 | agent.llm.reset_usage_cost() |
| 132 | task = Task( |
| 133 | agent, |
| 134 | name="Test", |
| 135 | interactive=False, |
| 136 | ) |
| 137 | |
| 138 | # run clones of this task on these inputs |
| 139 | N = 3 |
| 140 | questions = list(range(N)) |
| 141 | expected_answers = [(i + 1) for i in range(N)] |
| 142 | |
| 143 | # batch run |
| 144 | answers = run_batch_tasks( |
| 145 | task, |
| 146 | questions, |
| 147 | input_map=lambda x: str(x), # what to feed to each task |
| 148 | output_map=lambda x: x, # how to process the result of each task |
| 149 | sequential=sequential, |
| 150 | batch_size=batch_size, |
| 151 | turns=2, |
| 152 | max_cost=0.005, |
| 153 | max_tokens=100, |
| 154 | ) |
| 155 | |
| 156 | # expected_answers are simple numbers, but |
| 157 | # actual answers may be more wordy like "sum of 1 and 3 is 4", |
| 158 | # so we just check if the expected answer is contained in the actual answer |
| 159 | for e in expected_answers: |
| 160 | assert any(str(e) in a.content.lower() for a in answers) |
| 161 | |
| 162 |
nothing calls this directly
no test coverage detected
searching dependent graphs…