diff --git a/Cargo.lock b/Cargo.lock index 4d70c0b3af..9d43e1dd13 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -331,6 +331,30 @@ dependencies = [ "uuid", ] +[[package]] +name = "aws-sdk-bedrock" +version = "1.141.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a16484ce62b16cadf941e1c408d9b73afb9e6fc456573e5a9681e38d45fb00a" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-observability", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "http 1.4.0", + "regex-lite", + "tracing", +] + [[package]] name = "aws-sdk-bedrockruntime" version = "1.130.0" @@ -1023,7 +1047,7 @@ version = "3.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.61.2", ] [[package]] @@ -1776,7 +1800,7 @@ dependencies = [ "libc", "option-ext", "redox_users 0.5.2", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -1949,7 +1973,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -2545,6 +2569,7 @@ dependencies = [ "async-trait", "aws-config", "aws-credential-types", + "aws-sdk-bedrock", "aws-sdk-bedrockruntime", "aws-smithy-async", "aws-smithy-runtime", @@ -4571,7 +4596,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.5.10", + "socket2 0.6.3", "tokio", "tower-service", "tracing", @@ -4885,7 +4910,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" dependencies = [ "hermit-abi", "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -4937,7 +4962,7 @@ dependencies = [ "portable-atomic", "portable-atomic-util", "serde_core", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -5611,7 +5636,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -6427,7 +6452,7 @@ dependencies = [ "quinn-udp", "rustc-hash", "rustls 0.23.40", - "socket2 0.5.10", + "socket2 0.6.3", "thiserror 2.0.18", "tokio", "tracing", @@ -6465,7 +6490,7 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.5.10", + "socket2 0.6.3", "tracing", "windows-sys 0.59.0", ] @@ -7077,7 +7102,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.12.1", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -7157,7 +7182,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -8110,7 +8135,7 @@ dependencies = [ "getrandom 0.4.2", "once_cell", "rustix 1.1.4", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -8173,7 +8198,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "230a1b821ccbd75b185820a1f1ff7b14d21da1e442e22c0863ea5f08771a8874" dependencies = [ "rustix 1.1.4", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -9262,7 +9287,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 208aa18735..002b1870d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ async-recursion = "1.1.1" async-stream = "0.3" async-trait = "0.1.89" aws-config = { version = "1.8.13", features = ["behavior-version-latest", "sso"], default-features = false } +aws-sdk-bedrock = { version = "1.129.0", features = ["behavior-version-latest"], default-features = false } aws-sdk-bedrockruntime = { version = "1.129.0", features = ["behavior-version-latest"], default-features = false } aws-credential-types = "1.2.14" aws-smithy-types = "1.4.3" diff --git a/crates/forge_repo/Cargo.toml b/crates/forge_repo/Cargo.toml index 788afb768f..c0f1dd9d31 100644 --- a/crates/forge_repo/Cargo.toml +++ b/crates/forge_repo/Cargo.toml @@ -31,6 +31,7 @@ forge_eventsource_stream.workspace = true handlebars.workspace = true merge.workspace = true aws-config.workspace = true +aws-sdk-bedrock.workspace = true aws-sdk-bedrockruntime.workspace = true aws-credential-types.workspace = true aws-smithy-types.workspace = true diff --git a/crates/forge_repo/src/provider/bedrock.rs b/crates/forge_repo/src/provider/bedrock.rs index 7b3689c8e0..5a561a9d90 100644 --- a/crates/forge_repo/src/provider/bedrock.rs +++ b/crates/forge_repo/src/provider/bedrock.rs @@ -5,8 +5,8 @@ use aws_sdk_bedrockruntime::Client; use aws_sdk_bedrockruntime::config::Token; use forge_config::RetryConfig; use forge_domain::{ - AuthDetails, ChatCompletionMessage, ChatRepository, Context, Model, ModelId, Provider, - ResultStream, Transformer, + AuthDetails, ChatCompletionMessage, ChatRepository, Context, InputModality, Model, ModelId, + Provider, ResultStream, Transformer, }; use reqwest::Url; use tokio::sync::OnceCell; @@ -33,7 +33,9 @@ struct BedrockProvider { provider: Provider, region: String, auth_mode: BedrockAuthMode, + use_fips: bool, client: OnceCell, + control_client: OnceCell, } impl BedrockProvider { @@ -69,7 +71,58 @@ impl BedrockProvider { .map(|v| v.to_string()) .unwrap_or_else(|| "us-east-1".to_string()); - Ok(Self { provider, region, auth_mode, client: OnceCell::new() }) + // Extract optional FIPS mode for GovCloud + let fips_param: forge_domain::URLParam = "AWS_USE_FIPS".to_string().into(); + let use_fips = credential + .url_params + .get(&fips_param) + .map(|v| v.to_string() == "true") + .unwrap_or(false); + + Ok(Self { provider, region, auth_mode, use_fips, client: OnceCell::new(), control_client: OnceCell::new() }) + } + + /// Initializes and returns the AWS Bedrock control plane client + /// + /// Used for `ListFoundationModels` and `ListInferenceProfiles` API calls. + /// Lazily initialized on first call and reused for subsequent calls. + /// + /// # Errors + /// + /// Returns an error if credentials cannot be resolved. + async fn init_control_client(&self) -> Result<&aws_sdk_bedrock::Client> { + self.control_client + .get_or_try_init(|| async { + match &self.auth_mode { + BedrockAuthMode::BearerToken(token) => { + let mut builder = aws_sdk_bedrock::Config::builder() + .region(aws_sdk_bedrock::config::Region::new( + self.region.clone(), + )) + .bearer_token(aws_sdk_bedrock::config::Token::new( + token.clone(), + None, + )); + if self.use_fips { + builder = builder.use_fips(true); + } + Ok(aws_sdk_bedrock::Client::from_conf(builder.build())) + } + BedrockAuthMode::AwsProfile(profile) => { + let mut config_loader = aws_config::from_env() + .profile_name(profile) + .region(aws_sdk_bedrock::config::Region::new( + self.region.clone(), + )); + if self.use_fips { + config_loader = config_loader.use_fips(true); + } + let sdk_config = config_loader.load().await; + Ok(aws_sdk_bedrock::Client::new(&sdk_config)) + } + } + }) + .await } /// Initializes and returns the AWS Bedrock client @@ -87,22 +140,26 @@ impl BedrockProvider { .get_or_try_init(|| async { match &self.auth_mode { BedrockAuthMode::BearerToken(token) => { - let config = aws_sdk_bedrockruntime::Config::builder() + let mut builder = aws_sdk_bedrockruntime::Config::builder() .region(aws_sdk_bedrockruntime::config::Region::new( self.region.clone(), )) - .bearer_token(Token::new(token.clone(), None)) - .build(); - Ok(aws_sdk_bedrockruntime::Client::from_conf(config)) + .bearer_token(Token::new(token.clone(), None)); + if self.use_fips { + builder = builder.use_fips(true); + } + Ok(aws_sdk_bedrockruntime::Client::from_conf(builder.build())) } BedrockAuthMode::AwsProfile(profile) => { - let sdk_config = aws_config::from_env() + let mut config_loader = aws_config::from_env() .profile_name(profile) .region(aws_sdk_bedrockruntime::config::Region::new( self.region.clone(), - )) - .load() - .await; + )); + if self.use_fips { + config_loader = config_loader.use_fips(true); + } + let sdk_config = config_loader.load().await; Ok(aws_sdk_bedrockruntime::Client::new(&sdk_config)) } } @@ -290,15 +347,257 @@ impl BedrockProvider { Ok(Box::pin(stream)) } - /// Get available models + /// Get available models by querying the AWS Bedrock control plane APIs. + /// + /// Calls `ListFoundationModels` and `ListInferenceProfiles` to get the + /// account/region-aware model list. Models are filtered to those supporting + /// on-demand inference and streaming. Results are deduplicated by model ID. + /// + /// Falls back to hardcoded models from `provider.json` if the API calls + /// fail (e.g., missing IAM permissions). pub async fn models(&self) -> Result> { - // Bedrock doesn't have a models list API - // Return hardcoded models from configuration + match self.fetch_live_models().await { + Ok(models) => Ok(models), + Err(err) => { + tracing::warn!( + error = %err, + "Failed to list models from Bedrock API, falling back to hardcoded list. \ + Ensure your IAM policy includes bedrock:ListFoundationModels permission." + ); + self.hardcoded_models() + } + } + } + + /// Returns the hardcoded model list from provider.json configuration + fn hardcoded_models(&self) -> Result> { match &self.provider.models { Some(forge_domain::ModelSource::Hardcoded(models)) => Ok(models.clone()), _ => Ok(vec![]), } } + + /// Fetches live models from AWS Bedrock APIs + async fn fetch_live_models(&self) -> Result> { + use std::collections::HashSet; + + let client = self.init_control_client().await?; + + // Fetch foundation models + let fm_response = client + .list_foundation_models() + .send() + .await + .context("Failed to call ListFoundationModels")?; + + let mut seen_ids = HashSet::new(); + let mut models = Vec::new(); + + for summary in fm_response.model_summaries() { + // Filter: must support ON_DEMAND inference + let supports_on_demand = summary + .inference_types_supported() + .iter() + .any(|t| *t == aws_sdk_bedrock::types::InferenceType::OnDemand); + if !supports_on_demand { + continue; + } + + // Filter: must support response streaming + if summary.response_streaming_supported() != Some(true) { + continue; + } + + // Filter: skip LEGACY lifecycle status + if let Some(lc) = summary.model_lifecycle() { + if lc.status == aws_sdk_bedrock::types::FoundationModelLifecycleStatus::Legacy { + continue; + } + } + + let model_id = summary.model_id(); + if model_id.is_empty() || !seen_ids.insert(model_id.to_string()) { + continue; + } + + models.push(Self::foundation_model_to_domain(summary)); + } + + // Fetch inference profiles (cross-region profiles like us.anthropic.*) + match client.list_inference_profiles().send().await { + Ok(ip_response) => { + for profile in ip_response.inference_profile_summaries() { + let profile_id = profile.inference_profile_id().to_string(); + if profile_id.is_empty() || !seen_ids.insert(profile_id.clone()) { + continue; + } + + models.push(Model { + id: ModelId::from(profile_id), + name: Some(profile.inference_profile_name().to_string()), + description: profile.description().map(|s| s.to_string()), + context_length: None, + tools_supported: None, + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities: vec![InputModality::Text], + }); + } + } + Err(err) => { + tracing::debug!( + error = %err, + "ListInferenceProfiles failed; continuing with foundation models only" + ); + } + } + + // Enrich models with LiteLLM registry metadata (context_length, capabilities) + Self::enrich_with_litellm_registry(&mut models).await; + + Ok(models) + } + + /// LiteLLM community registry URL for model metadata + const LITELLM_REGISTRY_URL: &'static str = + "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"; + + /// Fetches the LiteLLM community model registry and enriches models with + /// metadata such as context_length, tool support, reasoning support, and + /// vision capabilities. + /// + /// This is best-effort: if the fetch fails, models are left unenriched. + async fn enrich_with_litellm_registry(models: &mut [Model]) { + let registry = match Self::fetch_litellm_registry().await { + Ok(r) => r, + Err(err) => { + tracing::debug!( + error = %err, + "Failed to fetch LiteLLM registry; models will lack metadata enrichment" + ); + return; + } + }; + + for model in models.iter_mut() { + Self::enrich_model_from_registry(model, ®istry); + } + } + + /// Fetches the LiteLLM registry JSON + async fn fetch_litellm_registry( + ) -> Result> { + let response = reqwest::get(Self::LITELLM_REGISTRY_URL) + .await + .context("Failed to fetch LiteLLM registry")?; + + let registry: std::collections::HashMap = response + .json() + .await + .context("Failed to parse LiteLLM registry JSON")?; + + Ok(registry) + } + + /// Enriches a single model with data from the LiteLLM registry. + /// + /// Tries multiple key patterns to find the model in the registry: + /// - `bedrock/` (standard Bedrock prefix in LiteLLM) + /// - Raw `` as fallback + fn enrich_model_from_registry( + model: &mut Model, + registry: &std::collections::HashMap, + ) { + let model_id = model.id.as_str(); + + // Try registry lookup with bedrock/ prefix first, then raw ID + let entry = registry + .get(&format!("bedrock/{}", model_id)) + .or_else(|| registry.get(model_id)); + + let Some(entry) = entry else { + return; + }; + + // Enrich context_length from max_input_tokens + if model.context_length.is_none() { + if let Some(max_input) = entry.get("max_input_tokens").and_then(|v| v.as_u64()) { + model.context_length = Some(max_input); + } + } + + // Enrich tool support + if model.tools_supported.is_none() { + if let Some(supports_fc) = entry + .get("supports_function_calling") + .and_then(|v| v.as_bool()) + { + model.tools_supported = Some(supports_fc); + } + } + + // Enrich reasoning support + if model.supports_reasoning.is_none() { + if let Some(supports_reasoning) = + entry.get("supports_reasoning").and_then(|v| v.as_bool()) + { + model.supports_reasoning = Some(supports_reasoning); + } + } + + // Enrich parallel tool calls support + if model.supports_parallel_tool_calls.is_none() { + if let Some(supports_parallel) = entry + .get("supports_parallel_function_calling") + .and_then(|v| v.as_bool()) + { + model.supports_parallel_tool_calls = Some(supports_parallel); + } + } + + // Enrich vision support -- add Image modality if supported + if let Some(true) = entry.get("supports_vision").and_then(|v| v.as_bool()) { + if !model.input_modalities.contains(&InputModality::Image) { + model.input_modalities.push(InputModality::Image); + } + } + } + + /// Converts a `FoundationModelSummary` to a domain `Model` + fn foundation_model_to_domain( + summary: &aws_sdk_bedrock::types::FoundationModelSummary, + ) -> Model { + let model_id = summary.model_id().to_string(); + let model_name = summary.model_name().map(|s| s.to_string()); + + // Map input modalities + let input_modalities: Vec = summary + .input_modalities() + .iter() + .filter_map(|m| match m { + aws_sdk_bedrock::types::ModelModality::Text => Some(InputModality::Text), + aws_sdk_bedrock::types::ModelModality::Image => Some(InputModality::Image), + _ => None, + }) + .collect(); + + let input_modalities = if input_modalities.is_empty() { + vec![InputModality::Text] + } else { + input_modalities + }; + + Model { + id: ModelId::from(model_id), + name: model_name, + description: None, + context_length: None, + tools_supported: None, + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities, + } + } } /// Converts Bedrock stream events to ChatCompletionMessage @@ -1127,7 +1426,9 @@ mod tests { BedrockProvider { provider: provider_fixture("test-token", Some(region)), auth_mode: BedrockAuthMode::BearerToken("test-token".to_string()), + use_fips: false, client: OnceCell::new(), + control_client: OnceCell::new(), region: region.to_string(), } } @@ -1447,7 +1748,9 @@ mod tests { let bedrock = BedrockProvider { provider: fixture_provider, auth_mode: BedrockAuthMode::BearerToken("token".to_string()), + use_fips: false, client: OnceCell::new(), + control_client: OnceCell::new(), region: "us-east-1".to_string(), }; @@ -1462,7 +1765,9 @@ mod tests { let bedrock = BedrockProvider { provider: fixture, auth_mode: BedrockAuthMode::BearerToken("token".to_string()), + use_fips: false, client: OnceCell::new(), + control_client: OnceCell::new(), region: "us-east-1".to_string(), }; @@ -2255,4 +2560,284 @@ mod tests { } assert!(got_text, "Expected text content in stream response"); } + + #[test] + fn test_foundation_model_to_domain_claude() { + use aws_sdk_bedrock::types::{ + FoundationModelLifecycle, FoundationModelLifecycleStatus, FoundationModelSummary, + InferenceType, ModelModality, + }; + + let fixture = FoundationModelSummary::builder() + .model_id("anthropic.claude-3-5-sonnet-20241022-v2:0") + .model_arn("arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-5-sonnet-20241022-v2:0") + .model_name("Claude 3.5 Sonnet v2") + .input_modalities(ModelModality::Text) + .input_modalities(ModelModality::Image) + .inference_types_supported(InferenceType::OnDemand) + .response_streaming_supported(true) + .model_lifecycle( + FoundationModelLifecycle::builder() + .status(FoundationModelLifecycleStatus::Active) + .build() + .unwrap(), + ) + .build() + .unwrap(); + + let actual = BedrockProvider::foundation_model_to_domain(&fixture); + let expected = Model { + id: ModelId::from("anthropic.claude-3-5-sonnet-20241022-v2:0".to_string()), + name: Some("Claude 3.5 Sonnet v2".to_string()), + description: None, + context_length: None, + tools_supported: None, + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities: vec![InputModality::Text, InputModality::Image], + }; + assert_eq!(actual, expected); + } + + #[test] + fn test_foundation_model_to_domain_nova_text_only() { + use aws_sdk_bedrock::types::{ + FoundationModelSummary, InferenceType, ModelModality, + }; + + let fixture = FoundationModelSummary::builder() + .model_id("amazon.nova-pro-v1:0") + .model_arn("arn:aws:bedrock:us-east-1::foundation-model/amazon.nova-pro-v1:0") + .model_name("Amazon Nova Pro") + .input_modalities(ModelModality::Text) + .inference_types_supported(InferenceType::OnDemand) + .response_streaming_supported(true) + .build() + .unwrap(); + + let actual = BedrockProvider::foundation_model_to_domain(&fixture); + let expected = Model { + id: ModelId::from("amazon.nova-pro-v1:0".to_string()), + name: Some("Amazon Nova Pro".to_string()), + description: None, + context_length: None, + tools_supported: None, + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities: vec![InputModality::Text], + }; + assert_eq!(actual, expected); + } + + #[test] + fn test_foundation_model_to_domain_no_modalities_defaults_to_text() { + use aws_sdk_bedrock::types::FoundationModelSummary; + + let fixture = FoundationModelSummary::builder() + .model_id("some.model-v1") + .model_arn("arn:aws:bedrock:us-east-1::foundation-model/some.model-v1") + .model_name("Some Model") + .build() + .unwrap(); + + let actual = BedrockProvider::foundation_model_to_domain(&fixture); + assert_eq!(actual.input_modalities, vec![InputModality::Text]); + } + + #[test] + fn test_enrich_model_from_registry_with_bedrock_prefix() { + let mut fixture = Model { + id: ModelId::from("anthropic.claude-3-5-sonnet-20241022-v2:0".to_string()), + name: Some("Claude 3.5 Sonnet v2".to_string()), + description: None, + context_length: None, + tools_supported: None, + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities: vec![InputModality::Text], + }; + + let mut registry = std::collections::HashMap::new(); + registry.insert( + "bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0".to_string(), + serde_json::json!({ + "max_input_tokens": 200000, + "max_output_tokens": 8192, + "supports_function_calling": true, + "supports_reasoning": true, + "supports_parallel_function_calling": true, + "supports_vision": true + }), + ); + + BedrockProvider::enrich_model_from_registry(&mut fixture, ®istry); + + assert_eq!(fixture.context_length, Some(200000)); + assert_eq!(fixture.tools_supported, Some(true)); + assert_eq!(fixture.supports_reasoning, Some(true)); + assert_eq!(fixture.supports_parallel_tool_calls, Some(true)); + assert!(fixture.input_modalities.contains(&InputModality::Image)); + } + + #[test] + fn test_enrich_model_from_registry_raw_id_fallback() { + let mut fixture = Model { + id: ModelId::from("meta.llama3-1-405b-instruct-v1:0".to_string()), + name: Some("Llama 3.1 405B".to_string()), + description: None, + context_length: None, + tools_supported: None, + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities: vec![InputModality::Text], + }; + + let mut registry = std::collections::HashMap::new(); + registry.insert( + "meta.llama3-1-405b-instruct-v1:0".to_string(), + serde_json::json!({ + "max_input_tokens": 128000, + "supports_function_calling": true, + "supports_vision": false + }), + ); + + BedrockProvider::enrich_model_from_registry(&mut fixture, ®istry); + + assert_eq!(fixture.context_length, Some(128000)); + assert_eq!(fixture.tools_supported, Some(true)); + // Vision is false, so Image should NOT be added + assert!(!fixture.input_modalities.contains(&InputModality::Image)); + } + + #[test] + fn test_enrich_model_from_registry_not_found() { + let mut fixture = Model { + id: ModelId::from("unknown.model-v1".to_string()), + name: None, + description: None, + context_length: None, + tools_supported: None, + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities: vec![InputModality::Text], + }; + + let registry = std::collections::HashMap::new(); + + BedrockProvider::enrich_model_from_registry(&mut fixture, ®istry); + + // Nothing should change + assert_eq!(fixture.context_length, None); + assert_eq!(fixture.tools_supported, None); + assert_eq!(fixture.input_modalities, vec![InputModality::Text]); + } + + #[test] + fn test_enrich_model_does_not_overwrite_existing_values() { + let mut fixture = Model { + id: ModelId::from("anthropic.claude-v2".to_string()), + name: None, + description: None, + context_length: Some(100000), // Already set + tools_supported: Some(false), // Already set + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities: vec![InputModality::Text], + }; + + let mut registry = std::collections::HashMap::new(); + registry.insert( + "bedrock/anthropic.claude-v2".to_string(), + serde_json::json!({ + "max_input_tokens": 200000, + "supports_function_calling": true + }), + ); + + BedrockProvider::enrich_model_from_registry(&mut fixture, ®istry); + + // Should NOT overwrite existing values + assert_eq!(fixture.context_length, Some(100000)); + assert_eq!(fixture.tools_supported, Some(false)); + } + + #[tokio::test] + async fn test_models_fallback_on_api_error() { + use forge_domain::ModelSource; + + // Create a provider with hardcoded models but invalid credentials + // The API call will fail, triggering fallback to hardcoded list + let mut fixture_provider = provider_fixture("invalid-token", None); + let fixture_models = vec![Model { + id: ModelId::from("fallback-model".to_string()), + name: Some("Fallback Model".to_string()), + description: None, + context_length: Some(128000), + tools_supported: Some(true), + supports_parallel_tool_calls: None, + supports_reasoning: None, + input_modalities: vec![InputModality::Text], + }]; + fixture_provider.models = Some(ModelSource::Hardcoded(fixture_models.clone())); + + let bedrock = BedrockProvider { + provider: fixture_provider, + auth_mode: BedrockAuthMode::BearerToken("invalid-token".to_string()), + use_fips: false, + client: OnceCell::new(), + control_client: OnceCell::new(), + region: "us-east-1".to_string(), + }; + + // models() should fall back to hardcoded list when API fails + let actual = bedrock.models().await.unwrap(); + let expected = fixture_models; + assert_eq!(actual, expected); + } + + #[test] + fn test_fips_mode_parsing() { + use forge_domain::{ + ApiKey, AuthCredential, AuthDetails, ProviderId, ProviderResponse, ProviderType, + URLParam, URLParamValue, + }; + + let mut url_params = std::collections::HashMap::new(); + url_params.insert( + URLParam::from("AWS_REGION".to_string()), + URLParamValue::from("us-gov-west-1".to_string()), + ); + url_params.insert( + URLParam::from("AWS_USE_FIPS".to_string()), + URLParamValue::from("true".to_string()), + ); + + let provider = Provider { + id: ProviderId::from("bedrock".to_string()), + provider_type: ProviderType::Llm, + response: Some(ProviderResponse::Bedrock), + url: Url::parse("https://bedrock-runtime.us-gov-west-1.amazonaws.com").unwrap(), + models: None, + auth_methods: vec![], + url_params: vec![], + credential: Some(AuthCredential { + id: ProviderId::from("bedrock".to_string()), + auth_details: AuthDetails::ApiKey(ApiKey::from("token".to_string())), + url_params, + }), + custom_headers: None, + }; + + let actual = BedrockProvider::new(provider).unwrap(); + assert_eq!(actual.region, "us-gov-west-1"); + assert!(actual.use_fips); + } + + #[test] + fn test_fips_mode_defaults_to_false() { + let fixture = provider_fixture("token", Some("us-east-1")); + let actual = BedrockProvider::new(fixture).unwrap(); + assert!(!actual.use_fips); + } } diff --git a/crates/forge_repo/src/provider/provider.json b/crates/forge_repo/src/provider/provider.json index 64af4400ed..0c84f6b9fd 100644 --- a/crates/forge_repo/src/provider/provider.json +++ b/crates/forge_repo/src/provider/provider.json @@ -1292,7 +1292,7 @@ }, { "id": "bedrock", - "url_param_vars": ["AWS_REGION"], + "url_param_vars": ["AWS_REGION", {"name": "AWS_USE_FIPS", "options": ["true", "false"], "optional": true}], "response_type": "Bedrock", "url": "https://bedrock-runtime.{{AWS_REGION}}.amazonaws.com", "models": [