From 237ee79cc6ecfdf28b8db878741b6a472b84c9c9 Mon Sep 17 00:00:00 2001 From: Paul Thurlow Date: Tue, 2 Jun 2026 11:10:17 -0700 Subject: [PATCH] add api timeout relaxation and refresh token retry ability --- src/api.rs | 387 ++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 353 insertions(+), 34 deletions(-) diff --git a/src/api.rs b/src/api.rs index e888f53..1e17b95 100644 --- a/src/api.rs +++ b/src/api.rs @@ -3,8 +3,15 @@ use crate::config; use crate::util; use crossterm::style::Stylize; use serde::de::DeserializeOwned; +use std::sync::{Arc, Mutex}; use std::time::Duration; +/// Mints a fresh bearer token on demand. Returns `None` if no fresh token +/// could be obtained (e.g. the refresh token is dead and there's no API key +/// to re-mint from). Must be `Send + Sync` because `ApiClient` is shared +/// across rayon worker threads (see `indexes.rs`). +pub type TokenRefresher = Arc Option + Send + Sync>; + /// Cap on any single HTTP request. Connection create + synchronous schema /// discovery against a slow remote catalog can take well over a minute, so /// this needs to be generous; 5 minutes leaves headroom while still bounding @@ -25,10 +32,32 @@ fn build_http_client() -> reqwest::blocking::Client { .expect("reqwest blocking client should always build with these defaults") } +/// Client used only for streaming file uploads. Deliberately has **no** +/// request timeout: an upload's duration scales with file size and the +/// user's uplink (a 10 GB parquet on a normal connection takes far longer +/// than the 300s `HTTP_REQUEST_TIMEOUT` that's sized for slow server-side +/// work), so a wall-clock cap would abort healthy-but-slow transfers. TCP +/// keepalive is kept so a genuinely dead peer is still reaped by the OS; a +/// live-but-slow upload runs to completion, and the user can Ctrl-C if it +/// truly stalls. +fn build_upload_client() -> reqwest::blocking::Client { + reqwest::blocking::Client::builder() + .tcp_keepalive(TCP_KEEPALIVE_INTERVAL) + .build() + .expect("reqwest blocking client should always build with these defaults") +} + #[derive(Clone)] pub struct ApiClient { client: reqwest::blocking::Client, - api_key: String, + /// The current bearer token. Wrapped so it can be refreshed in place + /// through a `&self` borrow (every request method takes `&self`), and + /// `Arc>` rather than `RefCell` so the client stays `Send + + /// Sync` for the rayon-parallel paths and `#[derive(Clone)]` keeps + /// clones sharing the same refreshed token. + token: Arc>, + /// How to obtain a fresh token when the current one is rejected. + refresh: TokenRefresher, pub api_url: String, workspace_id: Option, sandbox_id: Option, @@ -119,9 +148,18 @@ impl ApiClient { } }; + // Refresher used when a request comes back 401: reload config (to + // pick up a session that `ensure_access_token` may have just + // persisted) and re-run the same auth-source precedence, best-effort. + let refresh: TokenRefresher = Arc::new(|| { + let pc = config::load("default").ok()?; + resolve_fresh_token(&pc) + }); + Self { client: build_http_client(), - api_key: access_token, + token: Arc::new(Mutex::new(access_token)), + refresh, api_url: profile_config.api_url.to_string(), workspace_id: workspace_id.map(String::from), sandbox_id: std::env::var("HOTDATA_SANDBOX").ok().or_else(|| { @@ -144,11 +182,26 @@ impl ApiClient { } /// Test-only client (no config load). Used with a local mock HTTP server. + /// The refresher returns `None`, so 401s are not retried — matching the + /// behavior of tests that don't exercise the refresh path. #[cfg(test)] pub(crate) fn test_new(api_url: &str, api_key: &str, workspace_id: Option<&str>) -> Self { + Self::test_new_with_refresh(api_url, api_key, workspace_id, Arc::new(|| None)) + } + + /// Test-only client with an injectable token refresher, for exercising the + /// 401-retry path without touching real config or the JWT machinery. + #[cfg(test)] + pub(crate) fn test_new_with_refresh( + api_url: &str, + api_key: &str, + workspace_id: Option<&str>, + refresh: TokenRefresher, + ) -> Self { Self { client: build_http_client(), - api_key: api_key.to_string(), + token: Arc::new(Mutex::new(api_key.to_string())), + refresh, api_url: api_url.to_string(), workspace_id: workspace_id.map(String::from), sandbox_id: None, @@ -179,10 +232,11 @@ impl ApiClient { method: reqwest::Method, url: &str, ) -> reqwest::blocking::RequestBuilder { + let bearer = self.token.lock().expect("token mutex poisoned").clone(); let mut req = self .client .request(method, url) - .header("Authorization", format!("Bearer {}", self.api_key)); + .header("Authorization", format!("Bearer {bearer}")); if let Some(ref ws) = self.workspace_id { req = req.header("X-Workspace-Id", ws); } @@ -214,6 +268,37 @@ impl ApiClient { } } + /// Mint a fresh bearer and swap it in. Returns whether a new token was + /// obtained — `false` means the refresher gave up, so the caller should + /// surface the original failure rather than pointlessly retrying. + fn refresh_token(&self) -> bool { + match (self.refresh)() { + Some(new) => { + *self.token.lock().expect("token mutex poisoned") = new; + true + } + None => false, + } + } + + /// Send a request and, if the server rejects the bearer with 401, mint a + /// fresh token and retry exactly once. `build` reconstructs the request + /// from scratch on each attempt so the retry picks up the refreshed bearer + /// (the Authorization header is baked into an already-built request and + /// can't be mutated). Streaming uploads can't use this — their body is + /// consumed on the first send and is not replayable. + fn send_with_retry( + &self, + build: impl Fn() -> reqwest::blocking::RequestBuilder, + body_for_log: Option<&serde_json::Value>, + ) -> (reqwest::StatusCode, String) { + let (status, body) = self.send(build(), body_for_log); + if status == reqwest::StatusCode::UNAUTHORIZED && self.refresh_token() { + return self.send(build(), body_for_log); + } + (status, body) + } + fn parse_json(body: &str) -> T { match serde_json::from_str(body) { Ok(v) => v, @@ -236,10 +321,13 @@ impl ApiClient { .filter_map(|(k, v)| v.as_ref().map(|val| (*k, val))) .collect(); let url = format!("{}{path}", self.api_url); - let req = self - .build_request(reqwest::Method::GET, &url) - .query(&filtered); - let (status, body) = self.send(req, None); + let (status, body) = self.send_with_retry( + || { + self.build_request(reqwest::Method::GET, &url) + .query(&filtered) + }, + None, + ); if !status.is_success() { self.fail_response(status, body); } @@ -249,8 +337,8 @@ impl ApiClient { /// GET request, returns parsed response. pub fn get(&self, path: &str) -> T { let url = format!("{}{path}", self.api_url); - let req = self.build_request(reqwest::Method::GET, &url); - let (status, body) = self.send(req, None); + let (status, body) = + self.send_with_retry(|| self.build_request(reqwest::Method::GET, &url), None); if !status.is_success() { self.fail_response(status, body); } @@ -261,8 +349,8 @@ impl ApiClient { /// [`Self::get`]. Used when probing many paths where a missing resource is normal. pub fn get_none_if_not_found(&self, path: &str) -> Option { let url = format!("{}{path}", self.api_url); - let req = self.build_request(reqwest::Method::GET, &url); - let (status, body) = self.send(req, None); + let (status, body) = + self.send_with_retry(|| self.build_request(reqwest::Method::GET, &url), None); if status == reqwest::StatusCode::NOT_FOUND { return None; } @@ -275,8 +363,10 @@ impl ApiClient { /// POST request with JSON body, returns parsed response. pub fn post(&self, path: &str, body: &serde_json::Value) -> T { let url = format!("{}{path}", self.api_url); - let req = self.build_request(reqwest::Method::POST, &url).json(body); - let (status, resp_body) = self.send(req, Some(body)); + let (status, resp_body) = self.send_with_retry( + || self.build_request(reqwest::Method::POST, &url).json(body), + Some(body), + ); if !status.is_success() { self.fail_response(status, resp_body); } @@ -288,43 +378,54 @@ impl ApiClient { /// to handle non-2xx responses gracefully instead of aborting. pub fn get_raw(&self, path: &str) -> (reqwest::StatusCode, String) { let url = format!("{}{path}", self.api_url); - let req = self.build_request(reqwest::Method::GET, &url); - self.send(req, None) + self.send_with_retry(|| self.build_request(reqwest::Method::GET, &url), None) } /// GET with a custom Accept header; returns raw bytes instead of decoded text. /// Used for binary result formats such as Arrow IPC streams. pub fn get_bytes(&self, path: &str, accept: &str) -> (reqwest::StatusCode, Vec) { let url = format!("{}{path}", self.api_url); - let req = self.build_request(reqwest::Method::GET, &url).header("Accept", accept); - match util::send_debug_bytes(&self.client, req) { - Ok(pair) => pair, - Err(e) => { - eprintln!("error connecting to API: {e}"); - std::process::exit(1); + let send = |client: &reqwest::blocking::Client, c: &Self| { + let req = c + .build_request(reqwest::Method::GET, &url) + .header("Accept", accept); + match util::send_debug_bytes(client, req) { + Ok(pair) => pair, + Err(e) => { + eprintln!("error connecting to API: {e}"); + std::process::exit(1); + } } + }; + let (status, bytes) = send(&self.client, self); + if status == reqwest::StatusCode::UNAUTHORIZED && self.refresh_token() { + return send(&self.client, self); } + (status, bytes) } /// POST request with JSON body, exits on error, returns raw (status, body). pub fn post_raw(&self, path: &str, body: &serde_json::Value) -> (reqwest::StatusCode, String) { let url = format!("{}{path}", self.api_url); - let req = self.build_request(reqwest::Method::POST, &url).json(body); - self.send(req, Some(body)) + self.send_with_retry( + || self.build_request(reqwest::Method::POST, &url).json(body), + Some(body), + ) } /// DELETE request, exits on connection error, returns raw (status, body). pub fn delete_raw(&self, path: &str) -> (reqwest::StatusCode, String) { let url = format!("{}{path}", self.api_url); - let req = self.build_request(reqwest::Method::DELETE, &url); - self.send(req, None) + self.send_with_retry(|| self.build_request(reqwest::Method::DELETE, &url), None) } /// PATCH request with JSON body, returns parsed response. pub fn patch(&self, path: &str, body: &serde_json::Value) -> T { let url = format!("{}{path}", self.api_url); - let req = self.build_request(reqwest::Method::PATCH, &url).json(body); - let (status, resp_body) = self.send(req, Some(body)); + let (status, resp_body) = self.send_with_retry( + || self.build_request(reqwest::Method::PATCH, &url).json(body), + Some(body), + ); if !status.is_success() { self.fail_response(status, resp_body); } @@ -334,8 +435,10 @@ impl ApiClient { /// PUT request with JSON body, returns parsed response. pub fn put(&self, path: &str, body: &serde_json::Value) -> T { let url = format!("{}{path}", self.api_url); - let req = self.build_request(reqwest::Method::PUT, &url).json(body); - let (status, resp_body) = self.send(req, Some(body)); + let (status, resp_body) = self.send_with_retry( + || self.build_request(reqwest::Method::PUT, &url).json(body), + Some(body), + ); if !status.is_success() { self.fail_response(status, resp_body); } @@ -343,6 +446,12 @@ impl ApiClient { } /// POST with a custom request body (for file uploads). Returns raw status and body. + /// + /// Unlike the other methods this does **not** retry on 401: the body is a + /// one-shot stream that's consumed on send and can't be replayed. A large + /// upload is exactly the case where the token may expire mid-flight, but + /// the failure that matters surfaces on the *next* request (e.g. the load + /// POST), which does retry. See `databases::tables_load`. pub fn post_body( &self, path: &str, @@ -358,10 +467,41 @@ impl ApiClient { req = req.header("Content-Length", len); } let req = req.body(reqwest::blocking::Body::new(reader)); - // Body is an opaque stream — nothing meaningful to print under - // --debug, so pass `None`. Headers (including the masked - // Authorization) still log. - self.send(req, None) + // Execute on the upload client (no request timeout) rather than the + // default 300s client — `build_request`'s originating client is + // irrelevant once the request is built, since the executing client's + // timeout is what applies. Body is an opaque stream, so pass `None` + // for logging; headers (including the masked Authorization) still log. + let upload_client = build_upload_client(); + match util::send_debug(&upload_client, req, None) { + Ok(pair) => pair, + Err(e) => { + eprintln!("error connecting to API: {e}"); + std::process::exit(1); + } + } + } +} + +/// Best-effort re-resolution of the bearer token, mirroring the auth-source +/// precedence in [`ApiClient::new`] but returning `None` on failure instead +/// of exiting. Used by the 401-retry refresher: at refresh time we're already +/// past startup, so a failure just means "couldn't refresh, surface the +/// original error" rather than a fatal startup diagnostic. +fn resolve_fresh_token(profile_config: &config::ProfileConfig) -> Option { + let api_url = profile_config.api_url.to_string(); + if std::env::var("HOTDATA_DATABASE_TOKEN").is_ok() { + crate::database_session::refresh_from_env(&api_url) + } else if std::env::var("HOTDATA_SANDBOX_TOKEN").is_ok() { + crate::sandbox_session::refresh_from_env(&api_url) + } else if crate::sandbox_session::load().is_some() { + crate::sandbox_session::ensure_access_token(&api_url) + } else { + let api_key_fallback = profile_config + .api_key + .as_deref() + .filter(|k| !k.is_empty() && *k != "PLACEHOLDER"); + crate::jwt::ensure_access_token(profile_config, api_key_fallback).ok() } } @@ -540,6 +680,185 @@ mod tests { assert_eq!(msg, "plain body"); } + #[test] + fn post_raw_retries_once_with_refreshed_token_after_401() { + let mut server = mockito::Server::new(); + // First attempt: stale bearer is rejected. + let stale = server + .mock("POST", "/load") + .match_header("Authorization", "Bearer stale-token") + .with_status(401) + .with_body("Invalid api key") + .create(); + // Retry: client must mint a fresh bearer and the server accepts it. + let fresh = server + .mock("POST", "/load") + .match_header("Authorization", "Bearer fresh-token") + .with_status(200) + .with_body(r#"{"ok":true}"#) + .create(); + + let api = ApiClient::test_new_with_refresh( + &server.url(), + "stale-token", + None, + std::sync::Arc::new(|| Some("fresh-token".to_string())), + ); + let (status, body) = api.post_raw("/load", &serde_json::json!({"upload_id": "u1"})); + + assert_eq!( + status.as_u16(), + 200, + "retry should surface the 200, got body: {body}" + ); + assert!(body.contains("\"ok\":true")); + stale.assert(); + fresh.assert(); + } + + #[test] + fn get_retries_once_with_refreshed_token_after_401() { + let mut server = mockito::Server::new(); + let stale = server + .mock("GET", "/ok") + .match_header("Authorization", "Bearer stale-token") + .with_status(401) + .with_body("Invalid api key") + .create(); + let fresh = server + .mock("GET", "/ok") + .match_header("Authorization", "Bearer fresh-token") + .with_status(200) + .with_body(r#"{"n":7}"#) + .create(); + + let api = ApiClient::test_new_with_refresh( + &server.url(), + "stale-token", + None, + std::sync::Arc::new(|| Some("fresh-token".to_string())), + ); + let got: Probe = api.get("/ok"); + assert_eq!(got.n, 7); + stale.assert(); + fresh.assert(); + } + + #[test] + fn does_not_retry_on_non_401() { + // A 500 is not an auth problem — the client must not refresh or retry. + let mut server = mockito::Server::new(); + let mock = server + .mock("POST", "/load") + .with_status(500) + .with_body("boom") + .expect(1) // exactly one request, no retry + .create(); + + let refreshed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); + let flag = refreshed.clone(); + let api = ApiClient::test_new_with_refresh( + &server.url(), + "stale-token", + None, + std::sync::Arc::new(move || { + flag.store(true, std::sync::atomic::Ordering::SeqCst); + Some("fresh-token".to_string()) + }), + ); + let (status, _) = api.post_raw("/load", &serde_json::json!({})); + assert_eq!(status.as_u16(), 500); + assert!( + !refreshed.load(std::sync::atomic::Ordering::SeqCst), + "refresher must not be called on a non-401 response" + ); + mock.assert(); + } + + #[test] + fn retries_at_most_once_then_surfaces_401() { + // Both attempts 401 → give up after a single retry (no infinite loop). + let mut server = mockito::Server::new(); + let mock = server + .mock("POST", "/load") + .with_status(401) + .with_body("Invalid api key") + .expect(2) // original + one retry, then stop + .create(); + + let api = ApiClient::test_new_with_refresh( + &server.url(), + "stale-token", + None, + std::sync::Arc::new(|| Some("still-bad-token".to_string())), + ); + let (status, body) = api.post_raw("/load", &serde_json::json!({})); + assert_eq!(status.as_u16(), 401); + assert!(body.contains("Invalid api key")); + mock.assert(); + } + + #[test] + fn does_not_retry_when_refresher_cannot_mint() { + // Refresher returns None (e.g. dead refresh token, no API key) → the + // original 401 is surfaced unchanged, with no second request. + let mut server = mockito::Server::new(); + let mock = server + .mock("POST", "/load") + .with_status(401) + .with_body("Invalid api key") + .expect(1) + .create(); + + let api = ApiClient::test_new_with_refresh( + &server.url(), + "stale-token", + None, + std::sync::Arc::new(|| None), + ); + let (status, _) = api.post_raw("/load", &serde_json::json!({})); + assert_eq!(status.as_u16(), 401); + mock.assert(); + } + + #[test] + fn post_body_does_not_retry_on_401() { + // Streaming uploads can't be replayed, so a 401 here is surfaced as-is + // and the refresher is never consulted. + let mut server = mockito::Server::new(); + let mock = server + .mock("POST", "/files") + .with_status(401) + .with_body("Invalid api key") + .expect(1) + .create(); + + let refreshed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); + let flag = refreshed.clone(); + let api = ApiClient::test_new_with_refresh( + &server.url(), + "stale-token", + None, + std::sync::Arc::new(move || { + flag.store(true, std::sync::atomic::Ordering::SeqCst); + Some("fresh-token".to_string()) + }), + ); + let data = b"parquet-bytes".to_vec(); + let (status, _) = api.post_body( + "/files", + "application/octet-stream", + std::io::Cursor::new(data), + None, + ); + assert_eq!(status.as_u16(), 401); + assert!( + !refreshed.load(std::sync::atomic::Ordering::SeqCst), + "streaming upload must not trigger a token refresh/retry" + ); + mock.assert(); + } + #[test] fn format_fail_message_4xx_authenticated_probe_shows_server_message() { // Valid key but a genuine client error — upstream message wins.