(system_prompt, initial_query, client, model, **kwargs)
| 90 | return APPROACHES[predicted_approach_index], confidence |
| 91 | |
| 92 | def run(system_prompt, initial_query, client, model, **kwargs): |
| 93 | try: |
| 94 | # Load the trained model |
| 95 | router_model, tokenizer, device = load_optillm_model() |
| 96 | |
| 97 | # Preprocess the input |
| 98 | input_ids, attention_mask = preprocess_input(tokenizer, system_prompt, initial_query) |
| 99 | |
| 100 | # Predict the best approach |
| 101 | predicted_approach, _ = predict_approach(router_model, input_ids, attention_mask, device) |
| 102 | |
| 103 | print(f"Router predicted approach: {predicted_approach}") |
| 104 | |
| 105 | # Route to the appropriate approach or use the model directly |
| 106 | if predicted_approach == "none": |
| 107 | # Use the model directly without routing |
| 108 | response = client.chat.completions.create( |
| 109 | model=model, |
| 110 | messages=[ |
| 111 | {"role": "system", "content": system_prompt}, |
| 112 | {"role": "user", "content": initial_query} |
| 113 | ] |
| 114 | ) |
| 115 | return response.choices[0].message.content, response.usage.completion_tokens |
| 116 | elif predicted_approach == "mcts": |
| 117 | return chat_with_mcts(system_prompt, initial_query, client, model, **kwargs) |
| 118 | elif predicted_approach == "bon": |
| 119 | return best_of_n_sampling(system_prompt, initial_query, client, model, **kwargs) |
| 120 | elif predicted_approach == "moa": |
| 121 | return mixture_of_agents(system_prompt, initial_query, client, model) |
| 122 | elif predicted_approach == "rto": |
| 123 | return round_trip_optimization(system_prompt, initial_query, client, model) |
| 124 | elif predicted_approach == "z3": |
| 125 | z3_solver = Z3SymPySolverSystem(system_prompt, client, model) |
| 126 | return z3_solver.process_query(initial_query) |
| 127 | elif predicted_approach == "self_consistency": |
| 128 | return advanced_self_consistency_approach(system_prompt, initial_query, client, model) |
| 129 | elif predicted_approach == "pvg": |
| 130 | return inference_time_pv_game(system_prompt, initial_query, client, model) |
| 131 | elif predicted_approach == "rstar": |
| 132 | rstar = RStar(system_prompt, client, model, **kwargs) |
| 133 | return rstar.solve(initial_query) |
| 134 | elif predicted_approach == "cot_reflection": |
| 135 | return cot_reflection(system_prompt, initial_query, client, model, **kwargs) |
| 136 | elif predicted_approach == "plansearch": |
| 137 | return plansearch(system_prompt, initial_query, client, model, **kwargs) |
| 138 | elif predicted_approach == "leap": |
| 139 | return leap(system_prompt, initial_query, client, model) |
| 140 | elif predicted_approach == "re2": |
| 141 | return re2_approach(system_prompt, initial_query, client, model, **kwargs) |
| 142 | else: |
| 143 | raise ValueError(f"Unknown approach: {predicted_approach}") |
| 144 | |
| 145 | except Exception as e: |
| 146 | # Log the error and fall back to using the model directly |
| 147 | print(f"Error in router plugin: {str(e)}. Falling back to direct model usage.") |
| 148 | response = client.chat.completions.create( |
| 149 | model=model, |
nothing calls this directly
no test coverage detected