Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 99 additions & 32 deletions go/adk/pkg/models/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,34 @@ import (
)

// bedrockToolIDValid matches Bedrock's toolUseId constraint: [a-zA-Z0-9_.:-]+
// bedrockToolNameInvalid matches characters not allowed in Bedrock tool names: [a-zA-Z0-9_-]+
var (
bedrockToolIDValid = regexp.MustCompile(`^[a-zA-Z0-9_.:-]+$`)
bedrockToolIDInvalid = regexp.MustCompile(`[^a-zA-Z0-9_.:-]`)
bedrockToolIDValid = regexp.MustCompile(`^[a-zA-Z0-9_.:-]+$`)
bedrockToolIDInvalid = regexp.MustCompile(`[^a-zA-Z0-9_.:-]`)
bedrockToolNameInvalid = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
)

// sanitizeBedrockToolName returns a valid Bedrock tool name.
// Bedrock requires tool names to match [a-zA-Z0-9_-]+ and be non-empty.
// nameMap caches original->sanitized so repeated lookups for the same name are
// consistent. counter is incremented only when a synthetic name is needed.
func sanitizeBedrockToolName(name string, nameMap map[string]string, counter *int) string {
if name == "" {
*counter++
return fmt.Sprintf("tool_fn_%d", *counter)
}
if sanitized, ok := nameMap[name]; ok {
return sanitized
}
sanitized := bedrockToolNameInvalid.ReplaceAllString(name, "_")
if sanitized == "" {
*counter++
sanitized = fmt.Sprintf("tool_fn_%d", *counter)
}
nameMap[name] = sanitized
return sanitized
}

// sanitizeBedrockToolID returns a valid Bedrock toolUseId.
// Bedrock requires toolUseId to match [a-zA-Z0-9_.:-]+ and be non-empty.
// idMap caches original→sanitized so FunctionCall and FunctionResponse
Expand Down Expand Up @@ -121,8 +144,32 @@ func (m *BedrockModel) GenerateContent(ctx context.Context, req *model.LLMReques
modelName = req.Model
}

// Convert content to Bedrock messages
messages, systemInstruction := convertGenaiContentsToBedrockMessages(req.Contents)
// Build tool configuration first so nameMap is available for message conversion.
// convertGenaiToolsToBedrock sanitizes tool names and returns the
// original->sanitized mapping so the same mapping can be applied to
// conversation history and reversed when restoring names from responses.
var toolConfig *types.ToolConfiguration
nameMap := make(map[string]string)
if req.Config != nil && len(req.Config.Tools) > 0 {
tools, nm := convertGenaiToolsToBedrock(req.Config.Tools)
nameMap = nm
if len(tools) > 0 {
toolConfig = &types.ToolConfiguration{
Tools: tools,
}
}
}

// Build reverse map for restoring original tool names from Bedrock responses.
reverseNameMap := make(map[string]string, len(nameMap))
for orig, sanitized := range nameMap {
reverseNameMap[sanitized] = orig
}

// Convert content to Bedrock messages.
// nameMap is passed so that any tool call recorded in conversation history
// is written with the sanitized name Bedrock already knows about.
messages, systemInstruction := convertGenaiContentsToBedrockMessages(req.Contents, nameMap)

// Build inference config
var inferenceConfig *types.InferenceConfiguration
Expand All @@ -147,27 +194,15 @@ func (m *BedrockModel) GenerateContent(ctx context.Context, req *model.LLMReques
})
}

// Build tool configuration
var toolConfig *types.ToolConfiguration
if req.Config != nil && len(req.Config.Tools) > 0 {
tools := convertGenaiToolsToBedrock(req.Config.Tools)
if len(tools) > 0 {
toolConfig = &types.ToolConfiguration{
Tools: tools,
}
}
}

// Build model-specific additional fields (Claude top_k, thinking, etc.)
additionalFields := m.buildAdditionalModelRequestFields()

// Set telemetry attributes
telemetry.SetLLMRequestAttributes(ctx, modelName, req)

if stream {
m.generateStreaming(ctx, modelName, messages, systemPrompt, inferenceConfig, toolConfig, additionalFields, yield)
m.generateStreaming(ctx, modelName, messages, systemPrompt, inferenceConfig, toolConfig, additionalFields, reverseNameMap, yield)
} else {
m.generateNonStreaming(ctx, modelName, messages, systemPrompt, inferenceConfig, toolConfig, additionalFields, yield)
m.generateNonStreaming(ctx, modelName, messages, systemPrompt, inferenceConfig, toolConfig, additionalFields, reverseNameMap, yield)
}
}
}
Expand All @@ -185,7 +220,8 @@ func (m *BedrockModel) buildAdditionalModelRequestFields() document.Interface {

// generateStreaming handles streaming responses from Bedrock ConverseStream.
// It properly handles both text and tool use content blocks during streaming.
func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, messages []types.Message, systemPrompt []types.SystemContentBlock, inferenceConfig *types.InferenceConfiguration, toolConfig *types.ToolConfiguration, additionalFields document.Interface, yield func(*model.LLMResponse, error) bool) {
// reverseNameMap maps sanitized Bedrock tool names back to their original names.
func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, messages []types.Message, systemPrompt []types.SystemContentBlock, inferenceConfig *types.InferenceConfiguration, toolConfig *types.ToolConfiguration, additionalFields document.Interface, reverseNameMap map[string]string, yield func(*model.LLMResponse, error) bool) {
output, err := m.Client.ConverseStream(ctx, &bedrockruntime.ConverseStreamInput{
ModelId: aws.String(modelId),
Messages: messages,
Expand Down Expand Up @@ -266,11 +302,17 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me
if stop, ok := event.(*types.ConverseStreamOutputMemberContentBlockStop); ok {
blockIdx := aws.ToInt32(stop.Value.ContentBlockIndex)
if tc, ok := toolCalls[blockIdx]; ok {
// Tool use block completed - parse the accumulated JSON and create FunctionCall
// Tool use block completed - parse the accumulated JSON and create FunctionCall.
// Restore the original tool name from the reverse map so the ADK framework
// can dispatch to the correctly registered tool.
originalName := tc.Name
if orig, found := reverseNameMap[tc.Name]; found {
originalName = orig
}
args := tc.parseArgs()
functionCall := &genai.FunctionCall{
ID: tc.ID,
Name: tc.Name,
Name: originalName,
Args: args,
}
completedToolCalls = append(completedToolCalls, &genai.Part{FunctionCall: functionCall})
Expand Down Expand Up @@ -338,7 +380,8 @@ func (tc *streamingToolCall) parseArgs() map[string]any {
}

// generateNonStreaming handles non-streaming responses from Bedrock Converse.
func (m *BedrockModel) generateNonStreaming(ctx context.Context, modelId string, messages []types.Message, systemPrompt []types.SystemContentBlock, inferenceConfig *types.InferenceConfiguration, toolConfig *types.ToolConfiguration, additionalFields document.Interface, yield func(*model.LLMResponse, error) bool) {
// reverseNameMap maps sanitized Bedrock tool names back to their original names.
func (m *BedrockModel) generateNonStreaming(ctx context.Context, modelId string, messages []types.Message, systemPrompt []types.SystemContentBlock, inferenceConfig *types.InferenceConfiguration, toolConfig *types.ToolConfiguration, additionalFields document.Interface, reverseNameMap map[string]string, yield func(*model.LLMResponse, error) bool) {
output, err := m.Client.Converse(ctx, &bedrockruntime.ConverseInput{
ModelId: aws.String(modelId),
Messages: messages,
Expand Down Expand Up @@ -366,9 +409,15 @@ func (m *BedrockModel) generateNonStreaming(ctx context.Context, modelId string,
}
// Handle tool use content
if toolUseBlock, ok := block.(*types.ContentBlockMemberToolUse); ok {
// Restore the original tool name so the ADK framework can dispatch
// to the correctly registered tool.
toolName := aws.ToString(toolUseBlock.Value.Name)
if orig, found := reverseNameMap[toolName]; found {
toolName = orig
}
functionCall := &genai.FunctionCall{
ID: aws.ToString(toolUseBlock.Value.ToolUseId),
Name: aws.ToString(toolUseBlock.Value.Name),
Name: toolName,
}
// Convert document.Interface to map using the String() method and JSON parsing
// The document type in AWS SDK implements String() that returns JSON
Expand Down Expand Up @@ -425,7 +474,10 @@ func documentToMap(doc document.Interface) map[string]any {
}

// convertGenaiContentsToBedrockMessages converts genai.Content to Bedrock Converse API message format.
func convertGenaiContentsToBedrockMessages(contents []*genai.Content) ([]types.Message, string) {
// nameMap is the original->sanitized tool name map produced by convertGenaiToolsToBedrock.
// Any FunctionCall found in the conversation history is written with the sanitized name so
// that Bedrock can correlate it with the tool spec it already received. A nil nameMap is safe.
func convertGenaiContentsToBedrockMessages(contents []*genai.Content, nameMap map[string]string) ([]types.Message, string) {
var messages []types.Message
var systemInstruction string

Expand Down Expand Up @@ -465,11 +517,17 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content) ([]types.M
continue
}

// Handle function call (tool use in Bedrock terminology)
// Handle function call (tool use in Bedrock terminology).
// Use the sanitized name from nameMap so Bedrock can correlate the
// tool call with the tool spec sent in the same request.
if part.FunctionCall != nil {
callName := part.FunctionCall.Name
if sanitized, ok := nameMap[callName]; ok {
callName = sanitized
}
toolUse := types.ToolUseBlock{
ToolUseId: aws.String(sanitizeBedrockToolID(part.FunctionCall.ID, idMap, &idCounter)),
Name: aws.String(part.FunctionCall.Name),
Name: aws.String(callName),
Input: document.NewLazyDocument(part.FunctionCall.Args),
}
contentBlocks = append(contentBlocks, &types.ContentBlockMemberToolUse{
Expand Down Expand Up @@ -507,11 +565,16 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content) ([]types.M
}

// convertGenaiToolsToBedrock converts genai.Tool to Bedrock Tool format.
func convertGenaiToolsToBedrock(tools []*genai.Tool) []types.Tool {
// It sanitizes tool names to satisfy Bedrock's [a-zA-Z0-9_-]+ constraint and
// returns the original->sanitized name mapping so callers can apply it to
// conversation history and reverse it when restoring names from responses.
func convertGenaiToolsToBedrock(tools []*genai.Tool) ([]types.Tool, map[string]string) {
if len(tools) == 0 {
return nil
return nil, nil
}

nameMap := make(map[string]string)
nameCounter := 0
var bedrockTools []types.Tool

for _, tool := range tools {
Expand All @@ -525,7 +588,7 @@ func convertGenaiToolsToBedrock(tools []*genai.Tool) []types.Tool {
}

// Build input schema as JSON document.
// MCP tools and built-in local toolsset ParametersJsonSchema
// MCP tools and built-in local toolsets set ParametersJsonSchema.
var schema map[string]any
if decl.ParametersJsonSchema != nil {
schema = parametersJsonSchemaToMap(decl.ParametersJsonSchema)
Expand All @@ -536,7 +599,7 @@ func convertGenaiToolsToBedrock(tools []*genai.Tool) []types.Tool {
// then lowercase all type values to match JSON Schema standard.
schema = genaiSchemaToMap(decl.Parameters)
}
// Fallback to empty object if no schema is found
// Fallback to empty object if no schema is found.
if schema == nil {
schema = map[string]any{"type": "object", "properties": map[string]any{}}
}
Expand All @@ -545,8 +608,12 @@ func convertGenaiToolsToBedrock(tools []*genai.Tool) []types.Tool {
Value: document.NewLazyDocument(schema),
}

// Sanitize the tool name: MCP tool names often contain dots, colons,
// or spaces (e.g. "fetch.get_url") that Bedrock rejects.
sanitizedName := sanitizeBedrockToolName(decl.Name, nameMap, &nameCounter)

toolSpec := types.ToolSpecification{
Name: aws.String(decl.Name),
Name: aws.String(sanitizedName),
Description: aws.String(decl.Description),
InputSchema: inputSchema,
}
Expand All @@ -558,7 +625,7 @@ func convertGenaiToolsToBedrock(tools []*genai.Tool) []types.Tool {
}
}

return bedrockTools
return bedrockTools, nameMap
}

// bedrockStopReasonToGenai maps Bedrock stop reason to genai.FinishReason.
Expand Down
95 changes: 90 additions & 5 deletions go/adk/pkg/models/bedrock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func TestConvertGenaiContentsToBedrockMessages(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msgs, systemText := convertGenaiContentsToBedrockMessages(tt.contents)
msgs, systemText := convertGenaiContentsToBedrockMessages(tt.contents, nil)
if len(msgs) != tt.wantMsgCount {
t.Errorf("expected %d messages, got %d", tt.wantMsgCount, len(msgs))
}
Expand All @@ -124,7 +124,7 @@ func TestConvertGenaiContentsToBedrockMessages(t *testing.T) {
// sources: genai.Schema (declaration-based), map[string]any (MCP), and
// *jsonschema.Schema (functiontool.New).
func TestConvertGenaiToolsToBedrock(t *testing.T) {
extractSchema := func(t *testing.T, tools []types.Tool) map[string]any {
extractSchema := func(t *testing.T, tools []types.Tool, _ map[string]string) map[string]any {
t.Helper()
if len(tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(tools))
Expand Down Expand Up @@ -162,7 +162,8 @@ func TestConvertGenaiToolsToBedrock(t *testing.T) {
},
}}}}

schema := extractSchema(t, convertGenaiToolsToBedrock(tools))
bt1, nm1 := convertGenaiToolsToBedrock(tools)
schema := extractSchema(t, bt1, nm1)

props := schema["properties"].(map[string]any)
for prop, want := range map[string]string{"location": "string", "count": "integer", "detailed": "boolean"} {
Expand All @@ -189,7 +190,8 @@ func TestConvertGenaiToolsToBedrock(t *testing.T) {
},
}}}}

schema := extractSchema(t, convertGenaiToolsToBedrock(tools))
bt2, nm2 := convertGenaiToolsToBedrock(tools)
schema := extractSchema(t, bt2, nm2)
props, ok := schema["properties"].(map[string]any)
if !ok || len(props) == 0 {
t.Fatalf("expected non-empty properties, got %v", schema["properties"])
Expand All @@ -209,7 +211,8 @@ func TestConvertGenaiToolsToBedrock(t *testing.T) {
ParametersJsonSchema: s,
}}}}

schema := extractSchema(t, convertGenaiToolsToBedrock(tools))
bt3, nm3 := convertGenaiToolsToBedrock(tools)
schema := extractSchema(t, bt3, nm3)
props, ok := schema["properties"].(map[string]any)
if !ok || len(props) == 0 {
t.Fatalf("expected non-empty properties (means *jsonschema.Schema was not converted): %v", schema["properties"])
Expand Down Expand Up @@ -310,6 +313,88 @@ func TestSanitizeBedrockToolID(t *testing.T) {
})
}

func TestSanitizeBedrockToolName(t *testing.T) {
tests := []struct {
name string
tool string
want string
}{
{name: "valid name unchanged", tool: "get_weather", want: "get_weather"},
{name: "valid name with hyphen", tool: "fetch-data", want: "fetch-data"},
{name: "dot replaced", tool: "fetch.get_url", want: "fetch_get_url"},
{name: "colon replaced", tool: "filesystem:read_file", want: "filesystem_read_file"},
{name: "space replaced", tool: "my tool", want: "my_tool"},
{name: "multiple invalid chars", tool: "a.b:c d", want: "a_b_c_d"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nameMap := make(map[string]string)
counter := 0
if got := sanitizeBedrockToolName(tt.tool, nameMap, &counter); got != tt.want {
t.Errorf("sanitizeBedrockToolName(%q) = %q, want %q", tt.tool, got, tt.want)
}
})
}

t.Run("empty name gets synthetic", func(t *testing.T) {
nameMap, counter := make(map[string]string), 0
got := sanitizeBedrockToolName("", nameMap, &counter)
if got != "tool_fn_1" {
t.Errorf("expected tool_fn_1, got %q", got)
}
if counter != 1 {
t.Errorf("expected counter=1, got %d", counter)
}
})

t.Run("caching returns same sanitized name", func(t *testing.T) {
nameMap, counter := make(map[string]string), 0
first := sanitizeBedrockToolName("fetch.get_url", nameMap, &counter)
second := sanitizeBedrockToolName("fetch.get_url", nameMap, &counter)
if first != second {
t.Errorf("expected same cached result, got %q and %q", first, second)
}
if counter != 0 {
t.Errorf("expected counter unchanged, got %d", counter)
}
})
}

func TestConvertGenaiToolsToBedrockSanitizesNames(t *testing.T) {
tools := []*genai.Tool{{FunctionDeclarations: []*genai.FunctionDeclaration{
{Name: "fetch.get_url", Description: "Fetch a URL"},
{Name: "filesystem:read_file", Description: "Read a file"},
}}}

bedrockTools, nameMap := convertGenaiToolsToBedrock(tools)
if len(bedrockTools) != 2 {
t.Fatalf("expected 2 tools, got %d", len(bedrockTools))
}

// Verify sanitized names in the Bedrock tool specs.
for i, want := range []string{"fetch_get_url", "filesystem_read_file"} {
tm, ok := bedrockTools[i].(*types.ToolMemberToolSpec)
if !ok {
t.Fatalf("tool %d: expected *types.ToolMemberToolSpec", i)
}
got := ""
if tm.Value.Name != nil {
got = *tm.Value.Name
}
if got != want {
t.Errorf("tool %d: expected name %q, got %q", i, want, got)
}
}

// Verify nameMap contains the mappings.
if nameMap["fetch.get_url"] != "fetch_get_url" {
t.Errorf("nameMap[fetch.get_url] = %q, want fetch_get_url", nameMap["fetch.get_url"])
}
if nameMap["filesystem:read_file"] != "filesystem_read_file" {
t.Errorf("nameMap[filesystem:read_file] = %q, want filesystem_read_file", nameMap["filesystem:read_file"])
}
}

func TestStreamingToolCallParseArgs(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading
Loading