From 9614c9f690697905370f00c383299cdef4a52a48 Mon Sep 17 00:00:00 2001 From: Carson Loyal Date: Sun, 17 May 2026 17:02:15 -0500 Subject: [PATCH] feat: dynamic Bedrock model listing with LiteLLM metadata enrichment Replace the hardcoded Bedrock model list with live AWS API calls to ListFoundationModels and ListInferenceProfiles. This ensures users only see models they have access to in their account and region, solving the GovCloud model filtering problem. Key changes: - Add aws-sdk-bedrock dependency for control plane API - Implement init_control_client() for model listing (separate from chat client) - Call ListFoundationModels filtered to ON_DEMAND, streaming, ACTIVE - Call ListInferenceProfiles for cross-region inference profiles - Enrich model metadata (context_length, tools, reasoning, vision) from the LiteLLM community registry JSON - Graceful fallback to hardcoded model list when API permissions are missing - Add optional AWS_USE_FIPS URL param for GovCloud FIPS endpoints - Add 10 new unit tests covering model mapping, enrichment, fallback, and FIPS config GovCloud users with us-gov-west-1 region will now see only their available models with correct context window sizes. Commercial users benefit from seeing only models they have enabled. Co-Authored-By: ForgeCode --- Cargo.lock | 53 +- Cargo.toml | 1 + crates/forge_repo/Cargo.toml | 1 + crates/forge_repo/src/provider/bedrock.rs | 613 ++++++++++++++++++- crates/forge_repo/src/provider/provider.json | 2 +- 5 files changed, 641 insertions(+), 29 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4d70c0b3af..9d43e1dd13 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -331,6 +331,30 @@ dependencies = [ "uuid", ] +[[package]] +name = "aws-sdk-bedrock" +version = "1.141.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a16484ce62b16cadf941e1c408d9b73afb9e6fc456573e5a9681e38d45fb00a" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-observability", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "http 1.4.0", + "regex-lite", + "tracing", +] + [[package]] name = "aws-sdk-bedrockruntime" version = "1.130.0" @@ -1023,7 +1047,7 @@ version = "3.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.61.2", ] [[package]] @@ -1776,7 +1800,7 @@ dependencies = [ "libc", "option-ext", "redox_users 0.5.2", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -1949,7 +1973,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -2545,6 +2569,7 @@ dependencies = [ "async-trait", "aws-config", "aws-credential-types", + "aws-sdk-bedrock", "aws-sdk-bedrockruntime", "aws-smithy-async", "aws-smithy-runtime", @@ -4571,7 +4596,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.5.10", + "socket2 0.6.3", "tokio", "tower-service", "tracing", @@ -4885,7 +4910,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" dependencies = [ "hermit-abi", "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -4937,7 +4962,7 @@ dependencies = [ "portable-atomic", "portable-atomic-util", "serde_core", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -5611,7 +5636,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -6427,7 +6452,7 @@ dependencies = [ "quinn-udp", "rustc-hash", "rustls 0.23.40", - "socket2 0.5.10", + "socket2 0.6.3", "thiserror 2.0.18", "tokio", "tracing", @@ -6465,7 +6490,7 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.5.10", + "socket2 0.6.3", "tracing", "windows-sys 0.59.0", ] @@ -7077,7 +7102,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.12.1", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -7157,7 +7182,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -8110,7 +8135,7 @@ dependencies = [ "getrandom 0.4.2", "once_cell", "rustix 1.1.4", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -8173,7 +8198,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "230a1b821ccbd75b185820a1f1ff7b14d21da1e442e22c0863ea5f08771a8874" dependencies = [ "rustix 1.1.4", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -9262,7 +9287,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 208aa18735..002b1870d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ async-recursion = "1.1.1" async-stream = "0.3" async-trait = "0.1.89" aws-config = { version = "1.8.13", features = ["behavior-version-latest", "sso"], default-features = false } +aws-sdk-bedrock = { version = "1.129.0", features = ["behavior-version-latest"], default-features = false } aws-sdk-bedrockruntime = { version = "1.129.0", features = ["behavior-version-latest"], default-features = false } aws-credential-types = "1.2.14" aws-smithy-types = "1.4.3" diff --git a/crates/forge_repo/Cargo.toml b/crates/forge_repo/Cargo.toml index 788afb768f..c0f1dd9d31 100644 --- a/crates/forge_repo/Cargo.toml +++ b/crates/forge_repo/Cargo.toml @@ -31,6 +31,7 @@ forge_eventsource_stream.workspace = true handlebars.workspace = true merge.workspace = true aws-config.workspace = true +aws-sdk-bedrock.workspace = true aws-sdk-bedrockruntime.workspace = true aws-credential-types.workspace = true aws-smithy-types.workspace = true diff --git a/crates/forge_repo/src/provider/bedrock.rs b/crates/forge_repo/src/provider/bedrock.rs index 7b3689c8e0..5a561a9d90 100644 --- a/crates/forge_repo/src/provider/bedrock.rs +++ b/crates/forge_repo/src/provider/bedrock.rs @@ -5,8 +5,8 @@ use aws_sdk_bedrockruntime::Client; use aws_sdk_bedrockruntime::config::Token; use forge_config::RetryConfig; use forge_domain::{ - AuthDetails, ChatCompletionMessage, ChatRepository, Context, Model, ModelId, Provider, - ResultStream, Transformer, + AuthDetails, ChatCompletionMessage, ChatRepository, Context, InputModality, Model, ModelId, + Provider, ResultStream, Transformer, }; use reqwest::Url; use tokio::sync::OnceCell; @@ -33,7 +33,9 @@ struct BedrockProvider { provider: Provider, region: String, auth_mode: BedrockAuthMode, + use_fips: bool, client: OnceCell, + control_client: OnceCell, } impl BedrockProvider { @@ -69,7 +71,58 @@ impl BedrockProvider { .map(|v| v.to_string()) .unwrap_or_else(|| "us-east-1".to_string()); - Ok(Self { provider, region, auth_mode, client: OnceCell::new() }) + // Extract optional FIPS mode for GovCloud + let fips_param: forge_domain::URLParam = "AWS_USE_FIPS".to_string().into(); + let use_fips = credential + .url_params + .get(&fips_param) + .map(|v| v.to_string() == "true") + .unwrap_or(false); + + Ok(Self { provider, region, auth_mode, use_fips, client: OnceCell::new(), control_client: OnceCell::new() }) + } + + /// Initializes and returns the AWS Bedrock control plane client + /// + /// Used for `ListFoundationModels` and `ListInferenceProfiles` API calls. + /// Lazily initialized on first call and reused for subsequent calls. + /// + /// # Errors + /// + /// Returns an error if credentials cannot be resolved. + async fn init_control_client(&self) -> Result<&aws_sdk_bedrock::Client> { + self.control_client + .get_or_try_init(|| async { + match &self.auth_mode { + BedrockAuthMode::BearerToken(token) => { + let mut builder = aws_sdk_bedrock::Config::builder() + .region(aws_sdk_bedrock::config::Region::new( + self.region.clone(), + )) + .bearer_token(aws_sdk_bedrock::config::Token::new( + token.clone(), + None, + )); + if self.use_fips { + builder = builder.use_fips(true); + } + Ok(aws_sdk_bedrock::Client::from_conf(builder.build())) + } + BedrockAuthMode::AwsProfile(profile) => { + let mut config_loader = aws_config::from_env() + .profile_name(profile) + .region(aws_sdk_bedrock::config::Region::new( + self.region.clone(), + )); + if self.use_fips { + config_loader = config_loader.use_fips(true); + } + let sdk_config = config_loader.load().await; + Ok(aws_sdk_bedrock::Client::new(&sdk_config)) + } + } + }) + .await } /// Initializes and returns the AWS Bedrock client @@ -87,22 +140,26 @@ impl BedrockProvider { .get_or_try_init(|| async { match &self.auth_mode { BedrockAuthMode::BearerToken(token) => { - let config = aws_sdk_bedrockruntime::Config::builder() + let mut builder = aws_sdk_bedrockruntime::Config::builder() .region(aws_sdk_bedrockruntime::config::Region::new( self.region.clone(), )) - .bearer_token(Token::new(token.clone(), None)) - .build(); - Ok(aws_sdk_bedrockruntime::Client::from_conf(config)) + .bearer_token(Token::new(token.clone(), None)); + if self.use_fips { + builder = builder.use_fips(true); + } + Ok(aws_sdk_bedrockruntime::Client::from_conf(builder.build())) } BedrockAuthMode::AwsProfile(profile) => { - let sdk_config = aws_config::from_env() + let mut config_loader = aws_config::from_env() .profile_name(profile) .region(aws_sdk_bedrockruntime::config::Region::new( self.region.clone(), - )) - .load() - .await; + )); + if self.use_fips { + config_loader = config_loader.use_fips(true); + } + let sdk_config = config_loader.load().await; Ok(aws_sdk_bedrockruntime::Client::new(&sdk_config)) } } @@ -290,15 +347,257 @@ impl BedrockProvider { Ok(Box::pin(stream)) } - /// Get available models + /// Get available models by querying the AWS Bedrock control plane APIs. + /// + /// Calls `ListFoundationModels` and `ListInferenceProfiles` to get the + /// account/region-aware model list. Models are filtered to those supporting + /// on-demand inference and streaming. Results are deduplicated by model ID. + /// + /// Falls back to hardcoded models from `provider.json` if the API calls + /// fail (e.g., missing IAM permissions). pub async fn models(&self) -> Result> { - // Bedrock doesn't have a models list API - // Return hardcoded models from configuration + match self.fetch_live_models().await { + Ok(models) => Ok(models), + Err(err) => { + tracing::warn!( + error = %err, + "Failed to list models from Bedrock API, falling back to hardcoded list. \ + Ensure your IAM policy includes bedrock:ListFoundationModels permission." + ); + self.hardcoded_models() + } + } + } + + /// Returns the hardcoded model list from provider.json configuration + fn hardcoded_models(&self) -> Result> { match &self.provider.models { Some(forge_domain::ModelSource::Hardcoded(models)) => Ok(models.clone()), _ => Ok(vec![]), } } + + /// Fetches live models from AWS Bedrock APIs + async fn fetch_live_models(&self) -> Result> { + use std::collections::HashSet; + + let client = self.init_control_client().await?; + + // Fetch foundation models + let fm_response = client + .list_foundation_models() + .send() + .await + .context("Failed to call ListFoundationModels")?; + + let mut seen_ids = HashSet::new(); + let mut models = Vec::new(); + + for summary in fm_response.model_summaries() { + // Filter: must support ON_DEMAND inference + let supports_on_demand = summary + .inference_types_supported() + .iter() + .any(|t| *t == aws_sdk_bedrock::types::InferenceType::OnDemand); + if !supports_on_demand { + continue; + } + + // Filter: must support response streaming + if summary.response_streaming_supported() != Some(true) { + continue; + } + + // Filter: skip LEGACY lifecycle status + if let Some(lc) = summary.model_lifecycle() { + if lc.status == aws_sdk_bedrock::types::FoundationModelLifecycleStatus::Legacy { + continue; + } + } + + let model_id = summary.model_id(); + if model_id.is_empty() || !seen_ids.insert(model_id.to_string()) { + continue; + } + + models.push(Self::foundation_model_to_domain(summary)); + } + + // Fetch inference profiles (cross-region profiles like us.anthropic.*) + match client.list_inference_profiles().send().await { + Ok(ip_response) => { + for profile in ip_response.inference_profile_summaries() { + let profile_id = profile.inference_profile_id().to_string(); + if profile_id.is_empty() || !seen_ids.insert(profile_id.clone()) { + continue; + } + + models.push(Model { + id: ModelId::from(profile_id), + name: Some(profile.inference_profile_name().to_string()), + description: profile.description().map(|s| s.to_string()), + context_length: None, + tools_supported: None, + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities: vec![InputModality::Text], + }); + } + } + Err(err) => { + tracing::debug!( + error = %err, + "ListInferenceProfiles failed; continuing with foundation models only" + ); + } + } + + // Enrich models with LiteLLM registry metadata (context_length, capabilities) + Self::enrich_with_litellm_registry(&mut models).await; + + Ok(models) + } + + /// LiteLLM community registry URL for model metadata + const LITELLM_REGISTRY_URL: &'static str = + "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"; + + /// Fetches the LiteLLM community model registry and enriches models with + /// metadata such as context_length, tool support, reasoning support, and + /// vision capabilities. + /// + /// This is best-effort: if the fetch fails, models are left unenriched. + async fn enrich_with_litellm_registry(models: &mut [Model]) { + let registry = match Self::fetch_litellm_registry().await { + Ok(r) => r, + Err(err) => { + tracing::debug!( + error = %err, + "Failed to fetch LiteLLM registry; models will lack metadata enrichment" + ); + return; + } + }; + + for model in models.iter_mut() { + Self::enrich_model_from_registry(model, ®istry); + } + } + + /// Fetches the LiteLLM registry JSON + async fn fetch_litellm_registry( + ) -> Result> { + let response = reqwest::get(Self::LITELLM_REGISTRY_URL) + .await + .context("Failed to fetch LiteLLM registry")?; + + let registry: std::collections::HashMap = response + .json() + .await + .context("Failed to parse LiteLLM registry JSON")?; + + Ok(registry) + } + + /// Enriches a single model with data from the LiteLLM registry. + /// + /// Tries multiple key patterns to find the model in the registry: + /// - `bedrock/` (standard Bedrock prefix in LiteLLM) + /// - Raw `` as fallback + fn enrich_model_from_registry( + model: &mut Model, + registry: &std::collections::HashMap, + ) { + let model_id = model.id.as_str(); + + // Try registry lookup with bedrock/ prefix first, then raw ID + let entry = registry + .get(&format!("bedrock/{}", model_id)) + .or_else(|| registry.get(model_id)); + + let Some(entry) = entry else { + return; + }; + + // Enrich context_length from max_input_tokens + if model.context_length.is_none() { + if let Some(max_input) = entry.get("max_input_tokens").and_then(|v| v.as_u64()) { + model.context_length = Some(max_input); + } + } + + // Enrich tool support + if model.tools_supported.is_none() { + if let Some(supports_fc) = entry + .get("supports_function_calling") + .and_then(|v| v.as_bool()) + { + model.tools_supported = Some(supports_fc); + } + } + + // Enrich reasoning support + if model.supports_reasoning.is_none() { + if let Some(supports_reasoning) = + entry.get("supports_reasoning").and_then(|v| v.as_bool()) + { + model.supports_reasoning = Some(supports_reasoning); + } + } + + // Enrich parallel tool calls support + if model.supports_parallel_tool_calls.is_none() { + if let Some(supports_parallel) = entry + .get("supports_parallel_function_calling") + .and_then(|v| v.as_bool()) + { + model.supports_parallel_tool_calls = Some(supports_parallel); + } + } + + // Enrich vision support -- add Image modality if supported + if let Some(true) = entry.get("supports_vision").and_then(|v| v.as_bool()) { + if !model.input_modalities.contains(&InputModality::Image) { + model.input_modalities.push(InputModality::Image); + } + } + } + + /// Converts a `FoundationModelSummary` to a domain `Model` + fn foundation_model_to_domain( + summary: &aws_sdk_bedrock::types::FoundationModelSummary, + ) -> Model { + let model_id = summary.model_id().to_string(); + let model_name = summary.model_name().map(|s| s.to_string()); + + // Map input modalities + let input_modalities: Vec = summary + .input_modalities() + .iter() + .filter_map(|m| match m { + aws_sdk_bedrock::types::ModelModality::Text => Some(InputModality::Text), + aws_sdk_bedrock::types::ModelModality::Image => Some(InputModality::Image), + _ => None, + }) + .collect(); + + let input_modalities = if input_modalities.is_empty() { + vec![InputModality::Text] + } else { + input_modalities + }; + + Model { + id: ModelId::from(model_id), + name: model_name, + description: None, + context_length: None, + tools_supported: None, + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities, + } + } } /// Converts Bedrock stream events to ChatCompletionMessage @@ -1127,7 +1426,9 @@ mod tests { BedrockProvider { provider: provider_fixture("test-token", Some(region)), auth_mode: BedrockAuthMode::BearerToken("test-token".to_string()), + use_fips: false, client: OnceCell::new(), + control_client: OnceCell::new(), region: region.to_string(), } } @@ -1447,7 +1748,9 @@ mod tests { let bedrock = BedrockProvider { provider: fixture_provider, auth_mode: BedrockAuthMode::BearerToken("token".to_string()), + use_fips: false, client: OnceCell::new(), + control_client: OnceCell::new(), region: "us-east-1".to_string(), }; @@ -1462,7 +1765,9 @@ mod tests { let bedrock = BedrockProvider { provider: fixture, auth_mode: BedrockAuthMode::BearerToken("token".to_string()), + use_fips: false, client: OnceCell::new(), + control_client: OnceCell::new(), region: "us-east-1".to_string(), }; @@ -2255,4 +2560,284 @@ mod tests { } assert!(got_text, "Expected text content in stream response"); } + + #[test] + fn test_foundation_model_to_domain_claude() { + use aws_sdk_bedrock::types::{ + FoundationModelLifecycle, FoundationModelLifecycleStatus, FoundationModelSummary, + InferenceType, ModelModality, + }; + + let fixture = FoundationModelSummary::builder() + .model_id("anthropic.claude-3-5-sonnet-20241022-v2:0") + .model_arn("arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-5-sonnet-20241022-v2:0") + .model_name("Claude 3.5 Sonnet v2") + .input_modalities(ModelModality::Text) + .input_modalities(ModelModality::Image) + .inference_types_supported(InferenceType::OnDemand) + .response_streaming_supported(true) + .model_lifecycle( + FoundationModelLifecycle::builder() + .status(FoundationModelLifecycleStatus::Active) + .build() + .unwrap(), + ) + .build() + .unwrap(); + + let actual = BedrockProvider::foundation_model_to_domain(&fixture); + let expected = Model { + id: ModelId::from("anthropic.claude-3-5-sonnet-20241022-v2:0".to_string()), + name: Some("Claude 3.5 Sonnet v2".to_string()), + description: None, + context_length: None, + tools_supported: None, + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities: vec![InputModality::Text, InputModality::Image], + }; + assert_eq!(actual, expected); + } + + #[test] + fn test_foundation_model_to_domain_nova_text_only() { + use aws_sdk_bedrock::types::{ + FoundationModelSummary, InferenceType, ModelModality, + }; + + let fixture = FoundationModelSummary::builder() + .model_id("amazon.nova-pro-v1:0") + .model_arn("arn:aws:bedrock:us-east-1::foundation-model/amazon.nova-pro-v1:0") + .model_name("Amazon Nova Pro") + .input_modalities(ModelModality::Text) + .inference_types_supported(InferenceType::OnDemand) + .response_streaming_supported(true) + .build() + .unwrap(); + + let actual = BedrockProvider::foundation_model_to_domain(&fixture); + let expected = Model { + id: ModelId::from("amazon.nova-pro-v1:0".to_string()), + name: Some("Amazon Nova Pro".to_string()), + description: None, + context_length: None, + tools_supported: None, + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities: vec![InputModality::Text], + }; + assert_eq!(actual, expected); + } + + #[test] + fn test_foundation_model_to_domain_no_modalities_defaults_to_text() { + use aws_sdk_bedrock::types::FoundationModelSummary; + + let fixture = FoundationModelSummary::builder() + .model_id("some.model-v1") + .model_arn("arn:aws:bedrock:us-east-1::foundation-model/some.model-v1") + .model_name("Some Model") + .build() + .unwrap(); + + let actual = BedrockProvider::foundation_model_to_domain(&fixture); + assert_eq!(actual.input_modalities, vec![InputModality::Text]); + } + + #[test] + fn test_enrich_model_from_registry_with_bedrock_prefix() { + let mut fixture = Model { + id: ModelId::from("anthropic.claude-3-5-sonnet-20241022-v2:0".to_string()), + name: Some("Claude 3.5 Sonnet v2".to_string()), + description: None, + context_length: None, + tools_supported: None, + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities: vec![InputModality::Text], + }; + + let mut registry = std::collections::HashMap::new(); + registry.insert( + "bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0".to_string(), + serde_json::json!({ + "max_input_tokens": 200000, + "max_output_tokens": 8192, + "supports_function_calling": true, + "supports_reasoning": true, + "supports_parallel_function_calling": true, + "supports_vision": true + }), + ); + + BedrockProvider::enrich_model_from_registry(&mut fixture, ®istry); + + assert_eq!(fixture.context_length, Some(200000)); + assert_eq!(fixture.tools_supported, Some(true)); + assert_eq!(fixture.supports_reasoning, Some(true)); + assert_eq!(fixture.supports_parallel_tool_calls, Some(true)); + assert!(fixture.input_modalities.contains(&InputModality::Image)); + } + + #[test] + fn test_enrich_model_from_registry_raw_id_fallback() { + let mut fixture = Model { + id: ModelId::from("meta.llama3-1-405b-instruct-v1:0".to_string()), + name: Some("Llama 3.1 405B".to_string()), + description: None, + context_length: None, + tools_supported: None, + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities: vec![InputModality::Text], + }; + + let mut registry = std::collections::HashMap::new(); + registry.insert( + "meta.llama3-1-405b-instruct-v1:0".to_string(), + serde_json::json!({ + "max_input_tokens": 128000, + "supports_function_calling": true, + "supports_vision": false + }), + ); + + BedrockProvider::enrich_model_from_registry(&mut fixture, ®istry); + + assert_eq!(fixture.context_length, Some(128000)); + assert_eq!(fixture.tools_supported, Some(true)); + // Vision is false, so Image should NOT be added + assert!(!fixture.input_modalities.contains(&InputModality::Image)); + } + + #[test] + fn test_enrich_model_from_registry_not_found() { + let mut fixture = Model { + id: ModelId::from("unknown.model-v1".to_string()), + name: None, + description: None, + context_length: None, + tools_supported: None, + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities: vec![InputModality::Text], + }; + + let registry = std::collections::HashMap::new(); + + BedrockProvider::enrich_model_from_registry(&mut fixture, ®istry); + + // Nothing should change + assert_eq!(fixture.context_length, None); + assert_eq!(fixture.tools_supported, None); + assert_eq!(fixture.input_modalities, vec![InputModality::Text]); + } + + #[test] + fn test_enrich_model_does_not_overwrite_existing_values() { + let mut fixture = Model { + id: ModelId::from("anthropic.claude-v2".to_string()), + name: None, + description: None, + context_length: Some(100000), // Already set + tools_supported: Some(false), // Already set + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities: vec![InputModality::Text], + }; + + let mut registry = std::collections::HashMap::new(); + registry.insert( + "bedrock/anthropic.claude-v2".to_string(), + serde_json::json!({ + "max_input_tokens": 200000, + "supports_function_calling": true + }), + ); + + BedrockProvider::enrich_model_from_registry(&mut fixture, ®istry); + + // Should NOT overwrite existing values + assert_eq!(fixture.context_length, Some(100000)); + assert_eq!(fixture.tools_supported, Some(false)); + } + + #[tokio::test] + async fn test_models_fallback_on_api_error() { + use forge_domain::ModelSource; + + // Create a provider with hardcoded models but invalid credentials + // The API call will fail, triggering fallback to hardcoded list + let mut fixture_provider = provider_fixture("invalid-token", None); + let fixture_models = vec![Model { + id: ModelId::from("fallback-model".to_string()), + name: Some("Fallback Model".to_string()), + description: None, + context_length: Some(128000), + tools_supported: Some(true), + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities: vec![InputModality::Text], + }]; + fixture_provider.models = Some(ModelSource::Hardcoded(fixture_models.clone())); + + let bedrock = BedrockProvider { + provider: fixture_provider, + auth_mode: BedrockAuthMode::BearerToken("invalid-token".to_string()), + use_fips: false, + client: OnceCell::new(), + control_client: OnceCell::new(), + region: "us-east-1".to_string(), + }; + + // models() should fall back to hardcoded list when API fails + let actual = bedrock.models().await.unwrap(); + let expected = fixture_models; + assert_eq!(actual, expected); + } + + #[test] + fn test_fips_mode_parsing() { + use forge_domain::{ + ApiKey, AuthCredential, AuthDetails, ProviderId, ProviderResponse, ProviderType, + URLParam, URLParamValue, + }; + + let mut url_params = std::collections::HashMap::new(); + url_params.insert( + URLParam::from("AWS_REGION".to_string()), + URLParamValue::from("us-gov-west-1".to_string()), + ); + url_params.insert( + URLParam::from("AWS_USE_FIPS".to_string()), + URLParamValue::from("true".to_string()), + ); + + let provider = Provider { + id: ProviderId::from("bedrock".to_string()), + provider_type: ProviderType::Llm, + response: Some(ProviderResponse::Bedrock), + url: Url::parse("https://bedrock-runtime.us-gov-west-1.amazonaws.com").unwrap(), + models: None, + auth_methods: vec![], + url_params: vec![], + credential: Some(AuthCredential { + id: ProviderId::from("bedrock".to_string()), + auth_details: AuthDetails::ApiKey(ApiKey::from("token".to_string())), + url_params, + }), + custom_headers: None, + }; + + let actual = BedrockProvider::new(provider).unwrap(); + assert_eq!(actual.region, "us-gov-west-1"); + assert!(actual.use_fips); + } + + #[test] + fn test_fips_mode_defaults_to_false() { + let fixture = provider_fixture("token", Some("us-east-1")); + let actual = BedrockProvider::new(fixture).unwrap(); + assert!(!actual.use_fips); + } } diff --git a/crates/forge_repo/src/provider/provider.json b/crates/forge_repo/src/provider/provider.json index 64af4400ed..0c84f6b9fd 100644 --- a/crates/forge_repo/src/provider/provider.json +++ b/crates/forge_repo/src/provider/provider.json @@ -1292,7 +1292,7 @@ }, { "id": "bedrock", - "url_param_vars": ["AWS_REGION"], + "url_param_vars": ["AWS_REGION", {"name": "AWS_USE_FIPS", "options": ["true", "false"], "optional": true}], "response_type": "Bedrock", "url": "https://bedrock-runtime.{{AWS_REGION}}.amazonaws.com", "models": [