Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 58 additions & 23 deletions examples/.test_scripts/run-with-patched-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import runpy
import sys
from collections.abc import Callable
from typing import Any, cast
from typing import Any, TypeVar, cast

import ai
from ai import models
Expand All @@ -40,6 +40,8 @@

PROTOCOLS = ("chat", "messages", "responses")

ModelT = TypeVar("ModelT", bound=ai.Model)


def _protocol_factory(
name: str | None,
Expand Down Expand Up @@ -107,36 +109,69 @@ def selected_protocol_for_provider(
return None

def selected_protocol_for_model(
model: Any,
model: ai.Model,
) -> ai.ProviderProtocol[Any] | None:
provider = getattr(model, "provider", None)
if provider is None:
return None
return selected_protocol_for_provider(provider)
return selected_protocol_for_provider(model.provider)

def with_selected_protocol(model: ModelT) -> ModelT:
protocol = selected_protocol_for_model(model)
if protocol is None:
return model
return model.with_protocol(protocol)

class PatchedContext:
def __init__(self, context: Any) -> None:
self._context = context
self._model = with_selected_protocol(context.model)

@property
def model(self) -> Any:
return self._model

@property
def messages(self) -> Any:
return self._context.messages

@property
def tools(self) -> Any:
return self._context.tools

@property
def output_type(self) -> Any:
return self._context.output_type

@property
def params(self) -> Any:
return self._context.params

def patched_get_model(*_args: Any, **_kwargs: Any) -> ai.Model:
model_id = args.model or (
_args[0] if _args else _kwargs.get("model_id")
)
model = original_get_model(model_id)
model.protocol = selected_protocol_for_model(model)
return model

def patched_stream(*args: Any, **kwargs: Any) -> Any:
model = (
args[0] if args else getattr(kwargs.get("context"), "model", None)
)
protocol = selected_protocol_for_model(model)
if protocol is not None:
kwargs["protocol"] = protocol
return original_stream(*args, **kwargs)
return with_selected_protocol(model)

async def patched_generate(*args: Any, **kwargs: Any) -> Any:
model = args[0] if args else kwargs.get("model")
protocol = selected_protocol_for_model(model)
if protocol is not None:
kwargs["protocol"] = protocol
return await original_generate(*args, **kwargs)
def patched_stream(*call_args: Any, **kwargs: Any) -> Any:
if call_args:
call_args = (
with_selected_protocol(call_args[0]),
*call_args[1:],
)
elif "model" in kwargs and kwargs["model"] is not None:
kwargs["model"] = with_selected_protocol(kwargs["model"])
elif kwargs.get("context") is not None:
kwargs["context"] = PatchedContext(kwargs["context"])
return original_stream(*call_args, **kwargs)

async def patched_generate(*call_args: Any, **kwargs: Any) -> Any:
if call_args:
call_args = (
with_selected_protocol(call_args[0]),
*call_args[1:],
)
elif "model" in kwargs and kwargs["model"] is not None:
kwargs["model"] = with_selected_protocol(kwargs["model"])
return await original_generate(*call_args, **kwargs)

class PatchedModel(_model.Model):
def __init__(
Expand Down
12 changes: 6 additions & 6 deletions examples/openai_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ async def main() -> None:
print(f"[SKIP] {provider.name} provider is not configured")
return

model = ai.Model("gpt-5.5", provider=provider)
model = ai.Model(
"gpt-5.5",
provider=provider,
protocol=OpenAIChatCompletionsProtocol(),
)

try:
async with ai.stream(
model,
messages,
protocol=OpenAIChatCompletionsProtocol(),
) as stream:
async with ai.stream(model, messages) as stream:
async for event in stream:
if isinstance(event, ai.events.TextDelta):
print(event.chunk, end="", flush=True)
Expand Down
21 changes: 6 additions & 15 deletions src/ai/models/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, AsyncIterator, Sequence

from ...providers import base as provider_base
from . import model as model_
from . import params as params_

Expand All @@ -43,15 +42,13 @@ class StreamRequest:
tools: Sequence[types.tools.Tool] | None = None
output_type: type[pydantic.BaseModel] | None = None
params: Any = None
protocol: provider_base.ProviderProtocol[Any] | None = None


@dataclasses.dataclass(frozen=True)
class GenerateRequest:
model: model_.Model
messages: list[types.messages.Message]
params: params_.GenerateParams
protocol: provider_base.ProviderProtocol[Any] | None = None


@runtime_checkable
Expand Down Expand Up @@ -82,7 +79,6 @@ async def _do_stream(
tools=request.tools,
output_type=request.output_type,
params=request.params,
protocol=request.protocol,
):
yield ev

Expand All @@ -93,7 +89,6 @@ async def _do_generate(
request.model,
request.messages,
request.params,
protocol=request.protocol,
)


Expand Down Expand Up @@ -382,7 +377,6 @@ def stream(
*,
context: StreamContext,
params: Any = None,
protocol: provider_base.ProviderProtocol[Any] | None = None,
executor: StreamExecutor = _default_executor,
) -> AbstractAsyncContextManager[Stream[str]]: ...
@overload
Expand All @@ -391,7 +385,6 @@ def stream[T: pydantic.BaseModel](
context: StreamContext,
output_type: type[T],
params: Any = None,
protocol: provider_base.ProviderProtocol[Any] | None = None,
executor: StreamExecutor = _default_executor,
) -> AbstractAsyncContextManager[Stream[T]]: ...
@overload
Expand All @@ -401,7 +394,6 @@ def stream(
*,
tools: Sequence[types.tools.Tool] | None = None,
params: Any = None,
protocol: provider_base.ProviderProtocol[Any] | None = None,
executor: StreamExecutor = _default_executor,
) -> AbstractAsyncContextManager[Stream[str]]: ...
@overload
Expand All @@ -412,7 +404,6 @@ def stream[T: pydantic.BaseModel](
tools: Sequence[types.tools.Tool] | None = None,
output_type: type[T],
params: Any = None,
protocol: provider_base.ProviderProtocol[Any] | None = None,
executor: StreamExecutor = _default_executor,
) -> AbstractAsyncContextManager[Stream[T]]: ...
def stream(
Expand All @@ -423,7 +414,6 @@ def stream(
tools: Sequence[types.tools.Tool] | None = None,
output_type: type[pydantic.BaseModel] | None = None,
params: Any = None,
protocol: provider_base.ProviderProtocol[Any] | None = None,
executor: StreamExecutor = _default_executor,
) -> AbstractAsyncContextManager[Stream[Any]]:
"""Stream an LLM response.
Expand Down Expand Up @@ -463,7 +453,6 @@ def stream(
tools=tools,
output_type=output_type,
params=params,
protocol=protocol,
executor=executor,
)

Expand All @@ -476,7 +465,6 @@ async def _stream(
tools: Sequence[types.tools.Tool] | None,
output_type: type[pydantic.BaseModel] | None,
params: Any,
protocol: provider_base.ProviderProtocol[Any] | None,
executor: StreamExecutor,
) -> AsyncIterator[Stream[Any]]:
if messages and messages[-1].replay:
Expand All @@ -489,7 +477,11 @@ async def _stream(
else:
prepared = integrity.prepare_messages(messages)
request = StreamRequest(
model, prepared, tools, output_type, params, protocol
model=model,
messages=prepared,
tools=tools,
output_type=output_type,
params=params,
)
s = Stream(
executor._do_stream(request),
Expand All @@ -506,12 +498,11 @@ async def generate(
messages: list[types.messages.Message],
params: params_.GenerateParams,
*,
protocol: provider_base.ProviderProtocol[Any] | None = None,
executor: GenerateExecutor = _default_executor,
) -> types.messages.Message:
"""Generate a non-streaming response (images, video, etc.)."""
messages = integrity.prepare_messages(messages)
request = GenerateRequest(model, messages, params, protocol)
request = GenerateRequest(model, messages, params)
return await executor._do_generate(request)


Expand Down
9 changes: 8 additions & 1 deletion src/ai/models/core/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Model metadata types."""

import os
from typing import Any
from typing import Any, Self

from ... import _modelsdev
from ...errors import ConfigurationError
Expand Down Expand Up @@ -43,6 +43,13 @@ def __repr__(self) -> str:
def __hash__(self) -> int:
return hash((self.id, id(self.provider), id(self.protocol)))

def with_protocol(self, protocol: base.ProviderProtocol[Any]) -> Self:
return self.__class__(
id=self.id,
provider=self.provider,
protocol=protocol,
)


def get_model(
model_id: str | None = None,
Expand Down
8 changes: 1 addition & 7 deletions src/ai/providers/ai_gateway/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def stream(
tools: Sequence[tools_.Tool] | None = None,
output_type: type[pydantic.BaseModel] | None = None,
params: Any = None,
protocol: base.ProviderProtocol[Any] | None = None,
) -> AsyncGenerator[events.Event]:
"""Stream via the AI Gateway v3 protocol."""
return super().stream(
Expand All @@ -94,21 +93,16 @@ def stream(
tools=tools,
output_type=output_type,
params=params,
protocol=protocol,
)

async def generate(
self,
model: model_.Model,
messages: list[messages_.Message],
params: params_.GenerateParams,
*,
protocol: base.ProviderProtocol[Any] | None = None,
) -> messages_.Message:
"""Generate media via the AI Gateway v3 protocol."""
return await super().generate(
model, messages, params, protocol=protocol
)
return await super().generate(model, messages, params)

@classmethod
def from_modelsdev_provider(
Expand Down
2 changes: 0 additions & 2 deletions src/ai/providers/anthropic/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def stream(
tools: Sequence[tools_.Tool] | None = None,
output_type: type[pydantic.BaseModel] | None = None,
params: Any = None,
protocol: base.ProviderProtocol[Any] | None = None,
) -> AsyncGenerator[events.Event]:
"""Stream via the Anthropic messages protocol."""
return super().stream(
Expand All @@ -146,7 +145,6 @@ def stream(
tools=tools,
output_type=output_type,
params=params,
protocol=protocol,
)

@classmethod
Expand Down
7 changes: 2 additions & 5 deletions src/ai/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,9 @@ def stream(
tools: Sequence[tools_.Tool] | None = None,
output_type: type[pydantic.BaseModel] | None = None,
params: Any = None,
protocol: ProviderProtocol[Any] | None = None,
) -> AsyncGenerator[events.Event]:
"""Stream a language-model response from this provider."""
selected_protocol = protocol or model.protocol or self.protocol
selected_protocol = model.protocol or self.protocol
return selected_protocol.stream(
self.client,
model,
Expand All @@ -232,11 +231,9 @@ async def generate(
model: model_.Model,
messages: list[messages_.Message],
params: params_.GenerateParams,
*,
protocol: ProviderProtocol[Any] | None = None,
) -> messages_.Message:
"""Generate a non-streaming response from this provider."""
selected_protocol = protocol or model.protocol or self.protocol
selected_protocol = model.protocol or self.protocol
return await selected_protocol.generate(
self.client,
model,
Expand Down
2 changes: 0 additions & 2 deletions src/ai/providers/openai/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def stream(
tools: Sequence[tools_.Tool] | None = None,
output_type: type[pydantic.BaseModel] | None = None,
params: Any = None,
protocol: base.ProviderProtocol[Any] | None = None,
) -> AsyncGenerator[events.Event]:
"""Stream via this provider's configured OpenAI-compatible protocol."""
return super().stream(
Expand All @@ -144,7 +143,6 @@ def stream(
tools=tools,
output_type=output_type,
params=params,
protocol=protocol,
)

@classmethod
Expand Down
11 changes: 4 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,9 @@ def stream(
tools: Sequence[ai.tools.Tool] | None = None,
output_type: type[pydantic.BaseModel] | None = None,
params: Any = None,
protocol: models.ProviderProtocol[Any] | None = None,
) -> AsyncGenerator[events_.Event]:
if protocol is not None:
return protocol.stream(
if model.protocol is not None:
return model.protocol.stream(
None,
model,
messages,
Expand Down Expand Up @@ -78,11 +77,9 @@ async def generate(
model: models.Model,
messages: list[messages_.Message],
params: Any,
*,
protocol: models.ProviderProtocol[Any] | None = None,
) -> messages_.Message:
if protocol is not None:
return await protocol.generate(
if model.protocol is not None:
return await model.protocol.generate(
None,
model,
messages,
Expand Down
Loading
Loading