| 16 | return text |
| 17 | |
| 18 | def round_trip_optimization(system_prompt: str, initial_query: str, client, model: str, request_config: dict = None, request_id: str = None) -> str: |
| 19 | rto_completion_tokens = 0 |
| 20 | |
| 21 | # Extract max_tokens from request_config with default |
| 22 | max_tokens = 4096 |
| 23 | if request_config: |
| 24 | max_tokens = request_config.get('max_tokens', max_tokens) |
| 25 | |
| 26 | messages = [{"role": "system", "content": system_prompt}, |
| 27 | {"role": "user", "content": initial_query}] |
| 28 | |
| 29 | # Generate initial code (C1) |
| 30 | provider_request = { |
| 31 | "model": model, |
| 32 | "messages": messages, |
| 33 | "max_tokens": max_tokens, |
| 34 | "n": 1, |
| 35 | "temperature": 0.1 |
| 36 | } |
| 37 | response_c1 = client.chat.completions.create(**provider_request) |
| 38 | |
| 39 | # Log provider call |
| 40 | if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id: |
| 41 | response_dict = response_c1.model_dump() if hasattr(response_c1, 'model_dump') else response_c1 |
| 42 | optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict) |
| 43 | |
| 44 | if (response_c1 is None or |
| 45 | not response_c1.choices or |
| 46 | response_c1.choices[0].message.content is None or |
| 47 | response_c1.choices[0].finish_reason == "length"): |
| 48 | raise Exception("RTO: provider returned an empty, None, or truncated response for the initial code (C1)") |
| 49 | c1 = response_c1.choices[0].message.content |
| 50 | rto_completion_tokens += response_c1.usage.completion_tokens |
| 51 | |
| 52 | # Generate description of the code (Q2) |
| 53 | messages.append({"role": "assistant", "content": c1}) |
| 54 | messages.append({"role": "user", "content": "Summarize or describe the code you just created. The summary should be in form of an instruction such that, given the instruction you can create the code yourself."}) |
| 55 | provider_request = { |
| 56 | "model": model, |
| 57 | "messages": messages, |
| 58 | "max_tokens": 1024, |
| 59 | "n": 1, |
| 60 | "temperature": 0.1 |
| 61 | } |
| 62 | response_q2 = client.chat.completions.create(**provider_request) |
| 63 | |
| 64 | # Log provider call |
| 65 | if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and request_id: |
| 66 | response_dict = response_q2.model_dump() if hasattr(response_q2, 'model_dump') else response_q2 |
| 67 | optillm.conversation_logger.log_provider_call(request_id, provider_request, response_dict) |
| 68 | |
| 69 | if (response_q2 is None or |
| 70 | not response_q2.choices or |
| 71 | response_q2.choices[0].message.content is None or |
| 72 | response_q2.choices[0].finish_reason == "length"): |
| 73 | raise Exception("RTO: provider returned an empty, None, or truncated response for the description (Q2)") |
| 74 | q2 = response_q2.choices[0].message.content |
| 75 | rto_completion_tokens += response_q2.usage.completion_tokens |