diff --git a/examples/.test_scripts/run-with-patched-model.py b/examples/.test_scripts/run-with-patched-model.py index a3c7adb..253977a 100644 --- a/examples/.test_scripts/run-with-patched-model.py +++ b/examples/.test_scripts/run-with-patched-model.py @@ -21,7 +21,7 @@ import runpy import sys from collections.abc import Callable -from typing import Any, cast +from typing import Any, TypeVar, cast import ai from ai import models @@ -40,6 +40,8 @@ PROTOCOLS = ("chat", "messages", "responses") +ModelT = TypeVar("ModelT", bound=ai.Model) + def _protocol_factory( name: str | None, @@ -107,36 +109,69 @@ def selected_protocol_for_provider( return None def selected_protocol_for_model( - model: Any, + model: ai.Model, ) -> ai.ProviderProtocol[Any] | None: - provider = getattr(model, "provider", None) - if provider is None: - return None - return selected_protocol_for_provider(provider) + return selected_protocol_for_provider(model.provider) + + def with_selected_protocol(model: ModelT) -> ModelT: + protocol = selected_protocol_for_model(model) + if protocol is None: + return model + return model.with_protocol(protocol) + + class PatchedContext: + def __init__(self, context: Any) -> None: + self._context = context + self._model = with_selected_protocol(context.model) + + @property + def model(self) -> Any: + return self._model + + @property + def messages(self) -> Any: + return self._context.messages + + @property + def tools(self) -> Any: + return self._context.tools + + @property + def output_type(self) -> Any: + return self._context.output_type + + @property + def params(self) -> Any: + return self._context.params def patched_get_model(*_args: Any, **_kwargs: Any) -> ai.Model: model_id = args.model or ( _args[0] if _args else _kwargs.get("model_id") ) model = original_get_model(model_id) - model.protocol = selected_protocol_for_model(model) - return model - - def patched_stream(*args: Any, **kwargs: Any) -> Any: - model = ( - args[0] if args else getattr(kwargs.get("context"), "model", None) - ) - protocol = selected_protocol_for_model(model) - if protocol is not None: - kwargs["protocol"] = protocol - return original_stream(*args, **kwargs) + return with_selected_protocol(model) - async def patched_generate(*args: Any, **kwargs: Any) -> Any: - model = args[0] if args else kwargs.get("model") - protocol = selected_protocol_for_model(model) - if protocol is not None: - kwargs["protocol"] = protocol - return await original_generate(*args, **kwargs) + def patched_stream(*call_args: Any, **kwargs: Any) -> Any: + if call_args: + call_args = ( + with_selected_protocol(call_args[0]), + *call_args[1:], + ) + elif "model" in kwargs and kwargs["model"] is not None: + kwargs["model"] = with_selected_protocol(kwargs["model"]) + elif kwargs.get("context") is not None: + kwargs["context"] = PatchedContext(kwargs["context"]) + return original_stream(*call_args, **kwargs) + + async def patched_generate(*call_args: Any, **kwargs: Any) -> Any: + if call_args: + call_args = ( + with_selected_protocol(call_args[0]), + *call_args[1:], + ) + elif "model" in kwargs and kwargs["model"] is not None: + kwargs["model"] = with_selected_protocol(kwargs["model"]) + return await original_generate(*call_args, **kwargs) class PatchedModel(_model.Model): def __init__( diff --git a/examples/openai_chat_completions.py b/examples/openai_chat_completions.py index 770cffc..5a90972 100644 --- a/examples/openai_chat_completions.py +++ b/examples/openai_chat_completions.py @@ -19,14 +19,14 @@ async def main() -> None: print(f"[SKIP] {provider.name} provider is not configured") return - model = ai.Model("gpt-5.5", provider=provider) + model = ai.Model( + "gpt-5.5", + provider=provider, + protocol=OpenAIChatCompletionsProtocol(), + ) try: - async with ai.stream( - model, - messages, - protocol=OpenAIChatCompletionsProtocol(), - ) as stream: + async with ai.stream(model, messages) as stream: async for event in stream: if isinstance(event, ai.events.TextDelta): print(event.chunk, end="", flush=True) diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index 94c23c7..c95e906 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -26,7 +26,6 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator, AsyncIterator, Sequence - from ...providers import base as provider_base from . import model as model_ from . import params as params_ @@ -43,7 +42,6 @@ class StreamRequest: tools: Sequence[types.tools.Tool] | None = None output_type: type[pydantic.BaseModel] | None = None params: Any = None - protocol: provider_base.ProviderProtocol[Any] | None = None @dataclasses.dataclass(frozen=True) @@ -51,7 +49,6 @@ class GenerateRequest: model: model_.Model messages: list[types.messages.Message] params: params_.GenerateParams - protocol: provider_base.ProviderProtocol[Any] | None = None @runtime_checkable @@ -82,7 +79,6 @@ async def _do_stream( tools=request.tools, output_type=request.output_type, params=request.params, - protocol=request.protocol, ): yield ev @@ -93,7 +89,6 @@ async def _do_generate( request.model, request.messages, request.params, - protocol=request.protocol, ) @@ -382,7 +377,6 @@ def stream( *, context: StreamContext, params: Any = None, - protocol: provider_base.ProviderProtocol[Any] | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[str]]: ... @overload @@ -391,7 +385,6 @@ def stream[T: pydantic.BaseModel]( context: StreamContext, output_type: type[T], params: Any = None, - protocol: provider_base.ProviderProtocol[Any] | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[T]]: ... @overload @@ -401,7 +394,6 @@ def stream( *, tools: Sequence[types.tools.Tool] | None = None, params: Any = None, - protocol: provider_base.ProviderProtocol[Any] | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[str]]: ... @overload @@ -412,7 +404,6 @@ def stream[T: pydantic.BaseModel]( tools: Sequence[types.tools.Tool] | None = None, output_type: type[T], params: Any = None, - protocol: provider_base.ProviderProtocol[Any] | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[T]]: ... def stream( @@ -423,7 +414,6 @@ def stream( tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, params: Any = None, - protocol: provider_base.ProviderProtocol[Any] | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[Any]]: """Stream an LLM response. @@ -463,7 +453,6 @@ def stream( tools=tools, output_type=output_type, params=params, - protocol=protocol, executor=executor, ) @@ -476,7 +465,6 @@ async def _stream( tools: Sequence[types.tools.Tool] | None, output_type: type[pydantic.BaseModel] | None, params: Any, - protocol: provider_base.ProviderProtocol[Any] | None, executor: StreamExecutor, ) -> AsyncIterator[Stream[Any]]: if messages and messages[-1].replay: @@ -489,7 +477,11 @@ async def _stream( else: prepared = integrity.prepare_messages(messages) request = StreamRequest( - model, prepared, tools, output_type, params, protocol + model=model, + messages=prepared, + tools=tools, + output_type=output_type, + params=params, ) s = Stream( executor._do_stream(request), @@ -506,12 +498,11 @@ async def generate( messages: list[types.messages.Message], params: params_.GenerateParams, *, - protocol: provider_base.ProviderProtocol[Any] | None = None, executor: GenerateExecutor = _default_executor, ) -> types.messages.Message: """Generate a non-streaming response (images, video, etc.).""" messages = integrity.prepare_messages(messages) - request = GenerateRequest(model, messages, params, protocol) + request = GenerateRequest(model, messages, params) return await executor._do_generate(request) diff --git a/src/ai/models/core/model.py b/src/ai/models/core/model.py index 9ab4bf2..5ff6fd5 100644 --- a/src/ai/models/core/model.py +++ b/src/ai/models/core/model.py @@ -1,7 +1,7 @@ """Model metadata types.""" import os -from typing import Any +from typing import Any, Self from ... import _modelsdev from ...errors import ConfigurationError @@ -43,6 +43,13 @@ def __repr__(self) -> str: def __hash__(self) -> int: return hash((self.id, id(self.provider), id(self.protocol))) + def with_protocol(self, protocol: base.ProviderProtocol[Any]) -> Self: + return self.__class__( + id=self.id, + provider=self.provider, + protocol=protocol, + ) + def get_model( model_id: str | None = None, diff --git a/src/ai/providers/ai_gateway/provider.py b/src/ai/providers/ai_gateway/provider.py index e61c0d4..71f5b25 100644 --- a/src/ai/providers/ai_gateway/provider.py +++ b/src/ai/providers/ai_gateway/provider.py @@ -85,7 +85,6 @@ def stream( tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, params: Any = None, - protocol: base.ProviderProtocol[Any] | None = None, ) -> AsyncGenerator[events.Event]: """Stream via the AI Gateway v3 protocol.""" return super().stream( @@ -94,7 +93,6 @@ def stream( tools=tools, output_type=output_type, params=params, - protocol=protocol, ) async def generate( @@ -102,13 +100,9 @@ async def generate( model: model_.Model, messages: list[messages_.Message], params: params_.GenerateParams, - *, - protocol: base.ProviderProtocol[Any] | None = None, ) -> messages_.Message: """Generate media via the AI Gateway v3 protocol.""" - return await super().generate( - model, messages, params, protocol=protocol - ) + return await super().generate(model, messages, params) @classmethod def from_modelsdev_provider( diff --git a/src/ai/providers/anthropic/provider.py b/src/ai/providers/anthropic/provider.py index be7f5cb..87042a4 100644 --- a/src/ai/providers/anthropic/provider.py +++ b/src/ai/providers/anthropic/provider.py @@ -137,7 +137,6 @@ def stream( tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, params: Any = None, - protocol: base.ProviderProtocol[Any] | None = None, ) -> AsyncGenerator[events.Event]: """Stream via the Anthropic messages protocol.""" return super().stream( @@ -146,7 +145,6 @@ def stream( tools=tools, output_type=output_type, params=params, - protocol=protocol, ) @classmethod diff --git a/src/ai/providers/base.py b/src/ai/providers/base.py index f59299c..91fd33b 100644 --- a/src/ai/providers/base.py +++ b/src/ai/providers/base.py @@ -213,10 +213,9 @@ def stream( tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, params: Any = None, - protocol: ProviderProtocol[Any] | None = None, ) -> AsyncGenerator[events.Event]: """Stream a language-model response from this provider.""" - selected_protocol = protocol or model.protocol or self.protocol + selected_protocol = model.protocol or self.protocol return selected_protocol.stream( self.client, model, @@ -232,11 +231,9 @@ async def generate( model: model_.Model, messages: list[messages_.Message], params: params_.GenerateParams, - *, - protocol: ProviderProtocol[Any] | None = None, ) -> messages_.Message: """Generate a non-streaming response from this provider.""" - selected_protocol = protocol or model.protocol or self.protocol + selected_protocol = model.protocol or self.protocol return await selected_protocol.generate( self.client, model, diff --git a/src/ai/providers/openai/provider.py b/src/ai/providers/openai/provider.py index 51fb3c9..92b9304 100644 --- a/src/ai/providers/openai/provider.py +++ b/src/ai/providers/openai/provider.py @@ -135,7 +135,6 @@ def stream( tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, params: Any = None, - protocol: base.ProviderProtocol[Any] | None = None, ) -> AsyncGenerator[events.Event]: """Stream via this provider's configured OpenAI-compatible protocol.""" return super().stream( @@ -144,7 +143,6 @@ def stream( tools=tools, output_type=output_type, params=params, - protocol=protocol, ) @classmethod diff --git a/tests/conftest.py b/tests/conftest.py index 7fc2ffd..92ef235 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,10 +46,9 @@ def stream( tools: Sequence[ai.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, params: Any = None, - protocol: models.ProviderProtocol[Any] | None = None, ) -> AsyncGenerator[events_.Event]: - if protocol is not None: - return protocol.stream( + if model.protocol is not None: + return model.protocol.stream( None, model, messages, @@ -78,11 +77,9 @@ async def generate( model: models.Model, messages: list[messages_.Message], params: Any, - *, - protocol: models.ProviderProtocol[Any] | None = None, ) -> messages_.Message: - if protocol is not None: - return await protocol.generate( + if model.protocol is not None: + return await model.protocol.generate( None, model, messages, diff --git a/tests/models/core/test_api.py b/tests/models/core/test_api.py index 5858597..b4a0d97 100644 --- a/tests/models/core/test_api.py +++ b/tests/models/core/test_api.py @@ -247,7 +247,7 @@ async def test_stream_requires_model_messages_or_context() -> None: pass -async def test_stream_accepts_protocol_kwarg() -> None: +async def test_stream_uses_model_protocol() -> None: class OverrideProtocol(models.ProviderProtocol[Any]): def stream( self, @@ -272,9 +272,8 @@ async def _stream() -> AsyncGenerator[events_.Event]: return _stream() async with models.stream( - MOCK_MODEL, + MOCK_MODEL.with_protocol(OverrideProtocol()), [ai.user_message("Hi")], - protocol=OverrideProtocol(), ) as stream: async for _ in stream: pass @@ -315,7 +314,7 @@ async def _generate( assert result is sentinel -async def test_generate_accepts_protocol_kwarg() -> None: +async def test_generate_uses_model_protocol() -> None: sentinel = messages_.Message( role="assistant", parts=[messages_.FilePart(data=b"\x89PNG", media_type="image/png")], @@ -335,10 +334,9 @@ async def generate( return sentinel result = await models.generate( - MOCK_MODEL, + MOCK_MODEL.with_protocol(OverrideProtocol()), [ai.user_message("A cat")], models.ImageParams(n=1), - protocol=OverrideProtocol(), ) assert result is sentinel