| 108 | else: self.logger.info("Log file is saved to", self.logger.log_path, "...", title="Log Path", color="light_cyan3") |
| 109 | |
| 110 | def get_chat_completion( |
| 111 | self, |
| 112 | agent: Agent, |
| 113 | history: List, |
| 114 | context_variables: dict, |
| 115 | model_override: str, |
| 116 | stream: bool, |
| 117 | debug: bool, |
| 118 | ) -> Message: |
| 119 | context_variables = defaultdict(str, context_variables) |
| 120 | instructions = ( |
| 121 | agent.instructions(context_variables) |
| 122 | if callable(agent.instructions) |
| 123 | else agent.instructions |
| 124 | ) |
| 125 | if agent.examples: |
| 126 | examples = agent.examples(context_variables) if callable(agent.examples) else agent.examples |
| 127 | history = examples + history |
| 128 | |
| 129 | messages = [{"role": "system", "content": instructions}] + history |
| 130 | # debug_print(debug, "Getting chat completion for...:", messages) |
| 131 | |
| 132 | tools = [function_to_json(f) for f in agent.functions] |
| 133 | # hide context_variables from model |
| 134 | for tool in tools: |
| 135 | params = tool["function"]["parameters"] |
| 136 | params["properties"].pop(__CTX_VARS_NAME__, None) |
| 137 | if __CTX_VARS_NAME__ in params["required"]: |
| 138 | params["required"].remove(__CTX_VARS_NAME__) |
| 139 | |
| 140 | create_params = { |
| 141 | "model": model_override or agent.model, |
| 142 | "messages": messages, |
| 143 | "tools": tools or None, |
| 144 | "tool_choice": agent.tool_choice, |
| 145 | "stream": stream, |
| 146 | "base_url": API_BASE_URL, |
| 147 | } |
| 148 | |
| 149 | if create_params['model'].startswith("mistral"): |
| 150 | messages = create_params["messages"] |
| 151 | for message in messages: |
| 152 | if 'sender' in message: |
| 153 | del message['sender'] |
| 154 | create_params["messages"] = messages |
| 155 | |
| 156 | if tools and create_params['model'].startswith("gpt"): |
| 157 | create_params["parallel_tool_calls"] = agent.parallel_tool_calls |
| 158 | |
| 159 | return completion(**create_params) |
| 160 | |
| 161 | def handle_function_result(self, result, debug) -> Result: |
| 162 | match result: |