diff --git a/go/adk/pkg/models/bedrock.go b/go/adk/pkg/models/bedrock.go index 69596e0a1..d9db5a842 100644 --- a/go/adk/pkg/models/bedrock.go +++ b/go/adk/pkg/models/bedrock.go @@ -20,11 +20,34 @@ import ( ) // bedrockToolIDValid matches Bedrock's toolUseId constraint: [a-zA-Z0-9_.:-]+ +// bedrockToolNameInvalid matches characters not allowed in Bedrock tool names: [a-zA-Z0-9_-]+ var ( - bedrockToolIDValid = regexp.MustCompile(`^[a-zA-Z0-9_.:-]+$`) - bedrockToolIDInvalid = regexp.MustCompile(`[^a-zA-Z0-9_.:-]`) + bedrockToolIDValid = regexp.MustCompile(`^[a-zA-Z0-9_.:-]+$`) + bedrockToolIDInvalid = regexp.MustCompile(`[^a-zA-Z0-9_.:-]`) + bedrockToolNameInvalid = regexp.MustCompile(`[^a-zA-Z0-9_-]`) ) +// sanitizeBedrockToolName returns a valid Bedrock tool name. +// Bedrock requires tool names to match [a-zA-Z0-9_-]+ and be non-empty. +// nameMap caches original->sanitized so repeated lookups for the same name are +// consistent. counter is incremented only when a synthetic name is needed. +func sanitizeBedrockToolName(name string, nameMap map[string]string, counter *int) string { + if name == "" { + *counter++ + return fmt.Sprintf("tool_fn_%d", *counter) + } + if sanitized, ok := nameMap[name]; ok { + return sanitized + } + sanitized := bedrockToolNameInvalid.ReplaceAllString(name, "_") + if sanitized == "" { + *counter++ + sanitized = fmt.Sprintf("tool_fn_%d", *counter) + } + nameMap[name] = sanitized + return sanitized +} + // sanitizeBedrockToolID returns a valid Bedrock toolUseId. // Bedrock requires toolUseId to match [a-zA-Z0-9_.:-]+ and be non-empty. // idMap caches original→sanitized so FunctionCall and FunctionResponse @@ -121,8 +144,32 @@ func (m *BedrockModel) GenerateContent(ctx context.Context, req *model.LLMReques modelName = req.Model } - // Convert content to Bedrock messages - messages, systemInstruction := convertGenaiContentsToBedrockMessages(req.Contents) + // Build tool configuration first so nameMap is available for message conversion. + // convertGenaiToolsToBedrock sanitizes tool names and returns the + // original->sanitized mapping so the same mapping can be applied to + // conversation history and reversed when restoring names from responses. + var toolConfig *types.ToolConfiguration + nameMap := make(map[string]string) + if req.Config != nil && len(req.Config.Tools) > 0 { + tools, nm := convertGenaiToolsToBedrock(req.Config.Tools) + nameMap = nm + if len(tools) > 0 { + toolConfig = &types.ToolConfiguration{ + Tools: tools, + } + } + } + + // Build reverse map for restoring original tool names from Bedrock responses. + reverseNameMap := make(map[string]string, len(nameMap)) + for orig, sanitized := range nameMap { + reverseNameMap[sanitized] = orig + } + + // Convert content to Bedrock messages. + // nameMap is passed so that any tool call recorded in conversation history + // is written with the sanitized name Bedrock already knows about. + messages, systemInstruction := convertGenaiContentsToBedrockMessages(req.Contents, nameMap) // Build inference config var inferenceConfig *types.InferenceConfiguration @@ -147,27 +194,15 @@ func (m *BedrockModel) GenerateContent(ctx context.Context, req *model.LLMReques }) } - // Build tool configuration - var toolConfig *types.ToolConfiguration - if req.Config != nil && len(req.Config.Tools) > 0 { - tools := convertGenaiToolsToBedrock(req.Config.Tools) - if len(tools) > 0 { - toolConfig = &types.ToolConfiguration{ - Tools: tools, - } - } - } - - // Build model-specific additional fields (Claude top_k, thinking, etc.) additionalFields := m.buildAdditionalModelRequestFields() // Set telemetry attributes telemetry.SetLLMRequestAttributes(ctx, modelName, req) if stream { - m.generateStreaming(ctx, modelName, messages, systemPrompt, inferenceConfig, toolConfig, additionalFields, yield) + m.generateStreaming(ctx, modelName, messages, systemPrompt, inferenceConfig, toolConfig, additionalFields, reverseNameMap, yield) } else { - m.generateNonStreaming(ctx, modelName, messages, systemPrompt, inferenceConfig, toolConfig, additionalFields, yield) + m.generateNonStreaming(ctx, modelName, messages, systemPrompt, inferenceConfig, toolConfig, additionalFields, reverseNameMap, yield) } } } @@ -185,7 +220,8 @@ func (m *BedrockModel) buildAdditionalModelRequestFields() document.Interface { // generateStreaming handles streaming responses from Bedrock ConverseStream. // It properly handles both text and tool use content blocks during streaming. -func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, messages []types.Message, systemPrompt []types.SystemContentBlock, inferenceConfig *types.InferenceConfiguration, toolConfig *types.ToolConfiguration, additionalFields document.Interface, yield func(*model.LLMResponse, error) bool) { +// reverseNameMap maps sanitized Bedrock tool names back to their original names. +func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, messages []types.Message, systemPrompt []types.SystemContentBlock, inferenceConfig *types.InferenceConfiguration, toolConfig *types.ToolConfiguration, additionalFields document.Interface, reverseNameMap map[string]string, yield func(*model.LLMResponse, error) bool) { output, err := m.Client.ConverseStream(ctx, &bedrockruntime.ConverseStreamInput{ ModelId: aws.String(modelId), Messages: messages, @@ -266,11 +302,17 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me if stop, ok := event.(*types.ConverseStreamOutputMemberContentBlockStop); ok { blockIdx := aws.ToInt32(stop.Value.ContentBlockIndex) if tc, ok := toolCalls[blockIdx]; ok { - // Tool use block completed - parse the accumulated JSON and create FunctionCall + // Tool use block completed - parse the accumulated JSON and create FunctionCall. + // Restore the original tool name from the reverse map so the ADK framework + // can dispatch to the correctly registered tool. + originalName := tc.Name + if orig, found := reverseNameMap[tc.Name]; found { + originalName = orig + } args := tc.parseArgs() functionCall := &genai.FunctionCall{ ID: tc.ID, - Name: tc.Name, + Name: originalName, Args: args, } completedToolCalls = append(completedToolCalls, &genai.Part{FunctionCall: functionCall}) @@ -338,7 +380,8 @@ func (tc *streamingToolCall) parseArgs() map[string]any { } // generateNonStreaming handles non-streaming responses from Bedrock Converse. -func (m *BedrockModel) generateNonStreaming(ctx context.Context, modelId string, messages []types.Message, systemPrompt []types.SystemContentBlock, inferenceConfig *types.InferenceConfiguration, toolConfig *types.ToolConfiguration, additionalFields document.Interface, yield func(*model.LLMResponse, error) bool) { +// reverseNameMap maps sanitized Bedrock tool names back to their original names. +func (m *BedrockModel) generateNonStreaming(ctx context.Context, modelId string, messages []types.Message, systemPrompt []types.SystemContentBlock, inferenceConfig *types.InferenceConfiguration, toolConfig *types.ToolConfiguration, additionalFields document.Interface, reverseNameMap map[string]string, yield func(*model.LLMResponse, error) bool) { output, err := m.Client.Converse(ctx, &bedrockruntime.ConverseInput{ ModelId: aws.String(modelId), Messages: messages, @@ -366,9 +409,15 @@ func (m *BedrockModel) generateNonStreaming(ctx context.Context, modelId string, } // Handle tool use content if toolUseBlock, ok := block.(*types.ContentBlockMemberToolUse); ok { + // Restore the original tool name so the ADK framework can dispatch + // to the correctly registered tool. + toolName := aws.ToString(toolUseBlock.Value.Name) + if orig, found := reverseNameMap[toolName]; found { + toolName = orig + } functionCall := &genai.FunctionCall{ ID: aws.ToString(toolUseBlock.Value.ToolUseId), - Name: aws.ToString(toolUseBlock.Value.Name), + Name: toolName, } // Convert document.Interface to map using the String() method and JSON parsing // The document type in AWS SDK implements String() that returns JSON @@ -425,7 +474,10 @@ func documentToMap(doc document.Interface) map[string]any { } // convertGenaiContentsToBedrockMessages converts genai.Content to Bedrock Converse API message format. -func convertGenaiContentsToBedrockMessages(contents []*genai.Content) ([]types.Message, string) { +// nameMap is the original->sanitized tool name map produced by convertGenaiToolsToBedrock. +// Any FunctionCall found in the conversation history is written with the sanitized name so +// that Bedrock can correlate it with the tool spec it already received. A nil nameMap is safe. +func convertGenaiContentsToBedrockMessages(contents []*genai.Content, nameMap map[string]string) ([]types.Message, string) { var messages []types.Message var systemInstruction string @@ -465,11 +517,17 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content) ([]types.M continue } - // Handle function call (tool use in Bedrock terminology) + // Handle function call (tool use in Bedrock terminology). + // Use the sanitized name from nameMap so Bedrock can correlate the + // tool call with the tool spec sent in the same request. if part.FunctionCall != nil { + callName := part.FunctionCall.Name + if sanitized, ok := nameMap[callName]; ok { + callName = sanitized + } toolUse := types.ToolUseBlock{ ToolUseId: aws.String(sanitizeBedrockToolID(part.FunctionCall.ID, idMap, &idCounter)), - Name: aws.String(part.FunctionCall.Name), + Name: aws.String(callName), Input: document.NewLazyDocument(part.FunctionCall.Args), } contentBlocks = append(contentBlocks, &types.ContentBlockMemberToolUse{ @@ -507,11 +565,16 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content) ([]types.M } // convertGenaiToolsToBedrock converts genai.Tool to Bedrock Tool format. -func convertGenaiToolsToBedrock(tools []*genai.Tool) []types.Tool { +// It sanitizes tool names to satisfy Bedrock's [a-zA-Z0-9_-]+ constraint and +// returns the original->sanitized name mapping so callers can apply it to +// conversation history and reverse it when restoring names from responses. +func convertGenaiToolsToBedrock(tools []*genai.Tool) ([]types.Tool, map[string]string) { if len(tools) == 0 { - return nil + return nil, nil } + nameMap := make(map[string]string) + nameCounter := 0 var bedrockTools []types.Tool for _, tool := range tools { @@ -525,7 +588,7 @@ func convertGenaiToolsToBedrock(tools []*genai.Tool) []types.Tool { } // Build input schema as JSON document. - // MCP tools and built-in local toolsset ParametersJsonSchema + // MCP tools and built-in local toolsets set ParametersJsonSchema. var schema map[string]any if decl.ParametersJsonSchema != nil { schema = parametersJsonSchemaToMap(decl.ParametersJsonSchema) @@ -536,7 +599,7 @@ func convertGenaiToolsToBedrock(tools []*genai.Tool) []types.Tool { // then lowercase all type values to match JSON Schema standard. schema = genaiSchemaToMap(decl.Parameters) } - // Fallback to empty object if no schema is found + // Fallback to empty object if no schema is found. if schema == nil { schema = map[string]any{"type": "object", "properties": map[string]any{}} } @@ -545,8 +608,12 @@ func convertGenaiToolsToBedrock(tools []*genai.Tool) []types.Tool { Value: document.NewLazyDocument(schema), } + // Sanitize the tool name: MCP tool names often contain dots, colons, + // or spaces (e.g. "fetch.get_url") that Bedrock rejects. + sanitizedName := sanitizeBedrockToolName(decl.Name, nameMap, &nameCounter) + toolSpec := types.ToolSpecification{ - Name: aws.String(decl.Name), + Name: aws.String(sanitizedName), Description: aws.String(decl.Description), InputSchema: inputSchema, } @@ -558,7 +625,7 @@ func convertGenaiToolsToBedrock(tools []*genai.Tool) []types.Tool { } } - return bedrockTools + return bedrockTools, nameMap } // bedrockStopReasonToGenai maps Bedrock stop reason to genai.FinishReason. diff --git a/go/adk/pkg/models/bedrock_test.go b/go/adk/pkg/models/bedrock_test.go index 0b8fe8100..de2d1c3ca 100644 --- a/go/adk/pkg/models/bedrock_test.go +++ b/go/adk/pkg/models/bedrock_test.go @@ -106,7 +106,7 @@ func TestConvertGenaiContentsToBedrockMessages(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - msgs, systemText := convertGenaiContentsToBedrockMessages(tt.contents) + msgs, systemText := convertGenaiContentsToBedrockMessages(tt.contents, nil) if len(msgs) != tt.wantMsgCount { t.Errorf("expected %d messages, got %d", tt.wantMsgCount, len(msgs)) } @@ -124,7 +124,7 @@ func TestConvertGenaiContentsToBedrockMessages(t *testing.T) { // sources: genai.Schema (declaration-based), map[string]any (MCP), and // *jsonschema.Schema (functiontool.New). func TestConvertGenaiToolsToBedrock(t *testing.T) { - extractSchema := func(t *testing.T, tools []types.Tool) map[string]any { + extractSchema := func(t *testing.T, tools []types.Tool, _ map[string]string) map[string]any { t.Helper() if len(tools) != 1 { t.Fatalf("expected 1 tool, got %d", len(tools)) @@ -162,7 +162,8 @@ func TestConvertGenaiToolsToBedrock(t *testing.T) { }, }}}} - schema := extractSchema(t, convertGenaiToolsToBedrock(tools)) + bt1, nm1 := convertGenaiToolsToBedrock(tools) + schema := extractSchema(t, bt1, nm1) props := schema["properties"].(map[string]any) for prop, want := range map[string]string{"location": "string", "count": "integer", "detailed": "boolean"} { @@ -189,7 +190,8 @@ func TestConvertGenaiToolsToBedrock(t *testing.T) { }, }}}} - schema := extractSchema(t, convertGenaiToolsToBedrock(tools)) + bt2, nm2 := convertGenaiToolsToBedrock(tools) + schema := extractSchema(t, bt2, nm2) props, ok := schema["properties"].(map[string]any) if !ok || len(props) == 0 { t.Fatalf("expected non-empty properties, got %v", schema["properties"]) @@ -209,7 +211,8 @@ func TestConvertGenaiToolsToBedrock(t *testing.T) { ParametersJsonSchema: s, }}}} - schema := extractSchema(t, convertGenaiToolsToBedrock(tools)) + bt3, nm3 := convertGenaiToolsToBedrock(tools) + schema := extractSchema(t, bt3, nm3) props, ok := schema["properties"].(map[string]any) if !ok || len(props) == 0 { t.Fatalf("expected non-empty properties (means *jsonschema.Schema was not converted): %v", schema["properties"]) @@ -310,6 +313,88 @@ func TestSanitizeBedrockToolID(t *testing.T) { }) } +func TestSanitizeBedrockToolName(t *testing.T) { + tests := []struct { + name string + tool string + want string + }{ + {name: "valid name unchanged", tool: "get_weather", want: "get_weather"}, + {name: "valid name with hyphen", tool: "fetch-data", want: "fetch-data"}, + {name: "dot replaced", tool: "fetch.get_url", want: "fetch_get_url"}, + {name: "colon replaced", tool: "filesystem:read_file", want: "filesystem_read_file"}, + {name: "space replaced", tool: "my tool", want: "my_tool"}, + {name: "multiple invalid chars", tool: "a.b:c d", want: "a_b_c_d"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nameMap := make(map[string]string) + counter := 0 + if got := sanitizeBedrockToolName(tt.tool, nameMap, &counter); got != tt.want { + t.Errorf("sanitizeBedrockToolName(%q) = %q, want %q", tt.tool, got, tt.want) + } + }) + } + + t.Run("empty name gets synthetic", func(t *testing.T) { + nameMap, counter := make(map[string]string), 0 + got := sanitizeBedrockToolName("", nameMap, &counter) + if got != "tool_fn_1" { + t.Errorf("expected tool_fn_1, got %q", got) + } + if counter != 1 { + t.Errorf("expected counter=1, got %d", counter) + } + }) + + t.Run("caching returns same sanitized name", func(t *testing.T) { + nameMap, counter := make(map[string]string), 0 + first := sanitizeBedrockToolName("fetch.get_url", nameMap, &counter) + second := sanitizeBedrockToolName("fetch.get_url", nameMap, &counter) + if first != second { + t.Errorf("expected same cached result, got %q and %q", first, second) + } + if counter != 0 { + t.Errorf("expected counter unchanged, got %d", counter) + } + }) +} + +func TestConvertGenaiToolsToBedrockSanitizesNames(t *testing.T) { + tools := []*genai.Tool{{FunctionDeclarations: []*genai.FunctionDeclaration{ + {Name: "fetch.get_url", Description: "Fetch a URL"}, + {Name: "filesystem:read_file", Description: "Read a file"}, + }}} + + bedrockTools, nameMap := convertGenaiToolsToBedrock(tools) + if len(bedrockTools) != 2 { + t.Fatalf("expected 2 tools, got %d", len(bedrockTools)) + } + + // Verify sanitized names in the Bedrock tool specs. + for i, want := range []string{"fetch_get_url", "filesystem_read_file"} { + tm, ok := bedrockTools[i].(*types.ToolMemberToolSpec) + if !ok { + t.Fatalf("tool %d: expected *types.ToolMemberToolSpec", i) + } + got := "" + if tm.Value.Name != nil { + got = *tm.Value.Name + } + if got != want { + t.Errorf("tool %d: expected name %q, got %q", i, want, got) + } + } + + // Verify nameMap contains the mappings. + if nameMap["fetch.get_url"] != "fetch_get_url" { + t.Errorf("nameMap[fetch.get_url] = %q, want fetch_get_url", nameMap["fetch.get_url"]) + } + if nameMap["filesystem:read_file"] != "filesystem_read_file" { + t.Errorf("nameMap[filesystem:read_file] = %q, want filesystem_read_file", nameMap["filesystem:read_file"]) + } +} + func TestStreamingToolCallParseArgs(t *testing.T) { tests := []struct { name string diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py b/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py index e1ebdbfb6..1e1df89dd 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py @@ -29,6 +29,30 @@ _BEDROCK_TOOL_ID_VALID = re.compile(r"^[a-zA-Z0-9_.:-]+$") _BEDROCK_TOOL_ID_INVALID = re.compile(r"[^a-zA-Z0-9_.:-]") +# Bedrock tool names allow only letters, digits, underscores, and hyphens. +# Dots, colons, spaces, and other chars (common in MCP server tool names) are invalid. +_BEDROCK_TOOL_NAME_VALID = re.compile(r"^[a-zA-Z0-9_-]+$") +_BEDROCK_TOOL_NAME_INVALID = re.compile(r"[^a-zA-Z0-9_-]") + + +def _sanitize_tool_name(name: str, name_map: dict[str, str], counter: list[int]) -> str: + """Return a valid Bedrock tool name. + + Bedrock requires tool names to match [a-zA-Z0-9_-]+. + Dots, colons, spaces, and other chars common in MCP server tool names are invalid. + name_map caches original→sanitized for consistency across a single request. + counter is a single-element list used as a mutable integer for unique fallback names. + """ + if name and name in name_map: + return name_map[name] + sanitized = _BEDROCK_TOOL_NAME_INVALID.sub("_", name) + if not sanitized or not _BEDROCK_TOOL_NAME_VALID.match(sanitized): + counter[0] += 1 + sanitized = f"tool_fn_{counter[0]}" + if name: + name_map[name] = sanitized + return sanitized + def _sanitize_tool_id(tool_id: str, id_map: dict[str, str], counter: list[int]) -> str: """Return a valid Bedrock toolUseId. @@ -63,7 +87,10 @@ def _get_bedrock_client(extra_headers: Optional[dict[str, str]] = None): return boto3.client("bedrock-runtime", **kwargs) -def _convert_content_to_converse_messages(contents: list[types.Content]) -> list[dict]: +def _convert_content_to_converse_messages( + contents: list[types.Content], + tool_name_map: Optional[dict[str, str]] = None, +) -> list[dict]: id_map: dict[str, str] = {} counter = [0] @@ -76,11 +103,13 @@ def _convert_content_to_converse_messages(contents: list[types.Content]) -> list if part.text: blocks.append({"text": part.text}) elif part.function_call: + raw_name = part.function_call.name or "" + sanitized_name = tool_name_map.get(raw_name, raw_name) if tool_name_map else raw_name blocks.append( { "toolUse": { "toolUseId": _sanitize_tool_id(part.function_call.id or "", id_map, counter), - "name": part.function_call.name or "", + "name": sanitized_name, "input": part.function_call.args or {}, } } @@ -149,7 +178,11 @@ def _normalize_schema(schema: dict) -> dict: return result -def _convert_tools_to_converse(tools: list[types.Tool]) -> list[dict]: +def _convert_tools_to_converse( + tools: list[types.Tool], + name_map: dict[str, str], + counter: list[int], +) -> list[dict]: converse_tools = [] for tool in tools: for func_decl in tool.function_declarations or []: @@ -161,10 +194,11 @@ def _convert_tools_to_converse(tools: list[types.Tool]) -> list[dict]: properties[prop_name] = _normalize_schema(raw) required = func_decl.parameters.required or [] + sanitized_name = _sanitize_tool_name(func_decl.name or "", name_map, counter) converse_tools.append( { "toolSpec": { - "name": func_decl.name or "", + "name": sanitized_name, "description": func_decl.description or "", "inputSchema": { "json": { @@ -212,9 +246,12 @@ async def generate_content_async( client = self._client model_id = llm_request.model or self.model - messages = _convert_content_to_converse_messages(llm_request.contents or []) + # Build the tool name map first so that message history and tool specs + # use the same sanitized names throughout the request. + tool_name_map: dict[str, str] = {} + tool_name_counter = [0] - kwargs: dict[str, Any] = {"modelId": model_id, "messages": messages} + kwargs: dict[str, Any] = {"modelId": model_id} if llm_request.config and llm_request.config.system_instruction: si = llm_request.config.system_instruction @@ -228,10 +265,16 @@ async def generate_content_async( if llm_request.config and llm_request.config.tools: genai_tools = [t for t in llm_request.config.tools if hasattr(t, "function_declarations")] if genai_tools: - converse_tools = _convert_tools_to_converse(genai_tools) + converse_tools = _convert_tools_to_converse(genai_tools, tool_name_map, tool_name_counter) if converse_tools: kwargs["toolConfig"] = {"tools": converse_tools} + # Reverse map lets us restore original tool names from sanitized names in Bedrock responses. + reverse_name_map: dict[str, str] = {v: k for k, v in tool_name_map.items()} + + messages = _convert_content_to_converse_messages(llm_request.contents or [], tool_name_map) + kwargs["messages"] = messages + inference_config: dict[str, Any] = {} if llm_request.config: if llm_request.config.temperature is not None: @@ -267,8 +310,9 @@ def _run_converse_stream(**kw): start = event["contentBlockStart"].get("start", {}) if "toolUse" in start: current_tool_id = start["toolUse"]["toolUseId"] + sanitized = start["toolUse"]["name"] tool_uses[current_tool_id] = { - "name": start["toolUse"]["name"], + "name": reverse_name_map.get(sanitized, sanitized), "input_json": "", } @@ -325,7 +369,9 @@ def _run_converse_stream(**kw): parts.append(types.Part.from_text(text=block["text"])) elif "toolUse" in block: tool = block["toolUse"] - part = types.Part.from_function_call(name=tool["name"], args=tool.get("input", {})) + sanitized = tool["name"] + original_name = reverse_name_map.get(sanitized, sanitized) + part = types.Part.from_function_call(name=original_name, args=tool.get("input", {})) if part.function_call: part.function_call.id = tool["toolUseId"] parts.append(part) diff --git a/python/packages/kagent-adk/tests/unittests/models/test_bedrock.py b/python/packages/kagent-adk/tests/unittests/models/test_bedrock.py index fab018520..bb067238e 100644 --- a/python/packages/kagent-adk/tests/unittests/models/test_bedrock.py +++ b/python/packages/kagent-adk/tests/unittests/models/test_bedrock.py @@ -1,11 +1,152 @@ """Tests for KAgentBedrockLlm.""" -import asyncio from unittest import mock +from unittest.mock import MagicMock import pytest -from kagent.adk.models._bedrock import KAgentBedrockLlm, _get_bedrock_client +from kagent.adk.models._bedrock import ( + KAgentBedrockLlm, + _convert_content_to_converse_messages, + _convert_tools_to_converse, + _get_bedrock_client, + _sanitize_tool_name, +) + + +class TestSanitizeToolName: + def test_valid_name_unchanged(self): + name_map: dict = {} + counter = [0] + assert _sanitize_tool_name("get_weather", name_map, counter) == "get_weather" + assert name_map["get_weather"] == "get_weather" + + def test_dot_replaced_with_underscore(self): + name_map: dict = {} + counter = [0] + assert _sanitize_tool_name("fetch.get_url", name_map, counter) == "fetch_get_url" + + def test_colon_replaced_with_underscore(self): + name_map: dict = {} + counter = [0] + assert _sanitize_tool_name("filesystem:read", name_map, counter) == "filesystem_read" + + def test_space_replaced_with_underscore(self): + name_map: dict = {} + counter = [0] + assert _sanitize_tool_name("read file", name_map, counter) == "read_file" + + def test_hyphen_kept(self): + name_map: dict = {} + counter = [0] + assert _sanitize_tool_name("get-weather", name_map, counter) == "get-weather" + + def test_same_original_gives_same_sanitized(self): + name_map: dict = {} + counter = [0] + first = _sanitize_tool_name("fetch.get_url", name_map, counter) + second = _sanitize_tool_name("fetch.get_url", name_map, counter) + assert first == second == "fetch_get_url" + assert counter[0] == 0 + + def test_empty_name_gets_synthetic(self): + name_map: dict = {} + counter = [0] + result = _sanitize_tool_name("", name_map, counter) + assert result == "tool_fn_1" + assert "" not in name_map + + def test_multiple_distinct_names(self): + name_map: dict = {} + counter = [0] + a = _sanitize_tool_name("server.tool_a", name_map, counter) + b = _sanitize_tool_name("server.tool_b", name_map, counter) + assert a == "server_tool_a" + assert b == "server_tool_b" + assert a != b + + def test_mixed_invalid_chars(self): + name_map: dict = {} + counter = [0] + result = _sanitize_tool_name("my.server:some tool", name_map, counter) + assert result == "my_server_some_tool" + + +class TestConvertToolsToConverse: + def _make_tool(self, name: str, description: str = "a tool"): + tool = MagicMock() + decl = MagicMock() + decl.name = name + decl.description = description + decl.parameters = None + tool.function_declarations = [decl] + return tool + + def test_plain_name_unchanged(self): + name_map: dict = {} + counter = [0] + tools = self._make_tool("get_weather") + result = _convert_tools_to_converse([tools], name_map, counter) + assert result[0]["toolSpec"]["name"] == "get_weather" + assert name_map == {"get_weather": "get_weather"} + + def test_dot_in_name_sanitized(self): + name_map: dict = {} + counter = [0] + tools = self._make_tool("fetch.get_url") + result = _convert_tools_to_converse([tools], name_map, counter) + assert result[0]["toolSpec"]["name"] == "fetch_get_url" + assert name_map["fetch.get_url"] == "fetch_get_url" + + def test_colon_in_name_sanitized(self): + name_map: dict = {} + counter = [0] + tools = self._make_tool("filesystem:read_file") + result = _convert_tools_to_converse([tools], name_map, counter) + assert result[0]["toolSpec"]["name"] == "filesystem_read_file" + + def test_multiple_tools_all_sanitized(self): + name_map: dict = {} + counter = [0] + t1 = self._make_tool("server.alpha") + t2 = self._make_tool("server.beta") + result = _convert_tools_to_converse([t1, t2], name_map, counter) + names = [r["toolSpec"]["name"] for r in result] + assert names == ["server_alpha", "server_beta"] + + +class TestConvertContentWithNameMap: + def test_function_call_name_sanitized_via_map(self): + from google.genai import types + + name_map = {"fetch.get_url": "fetch_get_url"} + part = types.Part.from_function_call(name="fetch.get_url", args={"url": "https://example.com"}) + if part.function_call: + part.function_call.id = "call-1" + content = types.Content(role="model", parts=[part]) + messages = _convert_content_to_converse_messages([content], tool_name_map=name_map) + assert messages[0]["content"][0]["toolUse"]["name"] == "fetch_get_url" + + def test_function_call_name_unchanged_without_map(self): + from google.genai import types + + part = types.Part.from_function_call(name="fetch.get_url", args={}) + if part.function_call: + part.function_call.id = "call-2" + content = types.Content(role="model", parts=[part]) + messages = _convert_content_to_converse_messages([content], tool_name_map=None) + assert messages[0]["content"][0]["toolUse"]["name"] == "fetch.get_url" + + def test_unknown_name_falls_back_to_original(self): + from google.genai import types + + name_map = {"other.tool": "other_tool"} + part = types.Part.from_function_call(name="unknown.tool", args={}) + if part.function_call: + part.function_call.id = "call-3" + content = types.Content(role="model", parts=[part]) + messages = _convert_content_to_converse_messages([content], tool_name_map=name_map) + assert messages[0]["content"][0]["toolUse"]["name"] == "unknown.tool" class TestGetBedrockClient: @@ -101,8 +242,70 @@ async def fake_to_thread(fn, **kwargs): assert final.usage_metadata.candidates_token_count == 5 assert final.usage_metadata.total_token_count == 15 + @pytest.mark.asyncio + async def test_dot_tool_name_sanitized_to_bedrock_and_remapped_in_response(self): + from google.genai import types + + llm = KAgentBedrockLlm(model="us.anthropic.claude-sonnet-4-20250514-v1:0") + + converse_response = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "call-abc", + "name": "fetch_get_url", + "input": {"url": "https://example.com"}, + } + } + ], + } + }, + "stopReason": "tool_use", + "usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30}, + } + mock_client = MagicMock() + mock_client.converse.return_value = converse_response + + async def fake_to_thread(fn, **kwargs): + return fn(**kwargs) + + func_decl = MagicMock() + func_decl.name = "fetch.get_url" + func_decl.description = "Fetch a URL" + func_decl.parameters = None + tool = MagicMock() + tool.function_declarations = [func_decl] + + request = MagicMock() + request.model = "us.anthropic.claude-sonnet-4-20250514-v1:0" + request.contents = [] + request.config = MagicMock() + request.config.system_instruction = None + request.config.tools = [tool] + request.config.temperature = None + request.config.max_output_tokens = None + request.config.top_p = None + request.config.stop_sequences = None + + with ( + mock.patch("kagent.adk.models._bedrock._get_bedrock_client", return_value=mock_client), + mock.patch("kagent.adk.models._bedrock.asyncio.to_thread", side_effect=fake_to_thread), + ): + responses = [r async for r in llm.generate_content_async(request)] + + assert len(responses) == 1 + fc = responses[0].content.parts[0].function_call + assert fc.name == "fetch.get_url" + assert fc.id == "call-abc" + + call_kwargs = mock_client.converse.call_args.kwargs + tool_names = [t["toolSpec"]["name"] for t in call_kwargs["toolConfig"]["tools"]] + assert tool_names == ["fetch_get_url"] + def test_create_llm_from_bedrock_model_config(self): - """Integration: _create_llm_from_model_config returns KAgentBedrockLlm for bedrock type.""" from kagent.adk.types import Bedrock, _create_llm_from_model_config config = Bedrock(type="bedrock", model="meta.llama3-8b-instruct-v1:0")