From 438a2a834bcc5340e3fd74e939fcdac1e67662ba Mon Sep 17 00:00:00 2001 From: AnantKumar17 Date: Thu, 30 Apr 2026 22:42:21 +0530 Subject: [PATCH] fix(bedrock): sanitize tool names to meet Bedrock API constraints MCP servers can produce tool names containing dots, spaces, or other characters that Bedrock rejects. The Converse API requires tool names to match [a-zA-Z0-9_-]+ but the previous code passed names through unchanged, causing ValidationException errors at runtime. This change adds _sanitize_tool_name, which replaces any invalid character with an underscore and falls back to a generated name when the input is empty. A name_map shared across both _convert_tools_to_converse and _convert_content_to_converse_messages ensures the same original name always maps to the same sanitized name within a single request. A reverse_name_map is then used to restore original names in Bedrock responses before they are returned to the ADK framework, so the rest of the agent runtime sees the names it originally registered. A warning is logged whenever a name is sanitized to aid debugging. Fixes #1473 Signed-off-by: AnantKumar17 --- .../src/kagent/adk/models/_bedrock.py | 55 ++++- .../tests/unittests/models/test_bedrock.py | 205 +++++++++++++++++- 2 files changed, 251 insertions(+), 9 deletions(-) diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py b/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py index e1ebdbfb6..54ccb65a5 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py @@ -29,6 +29,9 @@ _BEDROCK_TOOL_ID_VALID = re.compile(r"^[a-zA-Z0-9_.:-]+$") _BEDROCK_TOOL_ID_INVALID = re.compile(r"[^a-zA-Z0-9_.:-]") +_BEDROCK_TOOL_NAME_VALID = re.compile(r"^[a-zA-Z0-9_-]+$") +_BEDROCK_TOOL_NAME_INVALID = re.compile(r"[^a-zA-Z0-9_-]") + def _sanitize_tool_id(tool_id: str, id_map: dict[str, str], counter: list[int]) -> str: """Return a valid Bedrock toolUseId. @@ -54,6 +57,29 @@ def _sanitize_tool_id(tool_id: str, id_map: dict[str, str], counter: list[int]) return sanitized +def _sanitize_tool_name(name: str, name_map: dict[str, str], counter: list[int]) -> str: + """Return a valid Bedrock tool name. + + Bedrock requires tool names to match [a-zA-Z0-9_-]+ and be non-empty. + name_map caches original->sanitized so the same tool name is consistently + mapped throughout a single request. counter is a single-element list used + as a mutable integer for generating unique fallback names. + + See https://github.com/kagent-dev/kagent/issues/1473 + """ + if name in name_map: + return name_map[name] + sanitized = _BEDROCK_TOOL_NAME_INVALID.sub("_", name) + if not sanitized or not _BEDROCK_TOOL_NAME_VALID.match(sanitized): + counter[0] += 1 + sanitized = f"unknown_tool_{counter[0]}" + return sanitized + if sanitized != name: + logger.warning("Sanitized Bedrock tool name %r -> %r", name, sanitized) + name_map[name] = sanitized + return sanitized + + def _get_bedrock_client(extra_headers: Optional[dict[str, str]] = None): region = os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION") or "us-east-1" kwargs: dict[str, Any] = {"region_name": region} @@ -63,7 +89,11 @@ def _get_bedrock_client(extra_headers: Optional[dict[str, str]] = None): return boto3.client("bedrock-runtime", **kwargs) -def _convert_content_to_converse_messages(contents: list[types.Content]) -> list[dict]: +def _convert_content_to_converse_messages( + contents: list[types.Content], + name_map: dict[str, str], + name_counter: list[int], +) -> list[dict]: id_map: dict[str, str] = {} counter = [0] @@ -80,7 +110,7 @@ def _convert_content_to_converse_messages(contents: list[types.Content]) -> list { "toolUse": { "toolUseId": _sanitize_tool_id(part.function_call.id or "", id_map, counter), - "name": part.function_call.name or "", + "name": _sanitize_tool_name(part.function_call.name or "", name_map, name_counter), "input": part.function_call.args or {}, } } @@ -149,7 +179,11 @@ def _normalize_schema(schema: dict) -> dict: return result -def _convert_tools_to_converse(tools: list[types.Tool]) -> list[dict]: +def _convert_tools_to_converse( + tools: list[types.Tool], + name_map: dict[str, str], + name_counter: list[int], +) -> list[dict]: converse_tools = [] for tool in tools: for func_decl in tool.function_declarations or []: @@ -164,7 +198,7 @@ def _convert_tools_to_converse(tools: list[types.Tool]) -> list[dict]: converse_tools.append( { "toolSpec": { - "name": func_decl.name or "", + "name": _sanitize_tool_name(func_decl.name or "", name_map, name_counter), "description": func_decl.description or "", "inputSchema": { "json": { @@ -212,7 +246,9 @@ async def generate_content_async( client = self._client model_id = llm_request.model or self.model - messages = _convert_content_to_converse_messages(llm_request.contents or []) + name_map: dict[str, str] = {} + name_counter: list[int] = [0] + messages = _convert_content_to_converse_messages(llm_request.contents or [], name_map, name_counter) kwargs: dict[str, Any] = {"modelId": model_id, "messages": messages} @@ -228,7 +264,7 @@ async def generate_content_async( if llm_request.config and llm_request.config.tools: genai_tools = [t for t in llm_request.config.tools if hasattr(t, "function_declarations")] if genai_tools: - converse_tools = _convert_tools_to_converse(genai_tools) + converse_tools = _convert_tools_to_converse(genai_tools, name_map, name_counter) if converse_tools: kwargs["toolConfig"] = {"tools": converse_tools} @@ -248,6 +284,8 @@ async def generate_content_async( if self.additional_model_request_fields: kwargs["additionalModelRequestFields"] = self.additional_model_request_fields + reverse_name_map = {v: k for k, v in name_map.items()} + def _run_converse_stream(**kw): resp = client.converse_stream(**kw) return list(resp.get("stream", [])) @@ -268,7 +306,7 @@ def _run_converse_stream(**kw): if "toolUse" in start: current_tool_id = start["toolUse"]["toolUseId"] tool_uses[current_tool_id] = { - "name": start["toolUse"]["name"], + "name": reverse_name_map.get(start["toolUse"]["name"], start["toolUse"]["name"]), "input_json": "", } @@ -325,7 +363,8 @@ def _run_converse_stream(**kw): parts.append(types.Part.from_text(text=block["text"])) elif "toolUse" in block: tool = block["toolUse"] - part = types.Part.from_function_call(name=tool["name"], args=tool.get("input", {})) + original_name = reverse_name_map.get(tool["name"], tool["name"]) + part = types.Part.from_function_call(name=original_name, args=tool.get("input", {})) if part.function_call: part.function_call.id = tool["toolUseId"] parts.append(part) diff --git a/python/packages/kagent-adk/tests/unittests/models/test_bedrock.py b/python/packages/kagent-adk/tests/unittests/models/test_bedrock.py index fab018520..d8cb3f56a 100644 --- a/python/packages/kagent-adk/tests/unittests/models/test_bedrock.py +++ b/python/packages/kagent-adk/tests/unittests/models/test_bedrock.py @@ -5,7 +5,12 @@ import pytest -from kagent.adk.models._bedrock import KAgentBedrockLlm, _get_bedrock_client +from kagent.adk.models._bedrock import ( + KAgentBedrockLlm, + _convert_tools_to_converse, + _get_bedrock_client, + _sanitize_tool_name, +) class TestGetBedrockClient: @@ -109,3 +114,201 @@ def test_create_llm_from_bedrock_model_config(self): result = _create_llm_from_model_config(config) assert isinstance(result, KAgentBedrockLlm) assert result.model == "meta.llama3-8b-instruct-v1:0" + + +class TestSanitizeToolName: + def test_valid_name_unchanged(self): + name_map: dict = {} + counter = [0] + assert _sanitize_tool_name("get_pods", name_map, counter) == "get_pods" + assert name_map == {"get_pods": "get_pods"} + + def test_dot_replaced_with_underscore(self): + name_map: dict = {} + counter = [0] + assert _sanitize_tool_name("kubernetes.get_pods", name_map, counter) == "kubernetes_get_pods" + assert name_map["kubernetes.get_pods"] == "kubernetes_get_pods" + + def test_space_replaced_with_underscore(self): + name_map: dict = {} + counter = [0] + assert _sanitize_tool_name("get pods", name_map, counter) == "get_pods" + + def test_colon_replaced_with_underscore(self): + name_map: dict = {} + counter = [0] + assert _sanitize_tool_name("ns:tool", name_map, counter) == "ns_tool" + + def test_empty_name_gets_fallback(self): + name_map: dict = {} + counter = [0] + result = _sanitize_tool_name("", name_map, counter) + assert result == "unknown_tool_1" + assert counter[0] == 1 + + def test_fully_invalid_name_becomes_underscores(self): + name_map: dict = {} + counter = [0] + result = _sanitize_tool_name("!@#$", name_map, counter) + # Characters are replaced with _ so the result is still valid per pattern + assert result == "____" + assert counter[0] == 0 + + def test_same_name_returns_cached_sanitized(self): + name_map: dict = {} + counter = [0] + first = _sanitize_tool_name("mcp.server.tool", name_map, counter) + second = _sanitize_tool_name("mcp.server.tool", name_map, counter) + assert first == second == "mcp_server_tool" + + def test_hyphen_preserved(self): + name_map: dict = {} + counter = [0] + assert _sanitize_tool_name("my-tool", name_map, counter) == "my-tool" + + +class TestConvertToolsToConverse: + def test_tool_name_with_dot_is_sanitized(self): + from google.genai import types as genai_types + + func_decl = mock.MagicMock() + func_decl.name = "github_copilot.suggest" + func_decl.description = "Suggest code" + func_decl.parameters = None + + tool = mock.MagicMock() + tool.function_declarations = [func_decl] + + name_map: dict = {} + counter = [0] + result = _convert_tools_to_converse([tool], name_map, counter) + + assert result[0]["toolSpec"]["name"] == "github_copilot_suggest" + assert name_map["github_copilot.suggest"] == "github_copilot_suggest" + + def test_valid_tool_name_unchanged(self): + func_decl = mock.MagicMock() + func_decl.name = "list_namespaces" + func_decl.description = "List namespaces" + func_decl.parameters = None + + tool = mock.MagicMock() + tool.function_declarations = [func_decl] + + name_map: dict = {} + counter = [0] + result = _convert_tools_to_converse([tool], name_map, counter) + + assert result[0]["toolSpec"]["name"] == "list_namespaces" + + +class TestBedrockToolNameRoundTrip: + @pytest.mark.asyncio + async def test_dotted_tool_name_restored_in_non_streaming_response(self): + """Tool names with dots are sanitized outbound and restored from the Bedrock response.""" + llm = KAgentBedrockLlm(model="us.anthropic.claude-sonnet-4-20250514-v1:0") + + converse_response = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "call_abc", + "name": "github_copilot_suggest", + "input": {"prompt": "hello"}, + } + } + ], + } + }, + "stopReason": "tool_use", + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + } + mock_client = mock.MagicMock() + mock_client.converse.return_value = converse_response + + async def fake_to_thread(fn, **kwargs): + return fn(**kwargs) + + from google.genai import types as genai_types + + func_decl = mock.MagicMock() + func_decl.name = "github_copilot.suggest" + func_decl.description = "Suggest" + func_decl.parameters = None + + tool = mock.MagicMock(spec=genai_types.Tool) + tool.function_declarations = [func_decl] + + request = mock.MagicMock() + request.model = "us.anthropic.claude-sonnet-4-20250514-v1:0" + request.contents = [] + request.config = mock.MagicMock() + request.config.system_instruction = None + request.config.tools = [tool] + request.config.temperature = None + request.config.max_output_tokens = None + request.config.top_p = None + request.config.stop_sequences = None + + with ( + mock.patch("kagent.adk.models._bedrock._get_bedrock_client", return_value=mock_client), + mock.patch("kagent.adk.models._bedrock.asyncio.to_thread", side_effect=fake_to_thread), + ): + responses = [r async for r in llm.generate_content_async(request)] + + assert len(responses) == 1 + fc = responses[0].content.parts[0].function_call + assert fc.name == "github_copilot.suggest" + + @pytest.mark.asyncio + async def test_dotted_tool_name_restored_in_streaming_response(self): + """Tool names with dots are sanitized outbound and restored from streaming Bedrock response.""" + llm = KAgentBedrockLlm(model="us.anthropic.claude-sonnet-4-20250514-v1:0") + + stream_events = [ + { + "contentBlockStart": { + "start": {"toolUse": {"toolUseId": "call_xyz", "name": "github_copilot_suggest"}} + } + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"prompt": "hi"}'}}}}, + {"messageStop": {"stopReason": "tool_use"}}, + {"metadata": {"usage": {"inputTokens": 5, "outputTokens": 3, "totalTokens": 8}}}, + ] + mock_client = mock.MagicMock() + mock_client.converse_stream.return_value = {"stream": stream_events} + + async def fake_to_thread(fn, **kwargs): + return fn(**kwargs) + + func_decl = mock.MagicMock() + func_decl.name = "github_copilot.suggest" + func_decl.description = "Suggest" + func_decl.parameters = None + + tool = mock.MagicMock() + tool.function_declarations = [func_decl] + + request = mock.MagicMock() + request.model = "us.anthropic.claude-sonnet-4-20250514-v1:0" + request.contents = [] + request.config = mock.MagicMock() + request.config.system_instruction = None + request.config.tools = [tool] + request.config.temperature = None + request.config.max_output_tokens = None + request.config.top_p = None + request.config.stop_sequences = None + + with ( + mock.patch("kagent.adk.models._bedrock._get_bedrock_client", return_value=mock_client), + mock.patch("kagent.adk.models._bedrock.asyncio.to_thread", side_effect=fake_to_thread), + ): + responses = [r async for r in llm.generate_content_async(request, stream=True)] + + final = responses[-1] + fc = final.content.parts[0].function_call + assert fc.name == "github_copilot.suggest"