diff --git a/.gitignore b/.gitignore index 46feafd..5103646 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.idea /target .fastembed_cache **/.claude/settings.local.json diff --git a/Cargo.lock b/Cargo.lock index cafde74..9daefe8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -822,18 +822,18 @@ dependencies = [ [[package]] name = "faiss" -version = "0.12.1" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3ffe048432786028b0a30aa1d13e10e08ced380439ba4a83fe5c227d2dd9733" +checksum = "3618cbbe48ebdc63b461cd1bf52f64c27a40e4d294beceedbf2fd4898571e204" dependencies = [ "faiss-sys", ] [[package]] name = "faiss-sys" -version = "0.6.2" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b9c008fc56422bf34357f17226d9c5a5c2ef6245b4774759c5f67112e46915e" +checksum = "c4a8fa67f7070bb81f41dc71c98813b3e40a66bedab12d6c2ac2af50f0837a7e" [[package]] name = "fastembed" diff --git a/Cargo.toml b/Cargo.toml index cf8a870..bc567f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ tokio = {version = "1.0.0", features = ["full"]} url = "2.5.4" thiserror = "2.0.12" fastembed = "4.9" -faiss = "0.12.1" +faiss = "0.13.0" askama = "0.14.0" prometheus = "0.14.0" chrono = { version = "0.4", features = ["serde"] } diff --git a/src/app_state.rs b/src/app_state.rs index afc676f..79a7ed5 100644 --- a/src/app_state.rs +++ b/src/app_state.rs @@ -6,35 +6,127 @@ use crate::clients::client::Client; use crate::clients::http_client::HttpClient; use crate::embedding::fastembed::FastEmbedService; use crate::embedding::service::EmbeddingService; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum AppStateError { + #[error("Lock poisoned: {0}")] + LockPoisoned(&'static str), + + #[error("Invalid namespace: {0}")] + InvalidNamespace(&'static str), +} pub struct AppState { + // client for upstream LLM requests pub http_client: Box, pub embedding_service: Box, - pub cache: Box>>, + caches: RwLock>>>>>, + similarity_threshold: f32, + eviction_policy: EvictionPolicy, } impl AppState { - pub fn new(semantic_threshold: f32, eviction_policy: EvictionPolicy) -> Self { - // client for upstream LLM requests - let http_client = Box::new(HttpClient::new()); - // cache fields - let embedding_service = Box::new(FastEmbedService::new()); + const MAX_NAMESPACE_LENGTH: usize = 64; + + pub fn new(similarity_threshold: f32, eviction_policy: EvictionPolicy) -> Self { + Self { + http_client: Box::new(HttpClient::new()), + embedding_service: Box::new(FastEmbedService::new()), + caches: RwLock::new(HashMap::new()), + similarity_threshold, + eviction_policy, + } + } + + fn validate_namespace(namespace: &str) -> Result<(), AppStateError> { + if namespace.is_empty() { + return Err(AppStateError::InvalidNamespace("namespace cannot be empty")); + } + + if namespace.len() > Self::MAX_NAMESPACE_LENGTH { + return Err(AppStateError::InvalidNamespace("namespace too long")); + } + + // Check for valid characters: alphanumeric, underscore, hyphen + if !namespace + .chars() + .all(|c| c.is_alphanumeric() || c == '_' || c == '-') + { + return Err(AppStateError::InvalidNamespace( + "invalid characters in namespace", + )); + } + + Ok(()) + } + + fn create_cache(&self) -> Box>> { let semantic_store = Box::new(FlatIPFaissStore::new( - embedding_service.get_dimensionality(), + self.embedding_service.get_dimensionality(), )); let response_store = ResponseStore::new(); - // create cache - let cache = Box::new(CacheImpl::new( + Box::new(CacheImpl::new( semantic_store, response_store, - semantic_threshold, - eviction_policy, - )); - // put service dependencies into app state + self.similarity_threshold, + self.eviction_policy, + )) + } + + pub fn get_cache( + &self, + namespace: &str, + ) -> Result>>>, AppStateError> { + Self::validate_namespace(namespace)?; + + // Try read lock first to check if cache exists + { + let read_guard = self + .caches + .read() + .map_err(|_| AppStateError::LockPoisoned("caches read lock poisoned"))?; + if let Some(cache) = read_guard.get(namespace) { + return Ok(Arc::clone(cache)); + } + } + + // Cache doesn't exist, acquire write lock to create it + let mut write_guard = self + .caches + .write() + .map_err(|_| AppStateError::LockPoisoned("caches write lock poisoned"))?; + + // Double-check: another thread might have created it while we were waiting + let cache = write_guard + .entry(namespace.to_string()) + .or_insert_with(|| Arc::new(self.create_cache())) + .clone(); + + Ok(cache) + } +} + +#[cfg(test)] +impl AppState { + pub fn new_with_cache_for_test( + http_client: Box, + embedding_service: Box, + cache: Box>>, + ) -> Self { + use crate::utils::header_utils::DEFAULT_NAMESPACE; + use std::collections::HashMap; + let mut caches = HashMap::new(); + caches.insert(DEFAULT_NAMESPACE.to_string(), Arc::new(cache)); + Self { http_client, embedding_service, - cache, + caches: RwLock::new(caches), + similarity_threshold: 0.9, + eviction_policy: EvictionPolicy::EntryLimit(100), } } } diff --git a/src/cache/cache_impl.rs b/src/cache/cache_impl.rs index 699c5cc..e0dfa82 100644 --- a/src/cache/cache_impl.rs +++ b/src/cache/cache_impl.rs @@ -7,7 +7,7 @@ use crate::cache::response_store::ResponseStore; use crate::metrics::metrics::CACHE_SIZE; use tracing::{debug, info}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub enum EvictionPolicy { EntryLimit(usize), MemoryLimitMb(usize), // Could also implement a "combined" of both limits diff --git a/src/embedding/fastembed.rs b/src/embedding/fastembed.rs index 787cb9c..13b097f 100644 --- a/src/embedding/fastembed.rs +++ b/src/embedding/fastembed.rs @@ -22,16 +22,6 @@ impl FastEmbedService { model_name: EmbeddingModel::AllMiniLML6V2, } } - - pub fn get_dimensionality(&self) -> u32 { - match &self.model_name { - EmbeddingModel::AllMiniLML6V2 => 384, - _ => panic!( - "{}", - EmbeddingError::SetupError(String::from("Embedding model with unknown size",)) - ), - } - } } impl EmbeddingService for FastEmbedService { @@ -43,4 +33,14 @@ impl EmbeddingService for FastEmbedService { Ok(embeddings.into_iter().next().unwrap()) } + + fn get_dimensionality(&self) -> u32 { + match &self.model_name { + EmbeddingModel::AllMiniLML6V2 => 384, + _ => panic!( + "{}", + EmbeddingError::SetupError(String::from("Embedding model with unknown size",)) + ), + } + } } diff --git a/src/embedding/service.rs b/src/embedding/service.rs index 61627db..3ea43d1 100644 --- a/src/embedding/service.rs +++ b/src/embedding/service.rs @@ -5,4 +5,5 @@ use crate::embedding::error::EmbeddingError; #[automock] pub trait EmbeddingService: Send + Sync { fn embed(&self, text: &str) -> Result, EmbeddingError>; + fn get_dimensionality(&self) -> u32; } diff --git a/src/endpoints/cache_aside/handler.rs b/src/endpoints/cache_aside/handler.rs index 25fe335..cd10a6c 100644 --- a/src/endpoints/cache_aside/handler.rs +++ b/src/endpoints/cache_aside/handler.rs @@ -7,7 +7,11 @@ use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use thiserror::Error; -use crate::{app_state::AppState, cache::error::CacheError, embedding::error::EmbeddingError}; +use crate::{ + app_state::{AppState, AppStateError}, + cache::error::CacheError, + embedding::error::EmbeddingError, +}; #[derive(Debug, Error)] pub enum CacheAsideError { @@ -15,6 +19,8 @@ pub enum CacheAsideError { InternalEmbedding(#[from] EmbeddingError), #[error("Error in caching layer: {0}")] InternalCache(#[from] CacheError), + #[error("AppState error: {0}")] + AppStateError(#[from] AppStateError), } impl IntoResponse for CacheAsideError { @@ -28,28 +34,56 @@ impl IntoResponse for CacheAsideError { error!(?err, "returning internal error to user"); (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response() } + Self::AppStateError(err) => { + error!(?err, "AppState error"); + match err { + AppStateError::InvalidNamespace(_) => { + (StatusCode::BAD_REQUEST, err.to_string()).into_response() + } + _ => { + (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response() + } + } + } } } } +use crate::utils::header_utils::DEFAULT_NAMESPACE; + +fn default_namespace() -> String { + DEFAULT_NAMESPACE.to_string() +} + #[derive(Deserialize, Serialize, Debug)] pub struct GetRequest { pub key: String, + #[serde(default = "default_namespace")] + pub namespace: String, } #[derive(Deserialize, Serialize, Debug)] pub struct PutRequest { pub key: String, pub data: String, + #[serde(default = "default_namespace")] + pub namespace: String, } pub async fn get( State(state): State>, Json(request): Json, ) -> Result { - debug!("cache_aside::GET request received"); + debug!( + "cache_aside::GET request received for namespace: {}", + request.namespace + ); + + // Get cache for namespace + let cache = state.get_cache(&request.namespace)?; + let embedding = state.embedding_service.embed(&request.key)?; - let saved_response = state.cache.get_if_present(&embedding)?; + let saved_response = cache.get_if_present(&embedding)?; let http_response = match saved_response { Some(response_bytes) => (StatusCode::OK, response_bytes).into_response(), None => (StatusCode::NOT_FOUND).into_response(), @@ -61,13 +95,20 @@ pub async fn put( State(state): State>, Json(request): Json, ) -> Result { - debug!("cache_aside::PUT request received"); + debug!( + "cache_aside::PUT request received for namespace: {}", + request.namespace + ); + + // Get cache for namespace + let cache = state.get_cache(&request.namespace)?; + let body: Vec = request.data.into_bytes(); let embedding = state.embedding_service.embed(&request.key)?; // if we already have an entry associated with the prompt, update it - let updated_existing_entry = state.cache.try_update(&embedding, body.clone())?; + let updated_existing_entry = cache.try_update(&embedding, body.clone())?; if !updated_existing_entry { - state.cache.insert(embedding, body)?; + cache.insert(embedding, body)?; } Ok((StatusCode::OK).into_response()) } @@ -113,14 +154,15 @@ mod tests { mock_client.expect_post_http_request().times(0); // put mocked objects into the appstate - let app_state = Arc::new(AppState { - embedding_service: Box::new(mock_embed), - cache: Box::new(mock_cache), - http_client: Box::new(mock_client), - }); + let app_state = Arc::new(AppState::new_with_cache_for_test( + Box::new(mock_client), + Box::new(mock_embed), + Box::new(mock_cache), + )); let request_body = GetRequest { key: String::from(prompt), + namespace: "default".to_string(), }; // when @@ -155,14 +197,15 @@ mod tests { mock_client.expect_post_http_request().times(0); // put mocked objects into the appstate - let app_state = Arc::new(AppState { - embedding_service: Box::new(mock_embed), - cache: Box::new(mock_cache), - http_client: Box::new(mock_client), - }); + let app_state = Arc::new(AppState::new_with_cache_for_test( + Box::new(mock_client), + Box::new(mock_embed), + Box::new(mock_cache), + )); let request_body = GetRequest { key: String::from(prompt), + namespace: "default".to_string(), }; // when @@ -204,14 +247,15 @@ mod tests { mock_client.expect_post_http_request().times(0); // put mocked objects into the appstate - let app_state = Arc::new(AppState { - embedding_service: Box::new(mock_embed), - cache: Box::new(mock_cache), - http_client: Box::new(mock_client), - }); + let app_state = Arc::new(AppState::new_with_cache_for_test( + Box::new(mock_client), + Box::new(mock_embed), + Box::new(mock_cache), + )); let request_body = GetRequest { key: String::from(prompt), + namespace: "default".to_string(), }; // when @@ -251,14 +295,15 @@ mod tests { mock_client.expect_post_http_request().times(0); // put mocked objects into the appstate - let app_state = Arc::new(AppState { - embedding_service: Box::new(mock_embed), - cache: Box::new(mock_cache), - http_client: Box::new(mock_client), - }); + let app_state = Arc::new(AppState::new_with_cache_for_test( + Box::new(mock_client), + Box::new(mock_embed), + Box::new(mock_cache), + )); let request_body = GetRequest { key: String::from(prompt), + namespace: "default".to_string(), }; // when @@ -297,15 +342,16 @@ mod tests { mock_client.expect_post_http_request().times(0); // put mocked objects into the appstate - let app_state = Arc::new(AppState { - embedding_service: Box::new(mock_embed), - cache: Box::new(mock_cache), - http_client: Box::new(mock_client), - }); + let app_state = Arc::new(AppState::new_with_cache_for_test( + Box::new(mock_client), + Box::new(mock_embed), + Box::new(mock_cache), + )); let request_body = PutRequest { key: String::from(prompt), data: String::from(body), + namespace: "default".to_string(), }; // when @@ -342,15 +388,16 @@ mod tests { mock_client.expect_post_http_request().times(0); // put mocked objects into the appstate - let app_state = Arc::new(AppState { - embedding_service: Box::new(mock_embed), - cache: Box::new(mock_cache), - http_client: Box::new(mock_client), - }); + let app_state = Arc::new(AppState::new_with_cache_for_test( + Box::new(mock_client), + Box::new(mock_embed), + Box::new(mock_cache), + )); let request_body = PutRequest { key: String::from(prompt), data: String::from(body), + namespace: "default".to_string(), }; // when @@ -386,15 +433,16 @@ mod tests { mock_client.expect_post_http_request().times(0); // put mocked objects into the appstate - let app_state = Arc::new(AppState { - embedding_service: Box::new(mock_embed), - cache: Box::new(mock_cache), - http_client: Box::new(mock_client), - }); + let app_state = Arc::new(AppState::new_with_cache_for_test( + Box::new(mock_client), + Box::new(mock_embed), + Box::new(mock_cache), + )); let request_body = PutRequest { key: String::from(prompt), data: String::from(body), + namespace: "default".to_string(), }; // when @@ -438,15 +486,16 @@ mod tests { mock_client.expect_post_http_request().times(0); // put mocked objects into the appstate - let app_state = Arc::new(AppState { - embedding_service: Box::new(mock_embed), - cache: Box::new(mock_cache), - http_client: Box::new(mock_client), - }); + let app_state = Arc::new(AppState::new_with_cache_for_test( + Box::new(mock_client), + Box::new(mock_embed), + Box::new(mock_cache), + )); let request_body = PutRequest { key: String::from(prompt), data, + namespace: "default".to_string(), }; // when @@ -457,4 +506,117 @@ mod tests { // then assert_eq!(result.status(), StatusCode::OK); } + + #[tokio::test] + async fn namespace_isolation_put_and_get() { + use crate::cache::cache_impl::EvictionPolicy; + + // Create real AppState (not mocked) to test namespace isolation + let app_state = Arc::new(AppState::new(0.9, EvictionPolicy::EntryLimit(100))); + + let key = "What is the capital of France?"; + + // Put different values in different namespaces + let put_ns1 = PutRequest { + key: key.to_string(), + data: "Paris for namespace 1".to_string(), + namespace: "namespace-1".to_string(), + }; + put(State(app_state.clone()), axum::Json(put_ns1)) + .await + .unwrap(); + + let put_ns2 = PutRequest { + key: key.to_string(), + data: "Paris for namespace 2".to_string(), + namespace: "namespace-2".to_string(), + }; + put(State(app_state.clone()), axum::Json(put_ns2)) + .await + .unwrap(); + + // Get from namespace-1 should return namespace-1's value + let get_ns1 = GetRequest { + key: key.to_string(), + namespace: "namespace-1".to_string(), + }; + let result_ns1 = get(State(app_state.clone()), axum::Json(get_ns1)) + .await + .unwrap(); + assert_eq!(result_ns1.status(), StatusCode::OK); + let body_ns1 = axum::body::to_bytes(result_ns1.into_body(), usize::MAX) + .await + .unwrap(); + assert_eq!(body_ns1, "Paris for namespace 1".as_bytes()); + + // Get from namespace-2 should return namespace-2's value + let get_ns2 = GetRequest { + key: key.to_string(), + namespace: "namespace-2".to_string(), + }; + let result_ns2 = get(State(app_state.clone()), axum::Json(get_ns2)) + .await + .unwrap(); + assert_eq!(result_ns2.status(), StatusCode::OK); + let body_ns2 = axum::body::to_bytes(result_ns2.into_body(), usize::MAX) + .await + .unwrap(); + assert_eq!(body_ns2, "Paris for namespace 2".as_bytes()); + } + + #[tokio::test] + async fn namespace_not_found_in_different_namespace() { + use crate::cache::cache_impl::EvictionPolicy; + + let app_state = Arc::new(AppState::new(0.9, EvictionPolicy::EntryLimit(100))); + + // Put in namespace-1 + let put_req = PutRequest { + key: "test key".to_string(), + data: "test value".to_string(), + namespace: "namespace-1".to_string(), + }; + put(State(app_state.clone()), axum::Json(put_req)) + .await + .unwrap(); + + // Try to get from namespace-2 (should not find it) + let get_req = GetRequest { + key: "test key".to_string(), + namespace: "namespace-2".to_string(), + }; + let result = get(State(app_state), axum::Json(get_req)).await.unwrap(); + assert_eq!(result.status(), StatusCode::NOT_FOUND); + } + + #[tokio::test] + async fn invalid_namespace_returns_error() { + use crate::cache::cache_impl::EvictionPolicy; + + let app_state = Arc::new(AppState::new(0.9, EvictionPolicy::EntryLimit(100))); + + // Empty namespace + let get_req = GetRequest { + key: "test".to_string(), + namespace: "".to_string(), + }; + let result = get(State(app_state.clone()), axum::Json(get_req)).await; + assert!(result.is_err()); + + // Invalid characters + let get_req = GetRequest { + key: "test".to_string(), + namespace: "test@namespace".to_string(), + }; + let result = get(State(app_state.clone()), axum::Json(get_req)).await; + assert!(result.is_err()); + + // Too long + let get_req = GetRequest { + key: "test".to_string(), + namespace: "a".repeat(65), + }; + let result = get(State(app_state), axum::Json(get_req)).await; + assert!(result.is_err()); + } } diff --git a/src/endpoints/chat/error.rs b/src/endpoints/chat/error.rs index 6c82ff9..e3e2754 100644 --- a/src/endpoints/chat/error.rs +++ b/src/endpoints/chat/error.rs @@ -5,7 +5,10 @@ use reqwest::StatusCode; use thiserror::Error; use tracing::warn; -use crate::{cache::error::CacheError, embedding::error::EmbeddingError, providers::ProviderError}; +use crate::{ + app_state::AppStateError, cache::error::CacheError, embedding::error::EmbeddingError, + providers::ProviderError, +}; // Error type #[derive(Debug, Error)] @@ -30,6 +33,9 @@ pub enum CompletionError { #[error("Provider error: {0}")] InternalProviderError(#[from] ProviderError), + + #[error("AppState error: {0}")] + AppStateError(#[from] AppStateError), } impl IntoResponse for CompletionError { @@ -81,6 +87,17 @@ impl IntoResponse for CompletionError { ) .into_response() } + Self::AppStateError(err) => { + warn!("AppState error: {}", err); + match err { + AppStateError::InvalidNamespace(_) => { + (StatusCode::BAD_REQUEST, err.to_string()).into_response() + } + _ => { + (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong!").into_response() + } + } + } } } } diff --git a/src/endpoints/chat/handler.rs b/src/endpoints/chat/handler.rs index 1c7e944..db8bf27 100644 --- a/src/endpoints/chat/handler.rs +++ b/src/endpoints/chat/handler.rs @@ -13,7 +13,8 @@ use crate::app_state::AppState; use crate::metrics::metrics::{CACHE_HIT, CACHE_MISS, CacheStatus}; use crate::providers::ProviderType; use crate::utils::{ - header_utils::PROXY_PROMPT_LOCATION_HEADER, json_extract::extract_prompt_from_path, + header_utils::{PROXY_PROMPT_LOCATION_HEADER, extract_namespace}, + json_extract::extract_prompt_from_path, }; pub async fn completions( @@ -22,13 +23,17 @@ pub async fn completions( Json(request_body): Json, provider: ProviderType, ) -> Result { + // Extract namespace from headers, get the appropriate cache + let namespace = extract_namespace(&headers); + let cache = state.get_cache(&namespace)?; + let prompt = extract_prompt_from_path( &request_body, provider.prompt_json_path(headers.get(&PROXY_PROMPT_LOCATION_HEADER))?, )?; let embedding = state.embedding_service.embed(&prompt)?; - if let Some(saved_response) = state.cache.get_if_present(&embedding)? { + if let Some(saved_response) = cache.get_if_present(&embedding)? { // Return cached response with 200 OK and minimal headers let mut response_headers = HeaderMap::new(); response_headers.insert("X-Cache-Status", "hit".parse().unwrap()); @@ -49,9 +54,7 @@ pub async fn completions( // only store the response if the status code of the response is 2XX if upstream_response.status_code.is_success() { - state - .cache - .insert(embedding, upstream_response.response_body.clone())?; + cache.insert(embedding, upstream_response.response_body.clone())?; } let mut response = ( @@ -110,11 +113,11 @@ mod tests { mock_client.expect_post_http_request().times(0); // put mocked objects into the appstate - let app_state = Arc::new(AppState { - embedding_service: Box::new(mock_embed), - cache: Box::new(mock_cache), - http_client: Box::new(mock_client), - }); + let app_state = Arc::new(AppState::new_with_cache_for_test( + Box::new(mock_client), + Box::new(mock_embed), + Box::new(mock_cache), + )); let request_body = json!({ "messages": [{ @@ -165,11 +168,11 @@ mod tests { let mut mock_client = MockClient::new(); mock_client.expect_post_http_request().times(0); - let app_state = Arc::new(AppState { - embedding_service: Box::new(mock_embed), - cache: Box::new(mock_cache), - http_client: Box::new(mock_client), - }); + let app_state = Arc::new(AppState::new_with_cache_for_test( + Box::new(mock_client), + Box::new(mock_embed), + Box::new(mock_cache), + )); let request_body = json!({ "messages": [{ @@ -235,11 +238,11 @@ mod tests { .times(0) .returning(|_, _, _| unreachable!()); - let app_state = Arc::new(AppState { - embedding_service: Box::new(mock_embed), - cache: Box::new(mock_cache), - http_client: Box::new(mock_client), - }); + let app_state = Arc::new(AppState::new_with_cache_for_test( + Box::new(mock_client), + Box::new(mock_embed), + Box::new(mock_cache), + )); // Test OpenAI message let request_body = json!({ @@ -346,11 +349,11 @@ mod tests { } }); - let app_state = Arc::new(AppState { - embedding_service: Box::new(mock_embed), - cache: Box::new(mock_cache), - http_client: Box::new(mock_client), - }); + let app_state = Arc::new(AppState::new_with_cache_for_test( + Box::new(mock_client), + Box::new(mock_embed), + Box::new(mock_cache), + )); let request_body = json!({ "messages": [{ @@ -417,11 +420,11 @@ mod tests { } }); - let app_state = Arc::new(AppState { - embedding_service: Box::new(mock_embed), - cache: Box::new(mock_cache), - http_client: Box::new(mock_client), - }); + let app_state = Arc::new(AppState::new_with_cache_for_test( + Box::new(mock_client), + Box::new(mock_embed), + Box::new(mock_cache), + )); let request_body = json!({ "messages": [{ diff --git a/src/utils/header_utils.rs b/src/utils/header_utils.rs index 2ad6b13..0f306f9 100644 --- a/src/utils/header_utils.rs +++ b/src/utils/header_utils.rs @@ -6,6 +6,7 @@ use axum::http::{HeaderMap, HeaderName}; pub static PROXY_UPSTREAM_HOST_HEADER: HeaderName = HeaderName::from_static("x-llm-proxy-host"); pub static PROXY_UPSTREAM_HEADER: HeaderName = HeaderName::from_static("x-llm-proxy-upstream"); pub static PROXY_PROMPT_LOCATION_HEADER: HeaderName = HeaderName::from_static("x-llm-prompt"); +pub static NAMESPACE_HEADER: HeaderName = HeaderName::from_static("x-semcache-namespace"); pub static HOP_HEADERS: LazyLock<[HeaderName; 12]> = LazyLock::new(|| { [ HeaderName::from_static("connection"), @@ -24,6 +25,9 @@ pub static HOP_HEADERS: LazyLock<[HeaderName; 12]> = LazyLock::new(|| { ] }); +// CONSTANTS +pub const DEFAULT_NAMESPACE: &str = "default"; + pub fn remove_hop_headers(headers: &mut HeaderMap) { for header in &*HOP_HEADERS { headers.remove(header); @@ -38,6 +42,15 @@ pub fn prepare_upstream_headers(headers: HeaderMap) -> HeaderMap { // remove semcache headers upstream_headers.remove(&PROXY_UPSTREAM_HEADER); upstream_headers.remove(&PROXY_PROMPT_LOCATION_HEADER); + upstream_headers.remove(&NAMESPACE_HEADER); upstream_headers } + +pub fn extract_namespace(headers: &HeaderMap) -> String { + headers + .get(&NAMESPACE_HEADER) + .and_then(|h| h.to_str().ok()) + .map(|s| s.to_string()) + .unwrap_or_else(|| DEFAULT_NAMESPACE.to_string()) +}