Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ if __name__ == "__main__":

### Streaming

The xAI SDK supports streaming responses, allowing you to process model outputs in real-time, which is ideal for interactive applications like chatbots. The `stream` method returns a tuple containing `response` and `chunk`. The chunks contain the text deltas from the stream, while the `response` variable automatically accumulates the response as the stream progresses.
The xAI SDK supports streaming responses, allowing you to process model outputs in real-time, which is ideal for interactive applications like chatbots. The `stream` method yields tuples containing `response` and `chunk`. The chunks contain the text deltas from the stream, while the `response` variable automatically accumulates the response as the stream progresses.

```python
from xai_sdk import Client
Expand All @@ -152,6 +152,17 @@ while True:
chat.append(response)
```

Async streams expose a handle that can stop the underlying request when you no longer need more chunks.

```python
stream = chat.stream()
async for response, chunk in stream:
print(chunk.content, end="", flush=True)
if should_stop_early(chunk):
stream.cancel()
break
```

### Image Understanding

You can easily interleave images and text together, making tasks like image understanding and analysis easy.
Expand Down Expand Up @@ -585,4 +596,4 @@ The xAI SDK is distributed under the Apache-2.0 License

## Contributing

See the [documentation](/CONTRIBUTING.md) on contributing to this project.
See the [documentation](/CONTRIBUTING.md) on contributing to this project.
139 changes: 116 additions & 23 deletions src/xai_sdk/aio/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ async def sample_batch(self, n: int) -> Sequence[Response]:
span.set_attributes(self._make_span_response_attributes(responses))
return responses

async def stream(self) -> AsyncIterator[tuple[Response, Chunk]]:
def stream(self) -> "ChatStream":
"""Asynchronously streams a single chat completion response.

This method streams the model's response in chunks, yielding each chunk as it is
Expand All @@ -167,6 +167,8 @@ async def stream(self) -> AsyncIterator[tuple[Response, Chunk]]:
- `Response`: The accumulating response object, updated with each chunk.
- `Chunk`: A `Chunk` object containing the content and metadata of the
current chunk.
The returned stream also exposes `cancel()` and `aclose()` helpers for
stopping the underlying transport before it is exhausted.

Example:
>>> chat = client.chat.create(model="grok-4.20-non-reasoning")
Expand All @@ -177,32 +179,22 @@ async def stream(self) -> AsyncIterator[tuple[Response, Chunk]]:
>>> print(response.content)
"Once upon a time..." (full accumulated response)
"""
first_chunk_received = False
with tracer.start_as_current_span(
return ChatStream(self)

def _create_stream_state(self) -> tuple[Response, Optional[int]]:
index = None if self._uses_server_side_tools() else 0
response = Response(chat_pb2.GetChatCompletionResponse(outputs=[chat_pb2.CompletionOutput()]), index)
return response, index

def _start_stream_span(self):
return tracer.start_as_current_span(
name=f"chat.stream {self._proto.model}",
kind=SpanKind.CLIENT,
attributes=self._make_span_request_attributes(),
) as span:
index = None if self._uses_server_side_tools() else 0
response = Response(chat_pb2.GetChatCompletionResponse(outputs=[chat_pb2.CompletionOutput()]), index)
stream = self._stub.GetCompletionChunk(self._make_request(1))

async for chunk in stream:
if not first_chunk_received:
span.set_attribute(
"gen_ai.completion.start_time", datetime.datetime.now(datetime.timezone.utc).isoformat()
)
first_chunk_received = True

# Auto-detect if server added tools implicitly
index = self._auto_detect_multi_output_mode(index, chunk.outputs)
response._index = index

response.process_chunk(chunk)
chunk_obj = Chunk(chunk, index)
yield response, chunk_obj
)

span.set_attributes(self._make_span_response_attributes([response]))
def _open_stream(self):
return self._stub.GetCompletionChunk(self._make_request(1))

async def stream_batch(self, n: int) -> AsyncIterator[tuple[Sequence[Response], Sequence[Chunk]]]:
"""Asynchronously streams multiple chat completion responses.
Expand Down Expand Up @@ -429,3 +421,104 @@ async def defer_batch(

response = await self._defer(n, timeout, interval)
return [Response(response, i) for i in range(n)]


class ChatStream(AsyncIterator[tuple[Response, Chunk]]):
"""Cancelable async iterator returned by `Chat.stream()`."""

def __init__(self, chat: Chat):
"""Create a stream handle for a chat request."""
self._chat = chat
self._response, self._index = chat._create_stream_state()
self._stream = None
self._stream_iter = None
self._span_cm = None
self._span = None
self._first_chunk_received = False
self._closed = False
self._finished = False
self._cancelled = False

def __aiter__(self) -> "ChatStream":
return self

async def __aenter__(self) -> "ChatStream":
self._ensure_started()
return self

async def __aexit__(self, exc_type, exc, tb) -> None:
await self.aclose()

async def __anext__(self) -> tuple[Response, Chunk]:
if self._closed:
raise StopAsyncIteration

self._ensure_started()
stream_iter = self._stream_iter
span = self._span
if stream_iter is None or span is None:
raise RuntimeError("Stream was not initialized.")

try:
chunk = await stream_iter.__anext__()
except StopAsyncIteration:
self._finish()
raise
except BaseException as exc:
self._finish(type(exc), exc, exc.__traceback__)
raise

if not self._first_chunk_received:
span.set_attribute("gen_ai.completion.start_time", datetime.datetime.now(datetime.timezone.utc).isoformat())
self._first_chunk_received = True

self._index = self._chat._auto_detect_multi_output_mode(self._index, chunk.outputs)
self._response._index = self._index
self._response.process_chunk(chunk)
return self._response, Chunk(chunk, self._index)

@property
def cancelled(self) -> bool:
"""Whether `cancel()` has been requested for this stream."""
return self._cancelled

def cancel(self) -> bool:
"""Cancel the underlying gRPC stream if it has been started."""
if self._closed:
return False

self._cancelled = True
if self._stream is None:
self._closed = True
self._finish()
return True

cancelled = self._stream.cancel()
self._closed = True
self._finish()
return bool(cancelled)

async def aclose(self) -> None:
"""Close the iterator and cancel the transport if it is still active."""
if not self._closed:
self.cancel()

def _ensure_started(self) -> None:
if self._stream is not None:
return

self._span_cm = self._chat._start_stream_span()
self._span = self._span_cm.__enter__()
self._stream = self._chat._open_stream()
self._stream_iter = self._stream.__aiter__()

def _finish(self, exc_type=None, exc=None, tb=None) -> None:
if self._finished:
return

self._finished = True
self._closed = True
if self._span is not None and not self._cancelled:
self._span.set_attributes(self._chat._make_span_response_attributes([self._response]))
if self._span_cm is not None:
self._span_cm.__exit__(exc_type, exc, tb)
47 changes: 47 additions & 0 deletions tests/aio/chat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,53 @@ async def test_streaming(client):
assert last_response.content == "Hello, this is a test response!"


@pytest.mark.asyncio(loop_scope="session")
async def test_stream_cancel_cancels_underlying_rpc(client):
class CancelableAsyncStream:
def __init__(self, chunks):
self._chunks = list(chunks)
self.cancelled = False

def __aiter__(self):
return self

async def __anext__(self):
if self.cancelled or not self._chunks:
raise StopAsyncIteration
return self._chunks.pop(0)

def cancel(self):
self.cancelled = True
return True

chat = client.chat.create("grok-3-latest")
chat.append(user("test message"))
fake_rpc_stream = CancelableAsyncStream(
[
chat_pb2.GetChatCompletionChunk(
outputs=[
chat_pb2.CompletionOutputChunk(
delta=chat_pb2.Delta(content="Hello, ", role=chat_pb2.ROLE_ASSISTANT),
index=0,
)
]
)
]
)

with mock.patch.object(chat._stub, "GetCompletionChunk", return_value=fake_rpc_stream):
stream = chat.stream()
_, chunk = await stream.__anext__()

assert chunk.content == "Hello, "
assert stream.cancel() is True
assert stream.cancelled is True
assert fake_rpc_stream.cancelled is True

with pytest.raises(StopAsyncIteration):
await stream.__anext__()


@pytest.mark.asyncio(loop_scope="session")
async def test_streaming_batch(client):
chat = client.chat.create("grok-3-latest")
Expand Down