From 7ed05b5f9e0247af49f1aa24abfd77ec30b4b597 Mon Sep 17 00:00:00 2001 From: mukunda katta Date: Fri, 15 May 2026 13:34:43 -0700 Subject: [PATCH] feat: expose async stream cancel handle --- README.md | 15 ++++- src/xai_sdk/aio/chat.py | 139 +++++++++++++++++++++++++++++++++------- tests/aio/chat_test.py | 47 ++++++++++++++ 3 files changed, 176 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 7130a8e..2711a4d 100644 --- a/README.md +++ b/README.md @@ -131,7 +131,7 @@ if __name__ == "__main__": ### Streaming -The xAI SDK supports streaming responses, allowing you to process model outputs in real-time, which is ideal for interactive applications like chatbots. The `stream` method returns a tuple containing `response` and `chunk`. The chunks contain the text deltas from the stream, while the `response` variable automatically accumulates the response as the stream progresses. +The xAI SDK supports streaming responses, allowing you to process model outputs in real-time, which is ideal for interactive applications like chatbots. The `stream` method yields tuples containing `response` and `chunk`. The chunks contain the text deltas from the stream, while the `response` variable automatically accumulates the response as the stream progresses. ```python from xai_sdk import Client @@ -152,6 +152,17 @@ while True: chat.append(response) ``` +Async streams expose a handle that can stop the underlying request when you no longer need more chunks. + +```python +stream = chat.stream() +async for response, chunk in stream: + print(chunk.content, end="", flush=True) + if should_stop_early(chunk): + stream.cancel() + break +``` + ### Image Understanding You can easily interleave images and text together, making tasks like image understanding and analysis easy. @@ -585,4 +596,4 @@ The xAI SDK is distributed under the Apache-2.0 License ## Contributing -See the [documentation](/CONTRIBUTING.md) on contributing to this project. \ No newline at end of file +See the [documentation](/CONTRIBUTING.md) on contributing to this project. diff --git a/src/xai_sdk/aio/chat.py b/src/xai_sdk/aio/chat.py index 292fe3c..d46442b 100644 --- a/src/xai_sdk/aio/chat.py +++ b/src/xai_sdk/aio/chat.py @@ -154,7 +154,7 @@ async def sample_batch(self, n: int) -> Sequence[Response]: span.set_attributes(self._make_span_response_attributes(responses)) return responses - async def stream(self) -> AsyncIterator[tuple[Response, Chunk]]: + def stream(self) -> "ChatStream": """Asynchronously streams a single chat completion response. This method streams the model's response in chunks, yielding each chunk as it is @@ -167,6 +167,8 @@ async def stream(self) -> AsyncIterator[tuple[Response, Chunk]]: - `Response`: The accumulating response object, updated with each chunk. - `Chunk`: A `Chunk` object containing the content and metadata of the current chunk. + The returned stream also exposes `cancel()` and `aclose()` helpers for + stopping the underlying transport before it is exhausted. Example: >>> chat = client.chat.create(model="grok-4.20-non-reasoning") @@ -177,32 +179,22 @@ async def stream(self) -> AsyncIterator[tuple[Response, Chunk]]: >>> print(response.content) "Once upon a time..." (full accumulated response) """ - first_chunk_received = False - with tracer.start_as_current_span( + return ChatStream(self) + + def _create_stream_state(self) -> tuple[Response, Optional[int]]: + index = None if self._uses_server_side_tools() else 0 + response = Response(chat_pb2.GetChatCompletionResponse(outputs=[chat_pb2.CompletionOutput()]), index) + return response, index + + def _start_stream_span(self): + return tracer.start_as_current_span( name=f"chat.stream {self._proto.model}", kind=SpanKind.CLIENT, attributes=self._make_span_request_attributes(), - ) as span: - index = None if self._uses_server_side_tools() else 0 - response = Response(chat_pb2.GetChatCompletionResponse(outputs=[chat_pb2.CompletionOutput()]), index) - stream = self._stub.GetCompletionChunk(self._make_request(1)) - - async for chunk in stream: - if not first_chunk_received: - span.set_attribute( - "gen_ai.completion.start_time", datetime.datetime.now(datetime.timezone.utc).isoformat() - ) - first_chunk_received = True - - # Auto-detect if server added tools implicitly - index = self._auto_detect_multi_output_mode(index, chunk.outputs) - response._index = index - - response.process_chunk(chunk) - chunk_obj = Chunk(chunk, index) - yield response, chunk_obj + ) - span.set_attributes(self._make_span_response_attributes([response])) + def _open_stream(self): + return self._stub.GetCompletionChunk(self._make_request(1)) async def stream_batch(self, n: int) -> AsyncIterator[tuple[Sequence[Response], Sequence[Chunk]]]: """Asynchronously streams multiple chat completion responses. @@ -429,3 +421,104 @@ async def defer_batch( response = await self._defer(n, timeout, interval) return [Response(response, i) for i in range(n)] + + +class ChatStream(AsyncIterator[tuple[Response, Chunk]]): + """Cancelable async iterator returned by `Chat.stream()`.""" + + def __init__(self, chat: Chat): + """Create a stream handle for a chat request.""" + self._chat = chat + self._response, self._index = chat._create_stream_state() + self._stream = None + self._stream_iter = None + self._span_cm = None + self._span = None + self._first_chunk_received = False + self._closed = False + self._finished = False + self._cancelled = False + + def __aiter__(self) -> "ChatStream": + return self + + async def __aenter__(self) -> "ChatStream": + self._ensure_started() + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.aclose() + + async def __anext__(self) -> tuple[Response, Chunk]: + if self._closed: + raise StopAsyncIteration + + self._ensure_started() + stream_iter = self._stream_iter + span = self._span + if stream_iter is None or span is None: + raise RuntimeError("Stream was not initialized.") + + try: + chunk = await stream_iter.__anext__() + except StopAsyncIteration: + self._finish() + raise + except BaseException as exc: + self._finish(type(exc), exc, exc.__traceback__) + raise + + if not self._first_chunk_received: + span.set_attribute("gen_ai.completion.start_time", datetime.datetime.now(datetime.timezone.utc).isoformat()) + self._first_chunk_received = True + + self._index = self._chat._auto_detect_multi_output_mode(self._index, chunk.outputs) + self._response._index = self._index + self._response.process_chunk(chunk) + return self._response, Chunk(chunk, self._index) + + @property + def cancelled(self) -> bool: + """Whether `cancel()` has been requested for this stream.""" + return self._cancelled + + def cancel(self) -> bool: + """Cancel the underlying gRPC stream if it has been started.""" + if self._closed: + return False + + self._cancelled = True + if self._stream is None: + self._closed = True + self._finish() + return True + + cancelled = self._stream.cancel() + self._closed = True + self._finish() + return bool(cancelled) + + async def aclose(self) -> None: + """Close the iterator and cancel the transport if it is still active.""" + if not self._closed: + self.cancel() + + def _ensure_started(self) -> None: + if self._stream is not None: + return + + self._span_cm = self._chat._start_stream_span() + self._span = self._span_cm.__enter__() + self._stream = self._chat._open_stream() + self._stream_iter = self._stream.__aiter__() + + def _finish(self, exc_type=None, exc=None, tb=None) -> None: + if self._finished: + return + + self._finished = True + self._closed = True + if self._span is not None and not self._cancelled: + self._span.set_attributes(self._chat._make_span_response_attributes([self._response])) + if self._span_cm is not None: + self._span_cm.__exit__(exc_type, exc, tb) diff --git a/tests/aio/chat_test.py b/tests/aio/chat_test.py index 052b9c7..78f0ec8 100644 --- a/tests/aio/chat_test.py +++ b/tests/aio/chat_test.py @@ -94,6 +94,53 @@ async def test_streaming(client): assert last_response.content == "Hello, this is a test response!" +@pytest.mark.asyncio(loop_scope="session") +async def test_stream_cancel_cancels_underlying_rpc(client): + class CancelableAsyncStream: + def __init__(self, chunks): + self._chunks = list(chunks) + self.cancelled = False + + def __aiter__(self): + return self + + async def __anext__(self): + if self.cancelled or not self._chunks: + raise StopAsyncIteration + return self._chunks.pop(0) + + def cancel(self): + self.cancelled = True + return True + + chat = client.chat.create("grok-3-latest") + chat.append(user("test message")) + fake_rpc_stream = CancelableAsyncStream( + [ + chat_pb2.GetChatCompletionChunk( + outputs=[ + chat_pb2.CompletionOutputChunk( + delta=chat_pb2.Delta(content="Hello, ", role=chat_pb2.ROLE_ASSISTANT), + index=0, + ) + ] + ) + ] + ) + + with mock.patch.object(chat._stub, "GetCompletionChunk", return_value=fake_rpc_stream): + stream = chat.stream() + _, chunk = await stream.__anext__() + + assert chunk.content == "Hello, " + assert stream.cancel() is True + assert stream.cancelled is True + assert fake_rpc_stream.cancelled is True + + with pytest.raises(StopAsyncIteration): + await stream.__anext__() + + @pytest.mark.asyncio(loop_scope="session") async def test_streaming_batch(client): chat = client.chat.create("grok-3-latest")