| 122 | |
| 123 | @asynccontextmanager |
| 124 | async def connect_sse(self, scope: Scope, receive: Receive, send: Send): |
| 125 | if scope["type"] != "http": |
| 126 | logger.error("connect_sse received non-HTTP request") |
| 127 | raise ValueError("connect_sse can only handle HTTP requests") |
| 128 | |
| 129 | # Validate request headers for DNS rebinding protection |
| 130 | request = Request(scope, receive) |
| 131 | error_response = await self._security.validate_request(request, is_post=False) |
| 132 | if error_response: |
| 133 | await error_response(scope, receive, send) |
| 134 | raise ValueError("Request validation failed") |
| 135 | |
| 136 | logger.debug("Setting up SSE connection") |
| 137 | |
| 138 | read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) |
| 139 | write_stream, write_stream_reader = create_context_streams[SessionMessage](0) |
| 140 | |
| 141 | session_id = uuid4() |
| 142 | user = scope.get("user") |
| 143 | if isinstance(user, AuthenticatedUser): |
| 144 | self._session_owners[session_id] = authorization_context(user) |
| 145 | self._read_stream_writers[session_id] = read_stream_writer |
| 146 | logger.debug(f"Created new session with ID: {session_id}") |
| 147 | |
| 148 | # Determine the full path for the message endpoint to be sent to the client. |
| 149 | # scope['root_path'] is the prefix where the current Starlette app |
| 150 | # instance is mounted. |
| 151 | # e.g., "" if top-level, or "/api_prefix" if mounted under "/api_prefix". |
| 152 | root_path = scope.get("root_path", "") |
| 153 | |
| 154 | # self._endpoint is the path *within* this app, e.g., "/messages". |
| 155 | # Concatenating them gives the full absolute path from the server root. |
| 156 | # e.g., "" + "/messages" -> "/messages" |
| 157 | # e.g., "/api_prefix" + "/messages" -> "/api_prefix/messages" |
| 158 | full_message_path_for_client = root_path.rstrip("/") + self._endpoint |
| 159 | |
| 160 | # This is the URI (path + query) the client will use to POST messages. |
| 161 | client_post_uri_data = f"{quote(full_message_path_for_client)}?session_id={session_id.hex}" |
| 162 | |
| 163 | sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, Any]](0) |
| 164 | |
| 165 | async def sse_writer(): |
| 166 | logger.debug("Starting SSE writer") |
| 167 | async with sse_stream_writer, write_stream_reader: |
| 168 | await sse_stream_writer.send({"event": "endpoint", "data": client_post_uri_data}) |
| 169 | logger.debug(f"Sent endpoint event: {client_post_uri_data}") |
| 170 | |
| 171 | async for session_message in write_stream_reader: |
| 172 | logger.debug(f"Sending message via SSE: {session_message}") |
| 173 | await sse_stream_writer.send( |
| 174 | { |
| 175 | "event": "message", |
| 176 | "data": session_message.message.model_dump_json(by_alias=True, exclude_unset=True), |
| 177 | } |
| 178 | ) |
| 179 | |
| 180 | try: |
| 181 | async with anyio.create_task_group() as tg: |