From 5616026efce098f1f66b3b3e93da87aed17d4626 Mon Sep 17 00:00:00 2001 From: Jet Chiang Date: Mon, 4 May 2026 18:57:39 -0400 Subject: [PATCH 1/3] go changes Signed-off-by: Jet Chiang --- go/adk/pkg/agent/agent.go | 20 ++++++++++++++------ go/adk/pkg/models/sapaicore.go | 10 +++++++++- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/go/adk/pkg/agent/agent.go b/go/adk/pkg/agent/agent.go index 1a7953f99..74eff615a 100644 --- a/go/adk/pkg/agent/agent.go +++ b/go/adk/pkg/agent/agent.go @@ -216,7 +216,14 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM, if modelName == "" { modelName = DefaultGeminiModel } - return adkgemini.NewModel(ctx, modelName, &genai.ClientConfig{APIKey: apiKey}) + httpClient, err := models.BuildHTTPClient(transportConfigFromBase(m.BaseModel, nil)) + if err != nil { + return nil, fmt.Errorf("failed to build HTTP client for Gemini: %w", err) + } + return adkgemini.NewModel(ctx, modelName, &genai.ClientConfig{ + APIKey: apiKey, + HTTPClient: httpClient, + }) case *adk.GeminiVertexAI: project := os.Getenv("GOOGLE_CLOUD_PROJECT") @@ -315,11 +322,12 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM, case *adk.SAPAICore: cfg := models.SAPAICoreConfig{ - Model: m.Model, - BaseUrl: m.BaseUrl, - ResourceGroup: m.ResourceGroup, - AuthUrl: m.AuthUrl, - Headers: extractHeaders(m.Headers), + TransportConfig: transportConfigFromBase(m.BaseModel, nil), + Model: m.Model, + BaseUrl: m.BaseUrl, + ResourceGroup: m.ResourceGroup, + AuthUrl: m.AuthUrl, + Headers: extractHeaders(m.Headers), } return models.NewSAPAICoreModelWithLogger(cfg, log) diff --git a/go/adk/pkg/models/sapaicore.go b/go/adk/pkg/models/sapaicore.go index 380e4f6e7..35290a163 100644 --- a/go/adk/pkg/models/sapaicore.go +++ b/go/adk/pkg/models/sapaicore.go @@ -15,6 +15,7 @@ import ( ) type SAPAICoreConfig struct { + TransportConfig Model string BaseUrl string ResourceGroup string @@ -41,10 +42,17 @@ func NewSAPAICoreModelWithLogger(config SAPAICoreConfig, logger logr.Logger) (*S if config.ResourceGroup == "" { config.ResourceGroup = "default" } + + // Build HTTP client with TLS, custom headers, and timeout support + httpClient, err := BuildHTTPClient(config.TransportConfig) + if err != nil { + return nil, fmt.Errorf("failed to create SAP AI Core HTTP client: %w", err) + } + return &SAPAICoreModel{ Config: config, Logger: logger, - httpClient: &http.Client{Timeout: 5 * time.Minute}, + httpClient: httpClient, }, nil } From cf1745af04c6548260694044ac4165ab7d1b1332 Mon Sep 17 00:00:00 2001 From: Jet Chiang Date: Mon, 4 May 2026 19:01:07 -0400 Subject: [PATCH 2/3] python model changes Signed-off-by: Jet Chiang --- .../src/kagent/adk/models/__init__.py | 2 + .../src/kagent/adk/models/_anthropic.py | 23 +++++- .../src/kagent/adk/models/_bedrock.py | 33 +++++++- .../src/kagent/adk/models/_gemini.py | 59 +++++++++++++++ .../src/kagent/adk/models/_ollama.py | 19 ++++- .../src/kagent/adk/models/_openai.py | 40 +--------- .../src/kagent/adk/models/_sap_ai_core.py | 75 ++++++++----------- .../kagent-adk/src/kagent/adk/models/_ssl.py | 42 +++++++++++ .../kagent-adk/src/kagent/adk/types.py | 46 ++++++++---- .../tests/unittests/models/test_openai.py | 10 +-- .../unittests/models/test_tls_integration.py | 24 +++++- 11 files changed, 262 insertions(+), 111 deletions(-) create mode 100644 python/packages/kagent-adk/src/kagent/adk/models/_gemini.py diff --git a/python/packages/kagent-adk/src/kagent/adk/models/__init__.py b/python/packages/kagent-adk/src/kagent/adk/models/__init__.py index 27e1dc246..3176704ce 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/__init__.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/__init__.py @@ -1,6 +1,7 @@ from ._anthropic import KAgentAnthropicLlm from ._bedrock import KAgentBedrockLlm from ._embedding import KAgentEmbedding +from ._gemini import KAgentGeminiLlm from ._ollama import KAgentOllamaLlm from ._openai import AzureOpenAI, OpenAI from ._sap_ai_core import KAgentSAPAICoreLlm @@ -10,6 +11,7 @@ "AzureOpenAI", "KAgentAnthropicLlm", "KAgentBedrockLlm", + "KAgentGeminiLlm", "KAgentOllamaLlm", "KAgentEmbedding", "KAgentSAPAICoreLlm", diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_anthropic.py b/python/packages/kagent-adk/src/kagent/adk/models/_anthropic.py index 4a5f51196..b8e9e68cd 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/_anthropic.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/_anthropic.py @@ -10,11 +10,13 @@ from anthropic import AsyncAnthropic from google.adk.models.anthropic_llm import AnthropicLlm +from ._ssl import KAgentTLSMixin + logger = logging.getLogger(__name__) -class KAgentAnthropicLlm(AnthropicLlm): - """Anthropic model with api_key_passthrough, custom base_url, and header support.""" +class KAgentAnthropicLlm(KAgentTLSMixin, AnthropicLlm): + """Anthropic model with api_key_passthrough, custom base_url, header, and TLS support.""" api_key_passthrough: Optional[bool] = None @@ -27,8 +29,17 @@ class KAgentAnthropicLlm(AnthropicLlm): def set_passthrough_key(self, token: str) -> None: """Forward the Bearer token from the incoming A2A request as the Anthropic API key.""" self._api_key = token - # Invalidate cached client so it's recreated with the new key + # Invalidate cached clients so they're recreated with the new key self.__dict__.pop("_anthropic_client", None) + self.__dict__.pop("_http_client", None) + + def _create_http_client(self): + """Create HTTP client with custom SSL context using Anthropic SDK defaults. + + Returns: + httpx.AsyncClient with SSL configuration, or None if no TLS config + """ + return self._httpx_async_client_if_tls() @cached_property def _anthropic_client(self) -> AsyncAnthropic: @@ -40,4 +51,10 @@ def _anthropic_client(self) -> AsyncAnthropic: kwargs["base_url"] = self.base_url if self.extra_headers: kwargs["default_headers"] = self.extra_headers + + # Use the httpx.AsyncClient with SSL configuration if present + http_client = self._create_http_client() + if http_client is not None: + kwargs["http_client"] = http_client + return AsyncAnthropic(**kwargs) diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py b/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py index e1ebdbfb6..f7615b5fd 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py @@ -20,6 +20,8 @@ from google.adk.models.llm_response import LlmResponse from google.genai import types +from ._ssl import KAgentTLSMixin + if TYPE_CHECKING: from google.adk.models.llm_request import LlmRequest @@ -54,12 +56,31 @@ def _sanitize_tool_id(tool_id: str, id_map: dict[str, str], counter: list[int]) return sanitized -def _get_bedrock_client(extra_headers: Optional[dict[str, str]] = None): +def _get_bedrock_client( + extra_headers: Optional[dict[str, str]] = None, + tls_disable_verify: Optional[bool] = None, + tls_ca_cert_path: Optional[str] = None, + tls_disable_system_cas: Optional[bool] = None, +): region = os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION") or "us-east-1" kwargs: dict[str, Any] = {"region_name": region} + if extra_headers: # boto3 doesn't support custom headers natively; log and ignore logger.warning("extra_headers are not supported for Bedrock models and will be ignored.") + + # TLS/SSL configuration via boto3 verify parameter + if tls_disable_verify: + kwargs["verify"] = False + elif tls_ca_cert_path: + kwargs["verify"] = tls_ca_cert_path + + if tls_disable_system_cas and tls_ca_cert_path: + logger.warning( + "disable_system_cas is not fully supported by boto3 for Bedrock; " + "using custom CA bundle only. System CAs may still be trusted." + ) + return boto3.client("bedrock-runtime", **kwargs) @@ -187,7 +208,7 @@ def _stop_reason_to_finish_reason(stop_reason: str) -> types.FinishReason: return types.FinishReason.STOP -class KAgentBedrockLlm(BaseLlm): +class KAgentBedrockLlm(KAgentTLSMixin, BaseLlm): """Bedrock model via the Converse API. Supports all Bedrock-compatible models (Anthropic, Meta, Mistral, Amazon, etc.). @@ -196,11 +217,17 @@ class KAgentBedrockLlm(BaseLlm): extra_headers: Optional[dict[str, str]] = None additional_model_request_fields: Optional[dict[str, Any]] = None + model_config = {"arbitrary_types_allowed": True} @cached_property def _client(self): - return _get_bedrock_client(self.extra_headers) + return _get_bedrock_client( + extra_headers=self.extra_headers, + tls_disable_verify=self.tls_disable_verify, + tls_ca_cert_path=self.tls_ca_cert_path, + tls_disable_system_cas=self.tls_disable_system_cas, + ) @classmethod def supported_models(cls) -> list[str]: diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_gemini.py b/python/packages/kagent-adk/src/kagent/adk/models/_gemini.py new file mode 100644 index 000000000..8b5189e7e --- /dev/null +++ b/python/packages/kagent-adk/src/kagent/adk/models/_gemini.py @@ -0,0 +1,59 @@ +"""Gemini model wrapper with kagent transport configuration.""" + +from __future__ import annotations + +import os +from functools import cached_property +from typing import Optional + +from google.adk.models.google_llm import Gemini as GeminiLLM +from google.adk.utils._google_client_headers import get_tracking_headers +from google.genai import Client, types + +from ._ssl import KAgentTLSMixin + + +def _merge_headers(extra_headers: Optional[dict[str, str]]) -> dict[str, str]: + headers = get_tracking_headers() + if extra_headers: + headers.update(extra_headers) + return headers + + +class KAgentGeminiLlm(KAgentTLSMixin, GeminiLLM): + """Gemini API model that applies kagent TLS and header settings.""" + + extra_headers: Optional[dict[str, str]] = None + api_key_passthrough: Optional[bool] = None + + model_config = {"arbitrary_types_allowed": True} + + def _http_options(self, *, api_version: str | None = None) -> types.HttpOptions: + verify = self._tls_verify() + kwargs = {} + if verify is not None: + kwargs = { + "client_args": {"verify": verify}, + "async_client_args": {"verify": verify, "ssl": verify}, + } + return types.HttpOptions( + headers=_merge_headers(self.extra_headers), + retry_options=self.retry_options, + base_url=self.base_url, + api_version=api_version, + **kwargs, + ) + + @cached_property + def api_client(self) -> Client: + return Client( + api_key=os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY"), + http_options=self._http_options(), + ) + + @cached_property + def _live_api_client(self) -> Client: + return Client( + api_key=os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY"), + http_options=self._http_options(api_version=self._live_api_version), + ) diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_ollama.py b/python/packages/kagent-adk/src/kagent/adk/models/_ollama.py index 9a5554146..cfaf5f701 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/_ollama.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/_ollama.py @@ -15,6 +15,8 @@ from ollama import AsyncClient from ollama import Message as OllamaMessage +from ._ssl import KAgentTLSMixin + if TYPE_CHECKING: from google.adk.models.llm_request import LlmRequest @@ -131,7 +133,7 @@ def _convert_tools_to_ollama(tools: list[types.Tool]) -> list[ollama_sdk.Tool]: return ollama_tools -class KAgentOllamaLlm(BaseLlm): +class KAgentOllamaLlm(KAgentTLSMixin, BaseLlm): """Ollama model via the native Ollama SDK. All Ollama options (temperature, top_p, top_k, num_ctx, etc.) are forwarded @@ -152,7 +154,14 @@ class KAgentOllamaLlm(BaseLlm): @cached_property def _client(self) -> AsyncClient: host = os.environ.get("OLLAMA_API_BASE", "http://localhost:11434") - return AsyncClient(host=host, headers=self.default_headers or {}) + kwargs: dict[str, object] = { + "host": host, + "headers": self.default_headers or {}, + } + + kwargs.update(self._tls_httpx_kwargs()) + + return AsyncClient(**kwargs) @classmethod def supported_models(cls) -> list[str]: @@ -261,6 +270,9 @@ def create_ollama_llm( model: str, options: dict[str, object] | None, extra_headers: dict[str, str], + tls_disable_verify: Optional[bool] = None, + tls_ca_cert_path: Optional[str] = None, + tls_disable_system_cas: Optional[bool] = None, ) -> KAgentOllamaLlm: """Build a KAgentOllamaLlm from Ollama options. @@ -272,4 +284,7 @@ def create_ollama_llm( model=model, ollama_options=options or None, default_headers=extra_headers or {}, + tls_disable_verify=tls_disable_verify, + tls_ca_cert_path=tls_ca_cert_path, + tls_disable_system_cas=tls_disable_system_cas, ) diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_openai.py b/python/packages/kagent-adk/src/kagent/adk/models/_openai.py index bc0e370ae..ae6fd6df4 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/_openai.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/_openai.py @@ -32,7 +32,7 @@ from openai.types.shared_params import FunctionDefinition, FunctionParameters from pydantic import Field -from ._ssl import create_ssl_context +from ._ssl import KAgentTLSMixin from ._token_source import GDCHTokenSource if TYPE_CHECKING: @@ -365,7 +365,7 @@ def _convert_openai_response_to_llm_response(response: ChatCompletion) -> LlmRes return LlmResponse(content=content, usage_metadata=usage_metadata, finish_reason=finish_reason) -class BaseOpenAI(BaseLlm): +class BaseOpenAI(KAgentTLSMixin, BaseLlm): """Base class for OpenAI-compatible models.""" model: str @@ -382,11 +382,6 @@ class BaseOpenAI(BaseLlm): timeout: Optional[int] = None top_p: Optional[float] = None - # TLS/SSL configuration fields - tls_disable_verify: Optional[bool] = None - tls_ca_cert_path: Optional[str] = None - tls_disable_system_cas: Optional[bool] = None - # API key passthrough: forward the Bearer token from incoming requests as the LLM API key api_key_passthrough: Optional[bool] = None @@ -403,20 +398,6 @@ def supported_models(cls) -> list[str]: """Returns a list of supported models in regex for LlmRegistry.""" return [r"gpt-.*", r"o1-.*"] - def _get_tls_config(self) -> tuple[bool, Optional[str], bool]: - """Read TLS configuration from instance fields. - - Returns: - Tuple of (disable_verify, ca_cert_path, disable_system_cas) - """ - # Read from instance fields only (config-based approach) - # Environment variables are no longer supported for TLS configuration - disable_verify = self.tls_disable_verify or False - ca_cert_path = self.tls_ca_cert_path - disable_system_cas = self.tls_disable_system_cas or False - - return disable_verify, ca_cert_path, disable_system_cas - def _create_http_client(self) -> Optional[httpx.AsyncClient]: """Create HTTP client with custom SSL context using OpenAI SDK defaults. @@ -427,22 +408,7 @@ def _create_http_client(self) -> Optional[httpx.AsyncClient]: Returns: DefaultAsyncHttpxClient with SSL configuration, or None if no TLS config """ - disable_verify, ca_cert_path, disable_system_cas = self._get_tls_config() - - # Only create custom http client if TLS configuration is present - if disable_verify or ca_cert_path or disable_system_cas: - ssl_context = create_ssl_context( - disable_verify=disable_verify, - ca_cert_path=ca_cert_path, - disable_system_cas=disable_system_cas, - ) - - # ssl_context is either False (verification disabled) or SSLContext - # Use DefaultAsyncHttpxClient to preserve OpenAI's defaults - return DefaultAsyncHttpxClient(verify=ssl_context) - - # No TLS configuration, return None to use OpenAI SDK default - return None + return self._httpx_async_client_if_tls(DefaultAsyncHttpxClient) @cached_property def _client(self) -> AsyncOpenAI: diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_sap_ai_core.py b/python/packages/kagent-adk/src/kagent/adk/models/_sap_ai_core.py index b3da12401..96960b16f 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/_sap_ai_core.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/_sap_ai_core.py @@ -2,11 +2,9 @@ from __future__ import annotations -import asyncio import json import logging import os -import ssl import time from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional @@ -16,6 +14,7 @@ from google.genai import types from ._openai import _convert_tools_to_openai +from ._ssl import KAgentTLSMixin if TYPE_CHECKING: from google.adk.models.llm_request import LlmRequest @@ -23,30 +22,32 @@ logger = logging.getLogger(__name__) -async def _fetch_oauth_token(auth_url: str, client_id: str, client_secret: str) -> tuple[str, float]: +async def _fetch_oauth_token( + auth_url: str, + client_id: str, + client_secret: str, + client: httpx.AsyncClient, +) -> tuple[str, float]: """Fetch a new OAuth2 token from the auth server. No caching — callers manage expiry.""" token_url = auth_url.rstrip("/") if not token_url.endswith("/oauth/token"): token_url += "/oauth/token" - def _sync_fetch() -> tuple[str, float]: - resp = httpx.post( - token_url, - data={ - "grant_type": "client_credentials", - "client_id": client_id, - "client_secret": client_secret, - }, - headers={"Content-Type": "application/x-www-form-urlencoded"}, - timeout=30, - ) - resp.raise_for_status() - data = resp.json() - token = data["access_token"] - expires_at = time.time() + data.get("expires_in", 43200) - return token, expires_at - - return await asyncio.to_thread(_sync_fetch) + resp = await client.post( + token_url, + data={ + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30, + ) + resp.raise_for_status() + data = resp.json() + token = data["access_token"] + expires_at = time.time() + data.get("expires_in", 43200) + return token, expires_at def _build_orchestration_template( @@ -145,7 +146,7 @@ def _parse_orchestration_chunk(event_data: dict[str, Any]) -> Optional[dict[str, _RETRYABLE_STATUS_CODES = {401, 403, 404, 502, 503, 504} -class KAgentSAPAICoreLlm(BaseLlm): +class KAgentSAPAICoreLlm(KAgentTLSMixin, BaseLlm): """SAP AI Core LLM via Orchestration Service. Supports all model families (OpenAI, Anthropic, Gemini, etc.) through @@ -157,10 +158,6 @@ class KAgentSAPAICoreLlm(BaseLlm): auth_url: Optional[str] = None api_key_passthrough: Optional[bool] = None - tls_disable_verify: bool = False - tls_ca_cert_path: Optional[str] = None - tls_disable_system_cas: bool = False - _passthrough_key: Optional[str] = None _http_client: Optional[httpx.AsyncClient] = None _token: Optional[str] = None @@ -178,28 +175,11 @@ def set_passthrough_key(self, token: str) -> None: self._passthrough_key = token self._http_client = None - def _create_ssl_context(self) -> Optional[ssl.SSLContext]: - if not self.tls_disable_verify and not self.tls_ca_cert_path and not self.tls_disable_system_cas: - return None - ctx = ssl.create_default_context() - if self.tls_disable_verify: - ctx.check_hostname = False - ctx.verify_mode = ssl.CERT_NONE - elif self.tls_disable_system_cas: - ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - if self.tls_ca_cert_path: - ctx.load_verify_locations(self.tls_ca_cert_path) - elif self.tls_ca_cert_path: - ctx.load_verify_locations(self.tls_ca_cert_path) - return ctx - def _get_http_client(self) -> httpx.AsyncClient: if self._http_client is not None: return self._http_client - ssl_ctx = self._create_ssl_context() kwargs: dict[str, Any] = {"timeout": 300} - if ssl_ctx is not None: - kwargs["verify"] = ssl_ctx + kwargs.update(self._tls_httpx_kwargs()) self._http_client = httpx.AsyncClient(**kwargs) return self._http_client @@ -223,7 +203,12 @@ async def _ensure_token(self) -> str: client_secret = os.environ.get("SAP_AI_CORE_CLIENT_SECRET", "") if self.auth_url and client_id and client_secret: - token, expires_at = await _fetch_oauth_token(self.auth_url, client_id, client_secret) + token, expires_at = await _fetch_oauth_token( + self.auth_url, + client_id, + client_secret, + self._get_http_client(), + ) self._token = token self._token_expires_at = expires_at return token diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_ssl.py b/python/packages/kagent-adk/src/kagent/adk/models/_ssl.py index 7e0a10428..ffd0b831e 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/_ssl.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/_ssl.py @@ -4,6 +4,9 @@ import ssl from datetime import datetime, timezone from pathlib import Path +from typing import Optional + +import httpx logger = logging.getLogger(__name__) @@ -243,3 +246,42 @@ def create_ssl_context( ) from e return ctx + + +class KAgentTLSMixin: + """Mixin for model wrappers that accept kagent TLS configuration.""" + + tls_disable_verify: Optional[bool] = None + tls_ca_cert_path: Optional[str] = None + tls_disable_system_cas: Optional[bool] = None + + def _has_tls_config(self) -> bool: + """Return True if the model has any TLS config.""" + return bool(self.tls_disable_verify or self.tls_ca_cert_path or self.tls_disable_system_cas) + + def _tls_verify(self) -> ssl.SSLContext | bool: + """Return the SSL context for the model.""" + if not self._has_tls_config(): + return None + return create_ssl_context( + disable_verify=self.tls_disable_verify or False, + ca_cert_path=self.tls_ca_cert_path, + disable_system_cas=self.tls_disable_system_cas or False, + ) + + def _tls_httpx_kwargs(self) -> dict[str, object]: + """Return the HTTPX kwargs for the model.""" + verify = self._tls_verify() + if verify is None: + return {} + return {"verify": verify} + + def _httpx_async_client_if_tls(self, client_cls=httpx.AsyncClient, **kwargs) -> httpx.AsyncClient | None: + """ + Return the HTTPX client for the model. If no TLS config is present, return None. + If client_cls is provided, use it to create the client. Otherwise, use httpx.AsyncClient. + """ + tls_kwargs = self._tls_httpx_kwargs() + if not tls_kwargs: + return None + return client_cls(**tls_kwargs, **kwargs) diff --git a/python/packages/kagent-adk/src/kagent/adk/types.py b/python/packages/kagent-adk/src/kagent/adk/types.py index 635fca9c9..96c9ee6f0 100644 --- a/python/packages/kagent-adk/src/kagent/adk/types.py +++ b/python/packages/kagent-adk/src/kagent/adk/types.py @@ -18,6 +18,7 @@ from kagent.adk._remote_a2a_tool import KAgentRemoteA2AToolset from kagent.adk.models._anthropic import KAgentAnthropicLlm from kagent.adk.models._bedrock import KAgentBedrockLlm +from kagent.adk.models._gemini import KAgentGeminiLlm from kagent.adk.models._ollama import create_ollama_llm from kagent.adk.models._openai import AzureOpenAI as OpenAIAzure from kagent.adk.models._openai import OpenAI as OpenAINative @@ -492,6 +493,26 @@ async def auto_save_session_to_memory_callback(callback_context: CallbackContext logger.error("Failed to inject memory configuration: %s", e) +def _transport_kwargs(model_config: BaseLLM) -> dict[str, Any]: + """Extract TLS/transport kwargs shared by most model types. + + Returns a dict with api_key_passthrough and TLS fields so callers + can spread them with ``**_transport_kwargs(model_config)`` instead of + repeating the same four lines in every branch of + ``_create_llm_from_model_config``. + """ + kwargs: dict[str, Any] = {} + if model_config.api_key_passthrough is not None: + kwargs["api_key_passthrough"] = model_config.api_key_passthrough + if model_config.tls_disable_verify is not None: + kwargs["tls_disable_verify"] = model_config.tls_disable_verify + if model_config.tls_ca_cert_path is not None: + kwargs["tls_ca_cert_path"] = model_config.tls_ca_cert_path + if model_config.tls_disable_system_cas is not None: + kwargs["tls_disable_system_cas"] = model_config.tls_disable_system_cas + return kwargs + + def _create_llm_from_model_config(model_config: ModelUnion): extra_headers = model_config.headers or {} base_url = getattr(model_config, "base_url", None) @@ -532,18 +553,15 @@ def _create_llm_from_model_config(model_config: ModelUnion): temperature=model_config.temperature, timeout=model_config.timeout, top_p=model_config.top_p, - tls_disable_verify=model_config.tls_disable_verify, - tls_ca_cert_path=model_config.tls_ca_cert_path, - tls_disable_system_cas=model_config.tls_disable_system_cas, - api_key_passthrough=model_config.api_key_passthrough, token_exchange=token_exchange, + **_transport_kwargs(model_config), ) if model_config.type == "anthropic": return KAgentAnthropicLlm( model=model_config.model, base_url=base_url, extra_headers=extra_headers, - api_key_passthrough=model_config.api_key_passthrough, + **_transport_kwargs(model_config), ) if model_config.type == "gemini_vertex_ai": return GeminiLLM(model=model_config.model) @@ -556,24 +574,27 @@ def _create_llm_from_model_config(model_config: ModelUnion): model=model_config.model, options=ollama_options, extra_headers=extra_headers, + **_transport_kwargs(model_config), ) if model_config.type == "azure_openai": return OpenAIAzure( model=model_config.model, type="azure_openai", default_headers=extra_headers, - tls_disable_verify=model_config.tls_disable_verify, - tls_ca_cert_path=model_config.tls_ca_cert_path, - tls_disable_system_cas=model_config.tls_disable_system_cas, - api_key_passthrough=model_config.api_key_passthrough, + **_transport_kwargs(model_config), ) if model_config.type == "gemini": - return model_config.model + return KAgentGeminiLlm( + model=model_config.model, + extra_headers=extra_headers, + **_transport_kwargs(model_config), + ) if model_config.type == "bedrock": return KAgentBedrockLlm( model=model_config.model, extra_headers=extra_headers, additional_model_request_fields=model_config.additional_model_request_fields, + **_transport_kwargs(model_config), ) if model_config.type == "sap_ai_core": from .models._sap_ai_core import KAgentSAPAICoreLlm @@ -583,10 +604,7 @@ def _create_llm_from_model_config(model_config: ModelUnion): base_url=base_url, resource_group=model_config.resource_group, auth_url=model_config.auth_url, - api_key_passthrough=model_config.api_key_passthrough, - tls_disable_verify=model_config.tls_disable_verify or False, - tls_ca_cert_path=model_config.tls_ca_cert_path, - tls_disable_system_cas=model_config.tls_disable_system_cas or False, + **_transport_kwargs(model_config), ) raise ValueError(f"Invalid model type: {model_config.type}") diff --git a/python/packages/kagent-adk/tests/unittests/models/test_openai.py b/python/packages/kagent-adk/tests/unittests/models/test_openai.py index bcc24d40c..392c2d8a4 100644 --- a/python/packages/kagent-adk/tests/unittests/models/test_openai.py +++ b/python/packages/kagent-adk/tests/unittests/models/test_openai.py @@ -544,7 +544,7 @@ def test_openai_client_without_tls_config(): def test_openai_client_with_tls_verification_disabled(): """Test OpenAI client with TLS verification disabled.""" - with mock.patch("kagent.adk.models._openai.create_ssl_context") as mock_create_ssl: + with mock.patch("kagent.adk.models._ssl.create_ssl_context") as mock_create_ssl: with mock.patch("kagent.adk.models._openai.DefaultAsyncHttpxClient") as mock_httpx: with mock.patch("kagent.adk.models._openai.AsyncOpenAI") as mock_openai: # create_ssl_context returns False when verification is disabled @@ -584,7 +584,7 @@ def test_openai_client_with_custom_ca_certificate(): """Test OpenAI client with custom CA certificate.""" import ssl - with mock.patch("kagent.adk.models._openai.create_ssl_context") as mock_create_ssl: + with mock.patch("kagent.adk.models._ssl.create_ssl_context") as mock_create_ssl: with mock.patch("kagent.adk.models._openai.DefaultAsyncHttpxClient") as mock_httpx: with mock.patch("kagent.adk.models._openai.AsyncOpenAI"): # create_ssl_context returns SSLContext for custom CA @@ -621,7 +621,7 @@ def test_openai_client_with_custom_ca_only(): """Test OpenAI client with custom CA only (no system CAs).""" import ssl - with mock.patch("kagent.adk.models._openai.create_ssl_context") as mock_create_ssl: + with mock.patch("kagent.adk.models._ssl.create_ssl_context") as mock_create_ssl: with mock.patch("kagent.adk.models._openai.DefaultAsyncHttpxClient") as mock_httpx: with mock.patch("kagent.adk.models._openai.AsyncOpenAI"): mock_ssl_context = mock.MagicMock(spec=ssl.SSLContext) @@ -677,7 +677,7 @@ def test_azure_openai_client_with_tls(): from kagent.adk.models import AzureOpenAI - with mock.patch("kagent.adk.models._openai.create_ssl_context") as mock_create_ssl: + with mock.patch("kagent.adk.models._ssl.create_ssl_context") as mock_create_ssl: with mock.patch("kagent.adk.models._openai.DefaultAsyncHttpxClient") as mock_httpx: with mock.patch("kagent.adk.models._openai.AsyncAzureOpenAI") as mock_azure_openai: mock_ssl_context = mock.MagicMock(spec=ssl.SSLContext) @@ -719,7 +719,7 @@ def test_openai_client_with_base_url_and_tls(): """Test OpenAI client with base_url (LiteLLM gateway) and TLS configuration.""" import ssl - with mock.patch("kagent.adk.models._openai.create_ssl_context") as mock_create_ssl: + with mock.patch("kagent.adk.models._ssl.create_ssl_context") as mock_create_ssl: with mock.patch("kagent.adk.models._openai.DefaultAsyncHttpxClient") as mock_httpx: with mock.patch("kagent.adk.models._openai.AsyncOpenAI"): mock_ssl_context = mock.MagicMock(spec=ssl.SSLContext) diff --git a/python/packages/kagent-adk/tests/unittests/models/test_tls_integration.py b/python/packages/kagent-adk/tests/unittests/models/test_tls_integration.py index bf944fb5a..9bc5d1a1e 100644 --- a/python/packages/kagent-adk/tests/unittests/models/test_tls_integration.py +++ b/python/packages/kagent-adk/tests/unittests/models/test_tls_integration.py @@ -13,6 +13,7 @@ import pytest +from kagent.adk.models._gemini import KAgentGeminiLlm from kagent.adk.models._openai import OpenAI from kagent.adk.models._ssl import create_ssl_context, get_ssl_troubleshooting_message, validate_certificate @@ -218,7 +219,7 @@ def test_e2e_ssl_error_troubleshooting_message(temp_cert_file): def test_e2e_openai_client_reads_config_based_tls(temp_cert_file): """Test OpenAI client reads TLS config from instance fields (agent config).""" - with mock.patch("kagent.adk.models._openai.create_ssl_context") as mock_create_ssl: + with mock.patch("kagent.adk.models._ssl.create_ssl_context") as mock_create_ssl: with mock.patch("httpx.AsyncClient") as mock_httpx: with mock.patch("kagent.adk.models._openai.AsyncOpenAI"): mock_create_ssl.return_value = mock.MagicMock(spec=ssl.SSLContext) @@ -245,6 +246,25 @@ def test_e2e_openai_client_reads_config_based_tls(temp_cert_file): assert call_kwargs["disable_system_cas"] is False +def test_e2e_gemini_sets_httpx_and_aiohttp_tls_options(temp_cert_file): + """Test Gemini config passes TLS context for both GenAI async transports.""" + with mock.patch("kagent.adk.models._ssl.create_ssl_context") as mock_create_ssl: + ssl_context = mock.MagicMock(spec=ssl.SSLContext) + mock_create_ssl.return_value = ssl_context + + gemini_llm = KAgentGeminiLlm( + model="gemini-2.5-flash", + type="gemini", + tls_ca_cert_path=temp_cert_file, + tls_disable_system_cas=False, + ) + + http_options = gemini_llm._http_options() + + assert http_options.client_args == {"verify": ssl_context} + assert http_options.async_client_args == {"verify": ssl_context, "ssl": ssl_context} + + def test_e2e_certificate_validation_expiry_warnings(caplog): """Test certificate validation logs expiry warnings but doesn't block.""" # This test requires the cryptography library to be installed @@ -316,7 +336,7 @@ def test_e2e_structured_logging_at_startup(temp_cert_file, caplog): def test_e2e_litellm_with_tls(temp_cert_file): """Test complete flow: LiteLLM base URL + TLS configuration.""" - with mock.patch("kagent.adk.models._openai.create_ssl_context") as mock_create_ssl: + with mock.patch("kagent.adk.models._ssl.create_ssl_context") as mock_create_ssl: with mock.patch("kagent.adk.models._openai.DefaultAsyncHttpxClient") as mock_httpx: with mock.patch("kagent.adk.models._openai.AsyncOpenAI") as mock_openai: mock_ssl_context = mock.MagicMock(spec=ssl.SSLContext) From 3efdaf15a4b25dfcc33049bc8b73ff4dfb4f45e3 Mon Sep 17 00:00:00 2001 From: Jet Chiang Date: Tue, 5 May 2026 11:48:19 -0400 Subject: [PATCH 3/3] cleanup and review feedback Signed-off-by: Jet Chiang --- go/adk/pkg/agent/agent.go | 11 ++- go/adk/pkg/models/sapaicore.go | 10 +-- .../src/kagent/adk/models/_ollama.py | 3 + .../src/kagent/adk/models/_sap_ai_core.py | 75 +++++++++++-------- 4 files changed, 54 insertions(+), 45 deletions(-) diff --git a/go/adk/pkg/agent/agent.go b/go/adk/pkg/agent/agent.go index 74eff615a..7b07e903e 100644 --- a/go/adk/pkg/agent/agent.go +++ b/go/adk/pkg/agent/agent.go @@ -322,12 +322,11 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM, case *adk.SAPAICore: cfg := models.SAPAICoreConfig{ - TransportConfig: transportConfigFromBase(m.BaseModel, nil), - Model: m.Model, - BaseUrl: m.BaseUrl, - ResourceGroup: m.ResourceGroup, - AuthUrl: m.AuthUrl, - Headers: extractHeaders(m.Headers), + Model: m.Model, + BaseUrl: m.BaseUrl, + ResourceGroup: m.ResourceGroup, + AuthUrl: m.AuthUrl, + Headers: extractHeaders(m.Headers), } return models.NewSAPAICoreModelWithLogger(cfg, log) diff --git a/go/adk/pkg/models/sapaicore.go b/go/adk/pkg/models/sapaicore.go index 35290a163..380e4f6e7 100644 --- a/go/adk/pkg/models/sapaicore.go +++ b/go/adk/pkg/models/sapaicore.go @@ -15,7 +15,6 @@ import ( ) type SAPAICoreConfig struct { - TransportConfig Model string BaseUrl string ResourceGroup string @@ -42,17 +41,10 @@ func NewSAPAICoreModelWithLogger(config SAPAICoreConfig, logger logr.Logger) (*S if config.ResourceGroup == "" { config.ResourceGroup = "default" } - - // Build HTTP client with TLS, custom headers, and timeout support - httpClient, err := BuildHTTPClient(config.TransportConfig) - if err != nil { - return nil, fmt.Errorf("failed to create SAP AI Core HTTP client: %w", err) - } - return &SAPAICoreModel{ Config: config, Logger: logger, - httpClient: httpClient, + httpClient: &http.Client{Timeout: 5 * time.Minute}, }, nil } diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_ollama.py b/python/packages/kagent-adk/src/kagent/adk/models/_ollama.py index cfaf5f701..03c268af4 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/_ollama.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/_ollama.py @@ -150,6 +150,7 @@ class KAgentOllamaLlm(KAgentTLSMixin, BaseLlm): type: Literal["ollama"] = "ollama" ollama_options: Optional[dict[str, object]] = None default_headers: Optional[dict[str, str]] = None + api_key_passthrough: Optional[bool] = None @cached_property def _client(self) -> AsyncClient: @@ -273,6 +274,7 @@ def create_ollama_llm( tls_disable_verify: Optional[bool] = None, tls_ca_cert_path: Optional[str] = None, tls_disable_system_cas: Optional[bool] = None, + api_key_passthrough: Optional[bool] = None, ) -> KAgentOllamaLlm: """Build a KAgentOllamaLlm from Ollama options. @@ -287,4 +289,5 @@ def create_ollama_llm( tls_disable_verify=tls_disable_verify, tls_ca_cert_path=tls_ca_cert_path, tls_disable_system_cas=tls_disable_system_cas, + api_key_passthrough=api_key_passthrough, ) diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_sap_ai_core.py b/python/packages/kagent-adk/src/kagent/adk/models/_sap_ai_core.py index 96960b16f..b3da12401 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/_sap_ai_core.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/_sap_ai_core.py @@ -2,9 +2,11 @@ from __future__ import annotations +import asyncio import json import logging import os +import ssl import time from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional @@ -14,7 +16,6 @@ from google.genai import types from ._openai import _convert_tools_to_openai -from ._ssl import KAgentTLSMixin if TYPE_CHECKING: from google.adk.models.llm_request import LlmRequest @@ -22,32 +23,30 @@ logger = logging.getLogger(__name__) -async def _fetch_oauth_token( - auth_url: str, - client_id: str, - client_secret: str, - client: httpx.AsyncClient, -) -> tuple[str, float]: +async def _fetch_oauth_token(auth_url: str, client_id: str, client_secret: str) -> tuple[str, float]: """Fetch a new OAuth2 token from the auth server. No caching — callers manage expiry.""" token_url = auth_url.rstrip("/") if not token_url.endswith("/oauth/token"): token_url += "/oauth/token" - resp = await client.post( - token_url, - data={ - "grant_type": "client_credentials", - "client_id": client_id, - "client_secret": client_secret, - }, - headers={"Content-Type": "application/x-www-form-urlencoded"}, - timeout=30, - ) - resp.raise_for_status() - data = resp.json() - token = data["access_token"] - expires_at = time.time() + data.get("expires_in", 43200) - return token, expires_at + def _sync_fetch() -> tuple[str, float]: + resp = httpx.post( + token_url, + data={ + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30, + ) + resp.raise_for_status() + data = resp.json() + token = data["access_token"] + expires_at = time.time() + data.get("expires_in", 43200) + return token, expires_at + + return await asyncio.to_thread(_sync_fetch) def _build_orchestration_template( @@ -146,7 +145,7 @@ def _parse_orchestration_chunk(event_data: dict[str, Any]) -> Optional[dict[str, _RETRYABLE_STATUS_CODES = {401, 403, 404, 502, 503, 504} -class KAgentSAPAICoreLlm(KAgentTLSMixin, BaseLlm): +class KAgentSAPAICoreLlm(BaseLlm): """SAP AI Core LLM via Orchestration Service. Supports all model families (OpenAI, Anthropic, Gemini, etc.) through @@ -158,6 +157,10 @@ class KAgentSAPAICoreLlm(KAgentTLSMixin, BaseLlm): auth_url: Optional[str] = None api_key_passthrough: Optional[bool] = None + tls_disable_verify: bool = False + tls_ca_cert_path: Optional[str] = None + tls_disable_system_cas: bool = False + _passthrough_key: Optional[str] = None _http_client: Optional[httpx.AsyncClient] = None _token: Optional[str] = None @@ -175,11 +178,28 @@ def set_passthrough_key(self, token: str) -> None: self._passthrough_key = token self._http_client = None + def _create_ssl_context(self) -> Optional[ssl.SSLContext]: + if not self.tls_disable_verify and not self.tls_ca_cert_path and not self.tls_disable_system_cas: + return None + ctx = ssl.create_default_context() + if self.tls_disable_verify: + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + elif self.tls_disable_system_cas: + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + if self.tls_ca_cert_path: + ctx.load_verify_locations(self.tls_ca_cert_path) + elif self.tls_ca_cert_path: + ctx.load_verify_locations(self.tls_ca_cert_path) + return ctx + def _get_http_client(self) -> httpx.AsyncClient: if self._http_client is not None: return self._http_client + ssl_ctx = self._create_ssl_context() kwargs: dict[str, Any] = {"timeout": 300} - kwargs.update(self._tls_httpx_kwargs()) + if ssl_ctx is not None: + kwargs["verify"] = ssl_ctx self._http_client = httpx.AsyncClient(**kwargs) return self._http_client @@ -203,12 +223,7 @@ async def _ensure_token(self) -> str: client_secret = os.environ.get("SAP_AI_CORE_CLIENT_SECRET", "") if self.auth_url and client_id and client_secret: - token, expires_at = await _fetch_oauth_token( - self.auth_url, - client_id, - client_secret, - self._get_http_client(), - ) + token, expires_at = await _fetch_oauth_token(self.auth_url, client_id, client_secret) self._token = token self._token_expires_at = expires_at return token