diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py b/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py index e1ebdbfb6..54ccb65a5 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py @@ -29,6 +29,9 @@ _BEDROCK_TOOL_ID_VALID = re.compile(r"^[a-zA-Z0-9_.:-]+$") _BEDROCK_TOOL_ID_INVALID = re.compile(r"[^a-zA-Z0-9_.:-]") +_BEDROCK_TOOL_NAME_VALID = re.compile(r"^[a-zA-Z0-9_-]+$") +_BEDROCK_TOOL_NAME_INVALID = re.compile(r"[^a-zA-Z0-9_-]") + def _sanitize_tool_id(tool_id: str, id_map: dict[str, str], counter: list[int]) -> str: """Return a valid Bedrock toolUseId. @@ -54,6 +57,29 @@ def _sanitize_tool_id(tool_id: str, id_map: dict[str, str], counter: list[int]) return sanitized +def _sanitize_tool_name(name: str, name_map: dict[str, str], counter: list[int]) -> str: + """Return a valid Bedrock tool name. + + Bedrock requires tool names to match [a-zA-Z0-9_-]+ and be non-empty. + name_map caches original->sanitized so the same tool name is consistently + mapped throughout a single request. counter is a single-element list used + as a mutable integer for generating unique fallback names. + + See https://github.com/kagent-dev/kagent/issues/1473 + """ + if name in name_map: + return name_map[name] + sanitized = _BEDROCK_TOOL_NAME_INVALID.sub("_", name) + if not sanitized or not _BEDROCK_TOOL_NAME_VALID.match(sanitized): + counter[0] += 1 + sanitized = f"unknown_tool_{counter[0]}" + return sanitized + if sanitized != name: + logger.warning("Sanitized Bedrock tool name %r -> %r", name, sanitized) + name_map[name] = sanitized + return sanitized + + def _get_bedrock_client(extra_headers: Optional[dict[str, str]] = None): region = os.environ.get("AWS_DEFAULT_REGION") or os.environ.get("AWS_REGION") or "us-east-1" kwargs: dict[str, Any] = {"region_name": region} @@ -63,7 +89,11 @@ def _get_bedrock_client(extra_headers: Optional[dict[str, str]] = None): return boto3.client("bedrock-runtime", **kwargs) -def _convert_content_to_converse_messages(contents: list[types.Content]) -> list[dict]: +def _convert_content_to_converse_messages( + contents: list[types.Content], + name_map: dict[str, str], + name_counter: list[int], +) -> list[dict]: id_map: dict[str, str] = {} counter = [0] @@ -80,7 +110,7 @@ def _convert_content_to_converse_messages(contents: list[types.Content]) -> list { "toolUse": { "toolUseId": _sanitize_tool_id(part.function_call.id or "", id_map, counter), - "name": part.function_call.name or "", + "name": _sanitize_tool_name(part.function_call.name or "", name_map, name_counter), "input": part.function_call.args or {}, } } @@ -149,7 +179,11 @@ def _normalize_schema(schema: dict) -> dict: return result -def _convert_tools_to_converse(tools: list[types.Tool]) -> list[dict]: +def _convert_tools_to_converse( + tools: list[types.Tool], + name_map: dict[str, str], + name_counter: list[int], +) -> list[dict]: converse_tools = [] for tool in tools: for func_decl in tool.function_declarations or []: @@ -164,7 +198,7 @@ def _convert_tools_to_converse(tools: list[types.Tool]) -> list[dict]: converse_tools.append( { "toolSpec": { - "name": func_decl.name or "", + "name": _sanitize_tool_name(func_decl.name or "", name_map, name_counter), "description": func_decl.description or "", "inputSchema": { "json": { @@ -212,7 +246,9 @@ async def generate_content_async( client = self._client model_id = llm_request.model or self.model - messages = _convert_content_to_converse_messages(llm_request.contents or []) + name_map: dict[str, str] = {} + name_counter: list[int] = [0] + messages = _convert_content_to_converse_messages(llm_request.contents or [], name_map, name_counter) kwargs: dict[str, Any] = {"modelId": model_id, "messages": messages} @@ -228,7 +264,7 @@ async def generate_content_async( if llm_request.config and llm_request.config.tools: genai_tools = [t for t in llm_request.config.tools if hasattr(t, "function_declarations")] if genai_tools: - converse_tools = _convert_tools_to_converse(genai_tools) + converse_tools = _convert_tools_to_converse(genai_tools, name_map, name_counter) if converse_tools: kwargs["toolConfig"] = {"tools": converse_tools} @@ -248,6 +284,8 @@ async def generate_content_async( if self.additional_model_request_fields: kwargs["additionalModelRequestFields"] = self.additional_model_request_fields + reverse_name_map = {v: k for k, v in name_map.items()} + def _run_converse_stream(**kw): resp = client.converse_stream(**kw) return list(resp.get("stream", [])) @@ -268,7 +306,7 @@ def _run_converse_stream(**kw): if "toolUse" in start: current_tool_id = start["toolUse"]["toolUseId"] tool_uses[current_tool_id] = { - "name": start["toolUse"]["name"], + "name": reverse_name_map.get(start["toolUse"]["name"], start["toolUse"]["name"]), "input_json": "", } @@ -325,7 +363,8 @@ def _run_converse_stream(**kw): parts.append(types.Part.from_text(text=block["text"])) elif "toolUse" in block: tool = block["toolUse"] - part = types.Part.from_function_call(name=tool["name"], args=tool.get("input", {})) + original_name = reverse_name_map.get(tool["name"], tool["name"]) + part = types.Part.from_function_call(name=original_name, args=tool.get("input", {})) if part.function_call: part.function_call.id = tool["toolUseId"] parts.append(part) diff --git a/python/packages/kagent-adk/tests/unittests/models/test_bedrock.py b/python/packages/kagent-adk/tests/unittests/models/test_bedrock.py index fab018520..d8cb3f56a 100644 --- a/python/packages/kagent-adk/tests/unittests/models/test_bedrock.py +++ b/python/packages/kagent-adk/tests/unittests/models/test_bedrock.py @@ -5,7 +5,12 @@ import pytest -from kagent.adk.models._bedrock import KAgentBedrockLlm, _get_bedrock_client +from kagent.adk.models._bedrock import ( + KAgentBedrockLlm, + _convert_tools_to_converse, + _get_bedrock_client, + _sanitize_tool_name, +) class TestGetBedrockClient: @@ -109,3 +114,201 @@ def test_create_llm_from_bedrock_model_config(self): result = _create_llm_from_model_config(config) assert isinstance(result, KAgentBedrockLlm) assert result.model == "meta.llama3-8b-instruct-v1:0" + + +class TestSanitizeToolName: + def test_valid_name_unchanged(self): + name_map: dict = {} + counter = [0] + assert _sanitize_tool_name("get_pods", name_map, counter) == "get_pods" + assert name_map == {"get_pods": "get_pods"} + + def test_dot_replaced_with_underscore(self): + name_map: dict = {} + counter = [0] + assert _sanitize_tool_name("kubernetes.get_pods", name_map, counter) == "kubernetes_get_pods" + assert name_map["kubernetes.get_pods"] == "kubernetes_get_pods" + + def test_space_replaced_with_underscore(self): + name_map: dict = {} + counter = [0] + assert _sanitize_tool_name("get pods", name_map, counter) == "get_pods" + + def test_colon_replaced_with_underscore(self): + name_map: dict = {} + counter = [0] + assert _sanitize_tool_name("ns:tool", name_map, counter) == "ns_tool" + + def test_empty_name_gets_fallback(self): + name_map: dict = {} + counter = [0] + result = _sanitize_tool_name("", name_map, counter) + assert result == "unknown_tool_1" + assert counter[0] == 1 + + def test_fully_invalid_name_becomes_underscores(self): + name_map: dict = {} + counter = [0] + result = _sanitize_tool_name("!@#$", name_map, counter) + # Characters are replaced with _ so the result is still valid per pattern + assert result == "____" + assert counter[0] == 0 + + def test_same_name_returns_cached_sanitized(self): + name_map: dict = {} + counter = [0] + first = _sanitize_tool_name("mcp.server.tool", name_map, counter) + second = _sanitize_tool_name("mcp.server.tool", name_map, counter) + assert first == second == "mcp_server_tool" + + def test_hyphen_preserved(self): + name_map: dict = {} + counter = [0] + assert _sanitize_tool_name("my-tool", name_map, counter) == "my-tool" + + +class TestConvertToolsToConverse: + def test_tool_name_with_dot_is_sanitized(self): + from google.genai import types as genai_types + + func_decl = mock.MagicMock() + func_decl.name = "github_copilot.suggest" + func_decl.description = "Suggest code" + func_decl.parameters = None + + tool = mock.MagicMock() + tool.function_declarations = [func_decl] + + name_map: dict = {} + counter = [0] + result = _convert_tools_to_converse([tool], name_map, counter) + + assert result[0]["toolSpec"]["name"] == "github_copilot_suggest" + assert name_map["github_copilot.suggest"] == "github_copilot_suggest" + + def test_valid_tool_name_unchanged(self): + func_decl = mock.MagicMock() + func_decl.name = "list_namespaces" + func_decl.description = "List namespaces" + func_decl.parameters = None + + tool = mock.MagicMock() + tool.function_declarations = [func_decl] + + name_map: dict = {} + counter = [0] + result = _convert_tools_to_converse([tool], name_map, counter) + + assert result[0]["toolSpec"]["name"] == "list_namespaces" + + +class TestBedrockToolNameRoundTrip: + @pytest.mark.asyncio + async def test_dotted_tool_name_restored_in_non_streaming_response(self): + """Tool names with dots are sanitized outbound and restored from the Bedrock response.""" + llm = KAgentBedrockLlm(model="us.anthropic.claude-sonnet-4-20250514-v1:0") + + converse_response = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "call_abc", + "name": "github_copilot_suggest", + "input": {"prompt": "hello"}, + } + } + ], + } + }, + "stopReason": "tool_use", + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + } + mock_client = mock.MagicMock() + mock_client.converse.return_value = converse_response + + async def fake_to_thread(fn, **kwargs): + return fn(**kwargs) + + from google.genai import types as genai_types + + func_decl = mock.MagicMock() + func_decl.name = "github_copilot.suggest" + func_decl.description = "Suggest" + func_decl.parameters = None + + tool = mock.MagicMock(spec=genai_types.Tool) + tool.function_declarations = [func_decl] + + request = mock.MagicMock() + request.model = "us.anthropic.claude-sonnet-4-20250514-v1:0" + request.contents = [] + request.config = mock.MagicMock() + request.config.system_instruction = None + request.config.tools = [tool] + request.config.temperature = None + request.config.max_output_tokens = None + request.config.top_p = None + request.config.stop_sequences = None + + with ( + mock.patch("kagent.adk.models._bedrock._get_bedrock_client", return_value=mock_client), + mock.patch("kagent.adk.models._bedrock.asyncio.to_thread", side_effect=fake_to_thread), + ): + responses = [r async for r in llm.generate_content_async(request)] + + assert len(responses) == 1 + fc = responses[0].content.parts[0].function_call + assert fc.name == "github_copilot.suggest" + + @pytest.mark.asyncio + async def test_dotted_tool_name_restored_in_streaming_response(self): + """Tool names with dots are sanitized outbound and restored from streaming Bedrock response.""" + llm = KAgentBedrockLlm(model="us.anthropic.claude-sonnet-4-20250514-v1:0") + + stream_events = [ + { + "contentBlockStart": { + "start": {"toolUse": {"toolUseId": "call_xyz", "name": "github_copilot_suggest"}} + } + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"prompt": "hi"}'}}}}, + {"messageStop": {"stopReason": "tool_use"}}, + {"metadata": {"usage": {"inputTokens": 5, "outputTokens": 3, "totalTokens": 8}}}, + ] + mock_client = mock.MagicMock() + mock_client.converse_stream.return_value = {"stream": stream_events} + + async def fake_to_thread(fn, **kwargs): + return fn(**kwargs) + + func_decl = mock.MagicMock() + func_decl.name = "github_copilot.suggest" + func_decl.description = "Suggest" + func_decl.parameters = None + + tool = mock.MagicMock() + tool.function_declarations = [func_decl] + + request = mock.MagicMock() + request.model = "us.anthropic.claude-sonnet-4-20250514-v1:0" + request.contents = [] + request.config = mock.MagicMock() + request.config.system_instruction = None + request.config.tools = [tool] + request.config.temperature = None + request.config.max_output_tokens = None + request.config.top_p = None + request.config.stop_sequences = None + + with ( + mock.patch("kagent.adk.models._bedrock._get_bedrock_client", return_value=mock_client), + mock.patch("kagent.adk.models._bedrock.asyncio.to_thread", side_effect=fake_to_thread), + ): + responses = [r async for r in llm.generate_content_async(request, stream=True)] + + final = responses[-1] + fc = final.content.parts[0].function_call + assert fc.name == "github_copilot.suggest"