(system_prompt: str, initial_query: str, client, model: str, request_config: dict = None, request_id: str = None)
| 5 | logger = logging.getLogger(__name__) |
| 6 | |
| 7 | def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str, request_config: dict = None, request_id: str = None) -> str: |
| 8 | logger.info(f"Starting mixture_of_agents function with model: {model}") |
| 9 | moa_completion_tokens = 0 |
| 10 | |
| 11 | # Extract max_tokens from request_config with default |
| 12 | max_tokens = 4096 |
| 13 | if request_config: |
| 14 | max_tokens = request_config.get('max_tokens', max_tokens) |
| 15 | |
| 16 | completions = [] |
| 17 | |
| 18 | logger.debug(f"Generating initial completions for query: {initial_query}") |
| 19 | |
| 20 | try: |
| 21 | # Try to generate 3 completions in a single API call using n parameter |
| 22 | provider_request = { |
| 23 | "model": model, |
| 24 | "messages": [ |
| 25 | {"role": "system", "content": system_prompt}, |
| 26 | {"role": "user", "content": initial_query} |
| 27 | ], |
| 28 | "max_tokens": max_tokens, |
| 29 | "n": 3, |
| 30 | "temperature": 1 |
| 31 | } |
| 32 | |
| 33 | response = client.chat.completions.create(**provider_request) |
| 34 | |
| 35 | # Convert response to dict for logging |
| 36 | response_dict = response.model_dump() if hasattr(response, 'model_dump') else response |
| 37 | |
| 38 | # Log provider call if conversation logging is enabled |
| 39 | if request_id: |
| 40 | conversation_logger.log_provider_call(request_id, provider_request, response_dict) |
| 41 | |
| 42 | # Check for valid response with None-checking |
| 43 | if response is None or not response.choices: |
| 44 | raise Exception("Response is None or has no choices") |
| 45 | |
| 46 | completions = [choice.message.content for choice in response.choices if choice.message.content is not None] |
| 47 | moa_completion_tokens += response.usage.completion_tokens |
| 48 | logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}") |
| 49 | |
| 50 | # Check if any valid completions were generated |
| 51 | if not completions: |
| 52 | raise Exception("No valid completions generated (all were None)") |
| 53 | |
| 54 | except Exception as e: |
| 55 | logger.warning(f"n parameter not supported by provider: {str(e)}") |
| 56 | logger.info("Falling back to generating 3 completions one by one") |
| 57 | |
| 58 | # Fallback: Generate 3 completions one by one in a loop |
| 59 | completions = [] |
| 60 | for i in range(3): |
| 61 | try: |
| 62 | provider_request = { |
| 63 | "model": model, |
| 64 | "messages": [ |
no test coverage detected