Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion go/adk/pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,14 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM,
if modelName == "" {
modelName = DefaultGeminiModel
}
return adkgemini.NewModel(ctx, modelName, &genai.ClientConfig{APIKey: apiKey})
httpClient, err := models.BuildHTTPClient(transportConfigFromBase(m.BaseModel, nil))
if err != nil {
return nil, fmt.Errorf("failed to build HTTP client for Gemini: %w", err)
}
return adkgemini.NewModel(ctx, modelName, &genai.ClientConfig{
APIKey: apiKey,
HTTPClient: httpClient,
Comment thread
supreme-gg-gg marked this conversation as resolved.
})

case *adk.GeminiVertexAI:
project := os.Getenv("GOOGLE_CLOUD_PROJECT")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ._anthropic import KAgentAnthropicLlm
from ._bedrock import KAgentBedrockLlm
from ._embedding import KAgentEmbedding
from ._gemini import KAgentGeminiLlm
from ._ollama import KAgentOllamaLlm
from ._openai import AzureOpenAI, OpenAI
from ._sap_ai_core import KAgentSAPAICoreLlm
Expand All @@ -10,6 +11,7 @@
"AzureOpenAI",
"KAgentAnthropicLlm",
"KAgentBedrockLlm",
"KAgentGeminiLlm",
"KAgentOllamaLlm",
"KAgentEmbedding",
"KAgentSAPAICoreLlm",
Expand Down
23 changes: 20 additions & 3 deletions python/packages/kagent-adk/src/kagent/adk/models/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from anthropic import AsyncAnthropic
from google.adk.models.anthropic_llm import AnthropicLlm

from ._ssl import KAgentTLSMixin

logger = logging.getLogger(__name__)


class KAgentAnthropicLlm(AnthropicLlm):
"""Anthropic model with api_key_passthrough, custom base_url, and header support."""
class KAgentAnthropicLlm(KAgentTLSMixin, AnthropicLlm):
"""Anthropic model with api_key_passthrough, custom base_url, header, and TLS support."""

api_key_passthrough: Optional[bool] = None

Expand All @@ -27,8 +29,17 @@ class KAgentAnthropicLlm(AnthropicLlm):
def set_passthrough_key(self, token: str) -> None:
"""Forward the Bearer token from the incoming A2A request as the Anthropic API key."""
self._api_key = token
# Invalidate cached client so it's recreated with the new key
# Invalidate cached clients so they're recreated with the new key
self.__dict__.pop("_anthropic_client", None)
self.__dict__.pop("_http_client", None)

def _create_http_client(self):
"""Create HTTP client with custom SSL context using Anthropic SDK defaults.

Returns:
httpx.AsyncClient with SSL configuration, or None if no TLS config
"""
return self._httpx_async_client_if_tls()

@cached_property
def _anthropic_client(self) -> AsyncAnthropic:
Expand All @@ -40,4 +51,10 @@ def _anthropic_client(self) -> AsyncAnthropic:
kwargs["base_url"] = self.base_url
if self.extra_headers:
kwargs["default_headers"] = self.extra_headers

# Use the httpx.AsyncClient with SSL configuration if present
http_client = self._create_http_client()
if http_client is not None:
kwargs["http_client"] = http_client

return AsyncAnthropic(**kwargs)
33 changes: 30 additions & 3 deletions python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from google.adk.models.llm_response import LlmResponse
from google.genai import types

from ._ssl import KAgentTLSMixin

if TYPE_CHECKING:
from google.adk.models.llm_request import LlmRequest

Expand Down Expand Up @@ -54,12 +56,31 @@ def _sanitize_tool_id(tool_id: str, id_map: dict[str, str], counter: list[int])
return sanitized


def _get_bedrock_client(extra_headers: Optional[dict[str, str]] = None):
def _get_bedrock_client(
extra_headers: Optional[dict[str, str]] = None,
tls_disable_verify: Optional[bool] = None,
tls_ca_cert_path: Optional[str] = None,
tls_disable_system_cas: Optional[bool] = None,
):
region = os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION") or "us-east-1"
kwargs: dict[str, Any] = {"region_name": region}

if extra_headers:
# boto3 doesn't support custom headers natively; log and ignore
logger.warning("extra_headers are not supported for Bedrock models and will be ignored.")

# TLS/SSL configuration via boto3 verify parameter
if tls_disable_verify:
kwargs["verify"] = False
elif tls_ca_cert_path:
kwargs["verify"] = tls_ca_cert_path

if tls_disable_system_cas and tls_ca_cert_path:
logger.warning(
"disable_system_cas is not fully supported by boto3 for Bedrock; "
"using custom CA bundle only. System CAs may still be trusted."
)

return boto3.client("bedrock-runtime", **kwargs)


Expand Down Expand Up @@ -187,7 +208,7 @@ def _stop_reason_to_finish_reason(stop_reason: str) -> types.FinishReason:
return types.FinishReason.STOP


class KAgentBedrockLlm(BaseLlm):
class KAgentBedrockLlm(KAgentTLSMixin, BaseLlm):
"""Bedrock model via the Converse API.

Supports all Bedrock-compatible models (Anthropic, Meta, Mistral, Amazon, etc.).
Expand All @@ -196,11 +217,17 @@ class KAgentBedrockLlm(BaseLlm):

extra_headers: Optional[dict[str, str]] = None
additional_model_request_fields: Optional[dict[str, Any]] = None

model_config = {"arbitrary_types_allowed": True}

@cached_property
def _client(self):
return _get_bedrock_client(self.extra_headers)
return _get_bedrock_client(
extra_headers=self.extra_headers,
tls_disable_verify=self.tls_disable_verify,
tls_ca_cert_path=self.tls_ca_cert_path,
tls_disable_system_cas=self.tls_disable_system_cas,
)

@classmethod
def supported_models(cls) -> list[str]:
Expand Down
59 changes: 59 additions & 0 deletions python/packages/kagent-adk/src/kagent/adk/models/_gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Gemini model wrapper with kagent transport configuration."""

from __future__ import annotations

import os
from functools import cached_property
from typing import Optional

from google.adk.models.google_llm import Gemini as GeminiLLM
from google.adk.utils._google_client_headers import get_tracking_headers
from google.genai import Client, types

from ._ssl import KAgentTLSMixin


def _merge_headers(extra_headers: Optional[dict[str, str]]) -> dict[str, str]:
headers = get_tracking_headers()
if extra_headers:
headers.update(extra_headers)
return headers


class KAgentGeminiLlm(KAgentTLSMixin, GeminiLLM):
"""Gemini API model that applies kagent TLS and header settings."""

extra_headers: Optional[dict[str, str]] = None
api_key_passthrough: Optional[bool] = None
Comment thread
supreme-gg-gg marked this conversation as resolved.

model_config = {"arbitrary_types_allowed": True}

def _http_options(self, *, api_version: str | None = None) -> types.HttpOptions:
verify = self._tls_verify()
kwargs = {}
if verify is not None:
kwargs = {
"client_args": {"verify": verify},
"async_client_args": {"verify": verify, "ssl": verify},
}
return types.HttpOptions(
headers=_merge_headers(self.extra_headers),
retry_options=self.retry_options,
base_url=self.base_url,
api_version=api_version,
**kwargs,
)

@cached_property
def api_client(self) -> Client:
return Client(
api_key=os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY"),
http_options=self._http_options(),
)

@cached_property
def _live_api_client(self) -> Client:
return Client(
api_key=os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY"),
http_options=self._http_options(api_version=self._live_api_version),
)
22 changes: 20 additions & 2 deletions python/packages/kagent-adk/src/kagent/adk/models/_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from ollama import AsyncClient
from ollama import Message as OllamaMessage

from ._ssl import KAgentTLSMixin

if TYPE_CHECKING:
from google.adk.models.llm_request import LlmRequest

Expand Down Expand Up @@ -131,7 +133,7 @@ def _convert_tools_to_ollama(tools: list[types.Tool]) -> list[ollama_sdk.Tool]:
return ollama_tools


class KAgentOllamaLlm(BaseLlm):
class KAgentOllamaLlm(KAgentTLSMixin, BaseLlm):
"""Ollama model via the native Ollama SDK.

All Ollama options (temperature, top_p, top_k, num_ctx, etc.) are forwarded
Expand All @@ -148,11 +150,19 @@ class KAgentOllamaLlm(BaseLlm):
type: Literal["ollama"] = "ollama"
ollama_options: Optional[dict[str, object]] = None
default_headers: Optional[dict[str, str]] = None
api_key_passthrough: Optional[bool] = None

@cached_property
def _client(self) -> AsyncClient:
host = os.environ.get("OLLAMA_API_BASE", "http://localhost:11434")
return AsyncClient(host=host, headers=self.default_headers or {})
kwargs: dict[str, object] = {
"host": host,
"headers": self.default_headers or {},
}

kwargs.update(self._tls_httpx_kwargs())

return AsyncClient(**kwargs)

@classmethod
def supported_models(cls) -> list[str]:
Expand Down Expand Up @@ -261,6 +271,10 @@ def create_ollama_llm(
model: str,
options: dict[str, object] | None,
extra_headers: dict[str, str],
tls_disable_verify: Optional[bool] = None,
tls_ca_cert_path: Optional[str] = None,
tls_disable_system_cas: Optional[bool] = None,
api_key_passthrough: Optional[bool] = None,
) -> KAgentOllamaLlm:
"""Build a KAgentOllamaLlm from Ollama options.

Expand All @@ -272,4 +286,8 @@ def create_ollama_llm(
model=model,
ollama_options=options or None,
default_headers=extra_headers or {},
tls_disable_verify=tls_disable_verify,
tls_ca_cert_path=tls_ca_cert_path,
tls_disable_system_cas=tls_disable_system_cas,
api_key_passthrough=api_key_passthrough,
)
40 changes: 3 additions & 37 deletions python/packages/kagent-adk/src/kagent/adk/models/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from openai.types.shared_params import FunctionDefinition, FunctionParameters
from pydantic import Field

from ._ssl import create_ssl_context
from ._ssl import KAgentTLSMixin
from ._token_source import GDCHTokenSource

if TYPE_CHECKING:
Expand Down Expand Up @@ -365,7 +365,7 @@ def _convert_openai_response_to_llm_response(response: ChatCompletion) -> LlmRes
return LlmResponse(content=content, usage_metadata=usage_metadata, finish_reason=finish_reason)


class BaseOpenAI(BaseLlm):
class BaseOpenAI(KAgentTLSMixin, BaseLlm):
"""Base class for OpenAI-compatible models."""

model: str
Expand All @@ -382,11 +382,6 @@ class BaseOpenAI(BaseLlm):
timeout: Optional[int] = None
top_p: Optional[float] = None

# TLS/SSL configuration fields
tls_disable_verify: Optional[bool] = None
tls_ca_cert_path: Optional[str] = None
tls_disable_system_cas: Optional[bool] = None

# API key passthrough: forward the Bearer token from incoming requests as the LLM API key
api_key_passthrough: Optional[bool] = None

Expand All @@ -403,20 +398,6 @@ def supported_models(cls) -> list[str]:
"""Returns a list of supported models in regex for LlmRegistry."""
return [r"gpt-.*", r"o1-.*"]

def _get_tls_config(self) -> tuple[bool, Optional[str], bool]:
"""Read TLS configuration from instance fields.

Returns:
Tuple of (disable_verify, ca_cert_path, disable_system_cas)
"""
# Read from instance fields only (config-based approach)
# Environment variables are no longer supported for TLS configuration
disable_verify = self.tls_disable_verify or False
ca_cert_path = self.tls_ca_cert_path
disable_system_cas = self.tls_disable_system_cas or False

return disable_verify, ca_cert_path, disable_system_cas

def _create_http_client(self) -> Optional[httpx.AsyncClient]:
"""Create HTTP client with custom SSL context using OpenAI SDK defaults.

Expand All @@ -427,22 +408,7 @@ def _create_http_client(self) -> Optional[httpx.AsyncClient]:
Returns:
DefaultAsyncHttpxClient with SSL configuration, or None if no TLS config
"""
disable_verify, ca_cert_path, disable_system_cas = self._get_tls_config()

# Only create custom http client if TLS configuration is present
if disable_verify or ca_cert_path or disable_system_cas:
ssl_context = create_ssl_context(
disable_verify=disable_verify,
ca_cert_path=ca_cert_path,
disable_system_cas=disable_system_cas,
)

# ssl_context is either False (verification disabled) or SSLContext
# Use DefaultAsyncHttpxClient to preserve OpenAI's defaults
return DefaultAsyncHttpxClient(verify=ssl_context)

# No TLS configuration, return None to use OpenAI SDK default
return None
return self._httpx_async_client_if_tls(DefaultAsyncHttpxClient)

@cached_property
def _client(self) -> AsyncOpenAI:
Expand Down
42 changes: 42 additions & 0 deletions python/packages/kagent-adk/src/kagent/adk/models/_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import ssl
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

import httpx

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -243,3 +246,42 @@ def create_ssl_context(
) from e

return ctx


class KAgentTLSMixin:
"""Mixin for model wrappers that accept kagent TLS configuration."""

tls_disable_verify: Optional[bool] = None
tls_ca_cert_path: Optional[str] = None
tls_disable_system_cas: Optional[bool] = None

def _has_tls_config(self) -> bool:
"""Return True if the model has any TLS config."""
return bool(self.tls_disable_verify or self.tls_ca_cert_path or self.tls_disable_system_cas)

def _tls_verify(self) -> ssl.SSLContext | bool:
"""Return the SSL context for the model."""
if not self._has_tls_config():
return None
return create_ssl_context(
disable_verify=self.tls_disable_verify or False,
ca_cert_path=self.tls_ca_cert_path,
disable_system_cas=self.tls_disable_system_cas or False,
)

def _tls_httpx_kwargs(self) -> dict[str, object]:
"""Return the HTTPX kwargs for the model."""
verify = self._tls_verify()
if verify is None:
return {}
return {"verify": verify}

def _httpx_async_client_if_tls(self, client_cls=httpx.AsyncClient, **kwargs) -> httpx.AsyncClient | None:
"""
Return the HTTPX client for the model. If no TLS config is present, return None.
If client_cls is provided, use it to create the client. Otherwise, use httpx.AsyncClient.
"""
tls_kwargs = self._tls_httpx_kwargs()
if not tls_kwargs:
return None
return client_cls(**tls_kwargs, **kwargs)
Loading
Loading