| 525 | ) |
| 526 | |
| 527 | async def unified_call( |
| 528 | self, |
| 529 | system_message="", |
| 530 | user_message="", |
| 531 | messages: List[BaseMessage] | None = None, |
| 532 | response_callback: Callable[[str, str], Awaitable[str | None]] | None = None, |
| 533 | reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None, |
| 534 | tokens_callback: Callable[[str, int], Awaitable[None]] | None = None, |
| 535 | rate_limiter_callback: ( |
| 536 | Callable[[str, str, int, int], Awaitable[bool]] | None |
| 537 | ) = None, |
| 538 | explicit_caching: bool = False, |
| 539 | **kwargs: Any, |
| 540 | ) -> Tuple[str, str]: |
| 541 | |
| 542 | configure_litellm() |
| 543 | |
| 544 | if not messages: |
| 545 | messages = [] |
| 546 | # construct messages |
| 547 | if system_message: |
| 548 | messages.insert(0, SystemMessage(content=system_message)) |
| 549 | if user_message: |
| 550 | messages.append(HumanMessage(content=user_message)) |
| 551 | |
| 552 | # convert to litellm format |
| 553 | msgs_conv = self._convert_messages(messages, explicit_caching=explicit_caching) |
| 554 | |
| 555 | # Apply rate limiting if configured |
| 556 | limiter = await apply_rate_limiter( |
| 557 | self.a0_model_conf, str(msgs_conv), rate_limiter_callback |
| 558 | ) |
| 559 | |
| 560 | # Prepare call kwargs and retry config (strip A0-only params before calling LiteLLM) |
| 561 | call_kwargs: dict[str, Any] = _merge_litellm_call_kwargs( |
| 562 | self.kwargs, kwargs |
| 563 | ) |
| 564 | if explicit_caching: |
| 565 | call_kwargs["a0_explicit_prompt_caching"] = True |
| 566 | max_retries: int = int(call_kwargs.pop("a0_retry_attempts", 2)) |
| 567 | retry_delay_s: float = float(call_kwargs.pop("a0_retry_delay_seconds", 1.5)) |
| 568 | stream = reasoning_callback is not None or response_callback is not None or tokens_callback is not None |
| 569 | transport = LiteLLMTransport( |
| 570 | model=self.model_name, |
| 571 | messages=msgs_conv, |
| 572 | kwargs=call_kwargs, |
| 573 | ) |
| 574 | |
| 575 | # results |
| 576 | result = ChatGenerationResult() |
| 577 | |
| 578 | attempt = 0 |
| 579 | while True: |
| 580 | got_any_chunk = False |
| 581 | try: |
| 582 | if stream: |
| 583 | stop_response: str | None = None |
| 584 | async for parsed in transport.astream(): |