Client transport for SSE. `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. Args: url: The SSE endpoint URL. headers: Optional headers to include in reque
(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 300.0,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
on_session_created: Callable[[str], None] | None = None,
)
| 29 | |
| 30 | @asynccontextmanager |
| 31 | async def sse_client( |
| 32 | url: str, |
| 33 | headers: dict[str, Any] | None = None, |
| 34 | timeout: float = 5.0, |
| 35 | sse_read_timeout: float = 300.0, |
| 36 | httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, |
| 37 | auth: httpx.Auth | None = None, |
| 38 | on_session_created: Callable[[str], None] | None = None, |
| 39 | ): |
| 40 | """Client transport for SSE. |
| 41 | |
| 42 | `sse_read_timeout` determines how long (in seconds) the client will wait for a new |
| 43 | event before disconnecting. All other HTTP operations are controlled by `timeout`. |
| 44 | |
| 45 | Args: |
| 46 | url: The SSE endpoint URL. |
| 47 | headers: Optional headers to include in requests. |
| 48 | timeout: HTTP timeout for regular operations (in seconds). |
| 49 | sse_read_timeout: Timeout for SSE read operations (in seconds). |
| 50 | httpx_client_factory: Factory function for creating the HTTPX client. |
| 51 | auth: Optional HTTPX authentication handler. |
| 52 | on_session_created: Optional callback invoked with the session ID when received. |
| 53 | """ |
| 54 | logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") |
| 55 | async with httpx_client_factory( |
| 56 | headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) |
| 57 | ) as client: |
| 58 | async with aconnect_sse(client, "GET", url) as event_source: |
| 59 | event_source.response.raise_for_status() |
| 60 | logger.debug("SSE connection established") |
| 61 | |
| 62 | read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) |
| 63 | write_stream, write_stream_reader = create_context_streams[SessionMessage](0) |
| 64 | |
| 65 | async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): |
| 66 | try: |
| 67 | async for sse in event_source.aiter_sse(): # pragma: no branch |
| 68 | logger.debug(f"Received SSE event: {sse.event}") |
| 69 | match sse.event: |
| 70 | case "endpoint": |
| 71 | endpoint_url = urljoin(url, sse.data) |
| 72 | logger.debug(f"Received endpoint URL: {endpoint_url}") |
| 73 | |
| 74 | url_parsed = urlparse(url) |
| 75 | endpoint_parsed = urlparse(endpoint_url) |
| 76 | if ( # pragma: no cover |
| 77 | url_parsed.netloc != endpoint_parsed.netloc |
| 78 | or url_parsed.scheme != endpoint_parsed.scheme |
| 79 | ): |
| 80 | error_msg = ( # pragma: no cover |
| 81 | f"Endpoint origin does not match connection origin: {endpoint_url}" |
| 82 | ) |
| 83 | logger.error(error_msg) # pragma: no cover |
| 84 | raise ValueError(error_msg) # pragma: no cover |
| 85 | |
| 86 | if on_session_created: |
| 87 | session_id = _extract_session_id_from_endpoint(endpoint_url) |
| 88 | if session_id: |