Extract tool calls from ``text`` using a mlx-lm tool module. Ports the ``process_tool_calls`` logic from ``mlx_vlm/server.py`` (v0.10 onwards). ``tool_module`` must expose ``tool_call_start``, ``tool_call_end`` and ``parse_tool_call``. Returns ``(calls, remaining_text)`` where ``ca
(text, tool_module, tools)
| 34 | |
| 35 | |
| 36 | def parse_tool_calls(text, tool_module, tools): |
| 37 | """Extract tool calls from ``text`` using a mlx-lm tool module. |
| 38 | |
| 39 | Ports the ``process_tool_calls`` logic from |
| 40 | ``mlx_vlm/server.py`` (v0.10 onwards). ``tool_module`` must expose |
| 41 | ``tool_call_start``, ``tool_call_end`` and ``parse_tool_call``. |
| 42 | |
| 43 | Returns ``(calls, remaining_text)`` where ``calls`` is a list of dicts: |
| 44 | |
| 45 | [{"index": int, "id": str, "name": str, "arguments": str (JSON)}] |
| 46 | |
| 47 | and ``remaining_text`` is the free-form text with the tool call blocks |
| 48 | removed. ``(calls, text)`` is returned unchanged if ``tool_module`` is |
| 49 | ``None`` or the start delimiter isn't present. |
| 50 | """ |
| 51 | if tool_module is None or not text: |
| 52 | return [], text |
| 53 | start = getattr(tool_module, "tool_call_start", None) |
| 54 | end = getattr(tool_module, "tool_call_end", None) |
| 55 | parse_fn = getattr(tool_module, "parse_tool_call", None) |
| 56 | if not start or parse_fn is None or start not in text: |
| 57 | return [], text |
| 58 | |
| 59 | if end == "" or end is None: |
| 60 | pattern = re.compile( |
| 61 | re.escape(start) + r".*?(?:\n|$)", |
| 62 | re.DOTALL, |
| 63 | ) |
| 64 | else: |
| 65 | pattern = re.compile( |
| 66 | re.escape(start) + r".*?" + re.escape(end), |
| 67 | re.DOTALL, |
| 68 | ) |
| 69 | |
| 70 | matches = pattern.findall(text) |
| 71 | if not matches: |
| 72 | return [], text |
| 73 | |
| 74 | remaining = pattern.sub(" ", text).strip() |
| 75 | calls = [] |
| 76 | for match in matches: |
| 77 | call_body = match.strip().removeprefix(start) |
| 78 | if end: |
| 79 | call_body = call_body.removesuffix(end) |
| 80 | call_body = call_body.strip() |
| 81 | try: |
| 82 | parsed = parse_fn(call_body, tools) |
| 83 | except Exception as e: |
| 84 | print( |
| 85 | f"[mlx_utils] Invalid tool call: {call_body!r} ({e})", |
| 86 | file=sys.stderr, |
| 87 | ) |
| 88 | continue |
| 89 | if not isinstance(parsed, list): |
| 90 | parsed = [parsed] |
| 91 | for tc in parsed: |
| 92 | calls.append( |
| 93 | { |