Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 47 additions & 8 deletions python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
_BEDROCK_TOOL_ID_VALID = re.compile(r"^[a-zA-Z0-9_.:-]+$")
_BEDROCK_TOOL_ID_INVALID = re.compile(r"[^a-zA-Z0-9_.:-]")

_BEDROCK_TOOL_NAME_VALID = re.compile(r"^[a-zA-Z0-9_-]+$")
_BEDROCK_TOOL_NAME_INVALID = re.compile(r"[^a-zA-Z0-9_-]")


def _sanitize_tool_id(tool_id: str, id_map: dict[str, str], counter: list[int]) -> str:
"""Return a valid Bedrock toolUseId.
Expand All @@ -54,6 +57,29 @@ def _sanitize_tool_id(tool_id: str, id_map: dict[str, str], counter: list[int])
return sanitized


def _sanitize_tool_name(name: str, name_map: dict[str, str], counter: list[int]) -> str:
"""Return a valid Bedrock tool name.

Bedrock requires tool names to match [a-zA-Z0-9_-]+ and be non-empty.
name_map caches original->sanitized so the same tool name is consistently
mapped throughout a single request. counter is a single-element list used
as a mutable integer for generating unique fallback names.

See https://github.com/kagent-dev/kagent/issues/1473
"""
if name in name_map:
return name_map[name]
sanitized = _BEDROCK_TOOL_NAME_INVALID.sub("_", name)
if not sanitized or not _BEDROCK_TOOL_NAME_VALID.match(sanitized):
counter[0] += 1
sanitized = f"unknown_tool_{counter[0]}"
return sanitized
Comment on lines +60 to +76
if sanitized != name:
logger.warning("Sanitized Bedrock tool name %r -> %r", name, sanitized)
name_map[name] = sanitized
return sanitized


def _get_bedrock_client(extra_headers: Optional[dict[str, str]] = None):
region = os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION") or "us-east-1"
kwargs: dict[str, Any] = {"region_name": region}
Expand All @@ -63,7 +89,11 @@ def _get_bedrock_client(extra_headers: Optional[dict[str, str]] = None):
return boto3.client("bedrock-runtime", **kwargs)


def _convert_content_to_converse_messages(contents: list[types.Content]) -> list[dict]:
def _convert_content_to_converse_messages(
contents: list[types.Content],
name_map: dict[str, str],
name_counter: list[int],
) -> list[dict]:
id_map: dict[str, str] = {}
counter = [0]

Expand All @@ -80,7 +110,7 @@ def _convert_content_to_converse_messages(contents: list[types.Content]) -> list
{
"toolUse": {
"toolUseId": _sanitize_tool_id(part.function_call.id or "", id_map, counter),
"name": part.function_call.name or "",
"name": _sanitize_tool_name(part.function_call.name or "", name_map, name_counter),
"input": part.function_call.args or {},
}
}
Expand Down Expand Up @@ -149,7 +179,11 @@ def _normalize_schema(schema: dict) -> dict:
return result


def _convert_tools_to_converse(tools: list[types.Tool]) -> list[dict]:
def _convert_tools_to_converse(
tools: list[types.Tool],
name_map: dict[str, str],
name_counter: list[int],
) -> list[dict]:
converse_tools = []
for tool in tools:
for func_decl in tool.function_declarations or []:
Expand All @@ -164,7 +198,7 @@ def _convert_tools_to_converse(tools: list[types.Tool]) -> list[dict]:
converse_tools.append(
{
"toolSpec": {
"name": func_decl.name or "",
"name": _sanitize_tool_name(func_decl.name or "", name_map, name_counter),
"description": func_decl.description or "",
"inputSchema": {
"json": {
Expand Down Expand Up @@ -212,7 +246,9 @@ async def generate_content_async(
client = self._client
model_id = llm_request.model or self.model

messages = _convert_content_to_converse_messages(llm_request.contents or [])
name_map: dict[str, str] = {}
name_counter: list[int] = [0]
messages = _convert_content_to_converse_messages(llm_request.contents or [], name_map, name_counter)

kwargs: dict[str, Any] = {"modelId": model_id, "messages": messages}

Expand All @@ -228,7 +264,7 @@ async def generate_content_async(
if llm_request.config and llm_request.config.tools:
genai_tools = [t for t in llm_request.config.tools if hasattr(t, "function_declarations")]
if genai_tools:
converse_tools = _convert_tools_to_converse(genai_tools)
converse_tools = _convert_tools_to_converse(genai_tools, name_map, name_counter)
if converse_tools:
kwargs["toolConfig"] = {"tools": converse_tools}

Expand All @@ -248,6 +284,8 @@ async def generate_content_async(
if self.additional_model_request_fields:
kwargs["additionalModelRequestFields"] = self.additional_model_request_fields

reverse_name_map = {v: k for k, v in name_map.items()}

def _run_converse_stream(**kw):
resp = client.converse_stream(**kw)
return list(resp.get("stream", []))
Expand All @@ -268,7 +306,7 @@ def _run_converse_stream(**kw):
if "toolUse" in start:
current_tool_id = start["toolUse"]["toolUseId"]
tool_uses[current_tool_id] = {
"name": start["toolUse"]["name"],
"name": reverse_name_map.get(start["toolUse"]["name"], start["toolUse"]["name"]),
"input_json": "",
}

Expand Down Expand Up @@ -325,7 +363,8 @@ def _run_converse_stream(**kw):
parts.append(types.Part.from_text(text=block["text"]))
elif "toolUse" in block:
tool = block["toolUse"]
part = types.Part.from_function_call(name=tool["name"], args=tool.get("input", {}))
original_name = reverse_name_map.get(tool["name"], tool["name"])
part = types.Part.from_function_call(name=original_name, args=tool.get("input", {}))
if part.function_call:
part.function_call.id = tool["toolUseId"]
parts.append(part)
Expand Down
205 changes: 204 additions & 1 deletion python/packages/kagent-adk/tests/unittests/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

import pytest

from kagent.adk.models._bedrock import KAgentBedrockLlm, _get_bedrock_client
from kagent.adk.models._bedrock import (
KAgentBedrockLlm,
_convert_tools_to_converse,
_get_bedrock_client,
_sanitize_tool_name,
)


class TestGetBedrockClient:
Expand Down Expand Up @@ -109,3 +114,201 @@ def test_create_llm_from_bedrock_model_config(self):
result = _create_llm_from_model_config(config)
assert isinstance(result, KAgentBedrockLlm)
assert result.model == "meta.llama3-8b-instruct-v1:0"


class TestSanitizeToolName:
def test_valid_name_unchanged(self):
name_map: dict = {}
counter = [0]
assert _sanitize_tool_name("get_pods", name_map, counter) == "get_pods"
assert name_map == {"get_pods": "get_pods"}

def test_dot_replaced_with_underscore(self):
name_map: dict = {}
counter = [0]
assert _sanitize_tool_name("kubernetes.get_pods", name_map, counter) == "kubernetes_get_pods"
assert name_map["kubernetes.get_pods"] == "kubernetes_get_pods"

def test_space_replaced_with_underscore(self):
name_map: dict = {}
counter = [0]
assert _sanitize_tool_name("get pods", name_map, counter) == "get_pods"

def test_colon_replaced_with_underscore(self):
name_map: dict = {}
counter = [0]
assert _sanitize_tool_name("ns:tool", name_map, counter) == "ns_tool"

def test_empty_name_gets_fallback(self):
name_map: dict = {}
counter = [0]
result = _sanitize_tool_name("", name_map, counter)
assert result == "unknown_tool_1"
assert counter[0] == 1

Comment on lines +145 to +148
def test_fully_invalid_name_becomes_underscores(self):
name_map: dict = {}
counter = [0]
result = _sanitize_tool_name("!@#$", name_map, counter)
# Characters are replaced with _ so the result is still valid per pattern
assert result == "____"
assert counter[0] == 0

def test_same_name_returns_cached_sanitized(self):
name_map: dict = {}
counter = [0]
first = _sanitize_tool_name("mcp.server.tool", name_map, counter)
second = _sanitize_tool_name("mcp.server.tool", name_map, counter)
assert first == second == "mcp_server_tool"

def test_hyphen_preserved(self):
name_map: dict = {}
counter = [0]
assert _sanitize_tool_name("my-tool", name_map, counter) == "my-tool"


class TestConvertToolsToConverse:
def test_tool_name_with_dot_is_sanitized(self):
from google.genai import types as genai_types

func_decl = mock.MagicMock()
func_decl.name = "github_copilot.suggest"
func_decl.description = "Suggest code"
func_decl.parameters = None

tool = mock.MagicMock()
tool.function_declarations = [func_decl]

name_map: dict = {}
counter = [0]
result = _convert_tools_to_converse([tool], name_map, counter)

assert result[0]["toolSpec"]["name"] == "github_copilot_suggest"
assert name_map["github_copilot.suggest"] == "github_copilot_suggest"

def test_valid_tool_name_unchanged(self):
func_decl = mock.MagicMock()
func_decl.name = "list_namespaces"
func_decl.description = "List namespaces"
func_decl.parameters = None

tool = mock.MagicMock()
tool.function_declarations = [func_decl]

name_map: dict = {}
counter = [0]
result = _convert_tools_to_converse([tool], name_map, counter)

assert result[0]["toolSpec"]["name"] == "list_namespaces"


class TestBedrockToolNameRoundTrip:
@pytest.mark.asyncio
async def test_dotted_tool_name_restored_in_non_streaming_response(self):
"""Tool names with dots are sanitized outbound and restored from the Bedrock response."""
llm = KAgentBedrockLlm(model="us.anthropic.claude-sonnet-4-20250514-v1:0")

converse_response = {
"output": {
"message": {
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "call_abc",
"name": "github_copilot_suggest",
"input": {"prompt": "hello"},
}
}
],
}
},
"stopReason": "tool_use",
"usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15},
}
mock_client = mock.MagicMock()
mock_client.converse.return_value = converse_response

async def fake_to_thread(fn, **kwargs):
return fn(**kwargs)

from google.genai import types as genai_types

func_decl = mock.MagicMock()
func_decl.name = "github_copilot.suggest"
func_decl.description = "Suggest"
func_decl.parameters = None

tool = mock.MagicMock(spec=genai_types.Tool)
tool.function_declarations = [func_decl]

request = mock.MagicMock()
request.model = "us.anthropic.claude-sonnet-4-20250514-v1:0"
request.contents = []
request.config = mock.MagicMock()
request.config.system_instruction = None
request.config.tools = [tool]
request.config.temperature = None
request.config.max_output_tokens = None
request.config.top_p = None
request.config.stop_sequences = None

with (
mock.patch("kagent.adk.models._bedrock._get_bedrock_client", return_value=mock_client),
mock.patch("kagent.adk.models._bedrock.asyncio.to_thread", side_effect=fake_to_thread),
):
responses = [r async for r in llm.generate_content_async(request)]

assert len(responses) == 1
fc = responses[0].content.parts[0].function_call
assert fc.name == "github_copilot.suggest"

@pytest.mark.asyncio
async def test_dotted_tool_name_restored_in_streaming_response(self):
"""Tool names with dots are sanitized outbound and restored from streaming Bedrock response."""
llm = KAgentBedrockLlm(model="us.anthropic.claude-sonnet-4-20250514-v1:0")

stream_events = [
{
"contentBlockStart": {
"start": {"toolUse": {"toolUseId": "call_xyz", "name": "github_copilot_suggest"}}
}
},
{"contentBlockDelta": {"delta": {"toolUse": {"input": '{"prompt": "hi"}'}}}},
{"messageStop": {"stopReason": "tool_use"}},
{"metadata": {"usage": {"inputTokens": 5, "outputTokens": 3, "totalTokens": 8}}},
]
mock_client = mock.MagicMock()
mock_client.converse_stream.return_value = {"stream": stream_events}

async def fake_to_thread(fn, **kwargs):
return fn(**kwargs)

func_decl = mock.MagicMock()
func_decl.name = "github_copilot.suggest"
func_decl.description = "Suggest"
func_decl.parameters = None

tool = mock.MagicMock()
tool.function_declarations = [func_decl]

request = mock.MagicMock()
request.model = "us.anthropic.claude-sonnet-4-20250514-v1:0"
request.contents = []
request.config = mock.MagicMock()
request.config.system_instruction = None
request.config.tools = [tool]
request.config.temperature = None
request.config.max_output_tokens = None
request.config.top_p = None
request.config.stop_sequences = None

with (
mock.patch("kagent.adk.models._bedrock._get_bedrock_client", return_value=mock_client),
mock.patch("kagent.adk.models._bedrock.asyncio.to_thread", side_effect=fake_to_thread),
):
responses = [r async for r in llm.generate_content_async(request, stream=True)]

final = responses[-1]
fc = final.content.parts[0].function_call
assert fc.name == "github_copilot.suggest"
Loading