diff --git a/Cargo.lock b/Cargo.lock index 9a40aee..20769c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -72,7 +72,7 @@ version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -83,7 +83,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -248,7 +248,7 @@ version = "3.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -420,7 +420,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -483,7 +483,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -724,6 +724,7 @@ dependencies = [ "serde_yaml", "sha2", "sqlformat", + "sysinfo", "tabled", "tar", "tempfile", @@ -1226,12 +1227,40 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "ntapi" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3b335231dfd352ffb0f8017f3b6027a4917f7df785ea2143d8af2adc66980ae" +dependencies = [ + "winapi", +] + [[package]] name = "number_prefix" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "objc2-core-foundation" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536" +dependencies = [ + "bitflags", +] + +[[package]] +name = "objc2-io-kit" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33fafba39597d6dc1fb709123dfa8289d39406734be322956a69f0931c73bb15" +dependencies = [ + "libc", + "objc2-core-foundation", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -1668,7 +1697,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.12.1", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -1969,6 +1998,20 @@ dependencies = [ "syn", ] +[[package]] +name = "sysinfo" +version = "0.38.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ab6a2f8bfe508deb3c6406578252e491d299cbbf3bc0529ecc3313aee4a52f" +dependencies = [ + "libc", + "memchr", + "ntapi", + "objc2-core-foundation", + "objc2-io-kit", + "windows", +] + [[package]] name = "system-configuration" version = "0.7.0" @@ -2037,7 +2080,7 @@ dependencies = [ "getrandom 0.4.2", "once_cell", "rustix 1.1.4", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -2481,12 +2524,89 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "527fadee13e0c05939a6a05d5bd6eec6cd2e3dbd648b9f8e447c6518133d8580" +dependencies = [ + "windows-collections", + "windows-core", + "windows-future", + "windows-numerics", +] + +[[package]] +name = "windows-collections" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23b2d95af1a8a14a3c7367e1ed4fc9c20e0a26e79551b1454d72583c97cc6610" +dependencies = [ + "windows-core", +] + +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-future" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1d6f90251fe18a279739e78025bd6ddc52a7e22f921070ccdc67dde84c605cb" +dependencies = [ + "windows-core", + "windows-link", + "windows-threading", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-link" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-numerics" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e2e40844ac143cdb44aead537bbf727de9b044e107a0f1220392177d15b0f26" +dependencies = [ + "windows-core", + "windows-link", +] + [[package]] name = "windows-registry" version = "0.6.1" @@ -2585,6 +2705,15 @@ dependencies = [ "windows_x86_64_msvc 0.53.1", ] +[[package]] +name = "windows-threading" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3949bd5b99cafdf1c7ca86b43ca564028dfe27d66958f2470940f73d86d75b37" +dependencies = [ + "windows-link", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.52.6" diff --git a/Cargo.toml b/Cargo.toml index b2e9737..03ab938 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,7 @@ flate2 = "1" tar = "0.4" semver = "1" sqlformat = "0.5.0" +sysinfo = { version = "0.38.4", default-features = false, features = ["system"] } [dev-dependencies] mockito = "1" diff --git a/src/api.rs b/src/api.rs index 37e7845..e0424f3 100644 --- a/src/api.rs +++ b/src/api.rs @@ -8,6 +8,7 @@ pub struct ApiClient { api_key: String, pub api_url: String, workspace_id: Option, + session_id: Option, } impl ApiClient { @@ -35,6 +36,13 @@ impl ApiClient { api_key, api_url: profile_config.api_url.to_string(), workspace_id: workspace_id.map(String::from), + session_id: std::env::var("HOTDATA_SESSION").ok().or_else(|| { + if crate::sessions::find_session_run_ancestor().is_some() { + eprintln!("error: session has been lost -- restart the process"); + std::process::exit(1); + } + profile_config.session + }), } } @@ -48,6 +56,9 @@ impl ApiClient { if let Some(ref ws) = self.workspace_id { headers.push(("X-Workspace-Id", ws.clone())); } + if let Some(ref sid) = self.session_id { + headers.push(("X-Session-Id", sid.clone())); + } headers } @@ -63,6 +74,9 @@ impl ApiClient { if let Some(ref ws) = self.workspace_id { req = req.header("X-Workspace-Id", ws); } + if let Some(ref sid) = self.session_id { + req = req.header("X-Session-Id", sid); + } req } @@ -235,6 +249,38 @@ impl ApiClient { } } + + /// PATCH request with JSON body, returns parsed response. + pub fn patch(&self, path: &str, body: &serde_json::Value) -> T { + let url = format!("{}{path}", self.api_url); + self.log_request("PATCH", &url, Some(body)); + + let resp = match self.build_request(reqwest::Method::PATCH, &url) + .json(body) + .send() + { + Ok(r) => r, + Err(e) => { + eprintln!("error connecting to API: {e}"); + std::process::exit(1); + } + }; + + let (status, resp_body) = util::debug_response(resp); + if !status.is_success() { + eprintln!("{}", util::api_error(resp_body).red()); + std::process::exit(1); + } + + match serde_json::from_str(&resp_body) { + Ok(v) => v, + Err(e) => { + eprintln!("error parsing response: {e}"); + std::process::exit(1); + } + } + } + /// POST with a custom request body (for file uploads). Returns raw status and body. pub fn post_body( &self, diff --git a/src/command.rs b/src/command.rs index b2a0f3f..1390c89 100644 --- a/src/command.rs +++ b/src/command.rs @@ -172,6 +172,23 @@ pub enum Commands { command: Option, }, + /// Manage work sessions + Sessions { + /// Session ID to show details + id: Option, + + /// Workspace ID (defaults to first workspace from login) + #[arg(long, short = 'w', global = true)] + workspace_id: Option, + + /// Output format + #[arg(long = "output", short = 'o', default_value = "table", value_parser = ["table", "json", "yaml"])] + output: String, + + #[command(subcommand)] + command: Option, + }, + /// Generate shell completions Completions { /// Shell to generate completions for @@ -542,6 +559,64 @@ pub enum QueriesCommands { }, } +#[derive(Subcommand)] +pub enum SessionsCommands { + /// List all sessions in a workspace + List { + /// Output format + #[arg(long = "output", short = 'o', default_value = "table", value_parser = ["table", "json", "yaml"])] + output: String, + }, + + /// Create a new session and set it as active + New { + /// Session name + #[arg(long)] + name: Option, + + /// Output format + #[arg(long = "output", short = 'o', default_value = "table", value_parser = ["table", "json", "yaml"])] + output: String, + }, + + /// Update a session's markdown or name + Update { + /// Session ID (defaults to active session) + id: Option, + + /// New session name + #[arg(long)] + name: Option, + + /// Markdown content + #[arg(long)] + markdown: Option, + + /// Output format + #[arg(long = "output", short = 'o', default_value = "table", value_parser = ["table", "json", "yaml"])] + output: String, + }, + + /// Set the active session (omit ID to clear) + Set { + /// Session ID to set as active (omit to clear) + id: Option, + }, + + /// Run a command with a hotdata session. Creates a new session unless an ID was provided. + /// Example: hotdata sessions run claude + /// Example: hotdata sessions run claude + Run { + /// Session name (only used when creating a new session) + #[arg(long)] + name: Option, + + /// Command and arguments to execute + #[arg(trailing_var_arg = true, required = true)] + cmd: Vec, + }, +} + #[derive(Subcommand)] pub enum TablesCommands { /// List all tables in a workspace diff --git a/src/config.rs b/src/config.rs index 3b80d67..c0aaa2b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -107,6 +107,8 @@ pub struct ProfileConfig { pub api_key_source: ApiKeySource, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub workspaces: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub session: Option, } #[derive(Debug, Deserialize, Serialize)] @@ -213,6 +215,49 @@ pub fn save_default_workspace(profile: &str, workspace: WorkspaceEntry) -> Resul write_config(&config_path, &content) } +pub fn save_session(profile: &str, session_id: &str) -> Result<(), String> { + let config_path = config_path()?; + + let mut config_file: ConfigFile = if config_path.exists() { + let content = fs::read_to_string(&config_path) + .map_err(|e| format!("error reading config file: {e}"))?; + serde_yaml::from_str(&content).map_err(|e| format!("error parsing config file: {e}"))? + } else { + ConfigFile { profiles: HashMap::new() } + }; + + config_file + .profiles + .entry(profile.to_string()) + .or_default() + .session = Some(session_id.to_string()); + + let content = serde_yaml::to_string(&config_file) + .map_err(|e| format!("error serializing config: {e}"))?; + write_config(&config_path, &content) +} + +pub fn clear_session(profile: &str) -> Result<(), String> { + let config_path = config_path()?; + + if !config_path.exists() { + return Ok(()); + } + + let content = fs::read_to_string(&config_path) + .map_err(|e| format!("error reading config file: {e}"))?; + let mut config_file: ConfigFile = + serde_yaml::from_str(&content).map_err(|e| format!("error parsing config file: {e}"))?; + + if let Some(entry) = config_file.profiles.get_mut(profile) { + entry.session = None; + } + + let content = serde_yaml::to_string(&config_file) + .map_err(|e| format!("error serializing config: {e}"))?; + write_config(&config_path, &content) +} + pub fn resolve_workspace_id(provided: Option, profile_config: &ProfileConfig) -> Result { if let Some(id) = provided { return Ok(id); diff --git a/src/datasets.rs b/src/datasets.rs index 82dd1fe..3f7b56d 100644 --- a/src/datasets.rs +++ b/src/datasets.rs @@ -8,15 +8,23 @@ use std::path::Path; struct Dataset { id: String, label: String, + #[serde(default = "default_schema")] + schema_name: String, table_name: String, created_at: String, updated_at: String, } +fn default_schema() -> String { + "main".to_string() +} + #[derive(Deserialize)] struct CreateResponse { id: String, label: String, + #[serde(default = "default_schema")] + schema_name: String, table_name: String, } @@ -231,7 +239,7 @@ fn create_dataset( println!("{}", "Dataset created".green()); println!("id: {}", dataset.id); println!("label: {}", dataset.label); - println!("full_name: datasets.main.{}", dataset.table_name); + println!("full_name: datasets.{}.{}", dataset.schema_name, dataset.table_name); } pub fn create_from_upload( @@ -381,7 +389,7 @@ pub fn list(workspace_id: &str, limit: Option, offset: Option, format: let rows: Vec> = body.datasets.iter().map(|d| vec![ d.id.clone(), d.label.clone(), - format!("datasets.main.{}", d.table_name), + format!("datasets.{}.{}", d.schema_name, d.table_name), crate::util::format_date(&d.created_at), ]).collect(); crate::table::print(&["ID", "LABEL", "FULL NAME", "CREATED AT"], &rows); diff --git a/src/main.rs b/src/main.rs index 05115c6..98cd98b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,6 +11,7 @@ mod jobs; mod queries; mod query; mod results; +mod sessions; mod skill; mod table; mod tables; @@ -19,7 +20,7 @@ mod workspace; use anstyle::AnsiColor; use clap::{Parser, builder::Styles}; -use command::{AuthCommands, Commands, ConnectionsCommands, ConnectionsCreateCommands, DatasetsCommands, IndexesCommands, JobsCommands, QueriesCommands, QueryCommands, ResultsCommands, SkillCommands, TablesCommands, WorkspaceCommands}; +use command::{AuthCommands, Commands, ConnectionsCommands, ConnectionsCreateCommands, DatasetsCommands, IndexesCommands, JobsCommands, QueriesCommands, QueryCommands, ResultsCommands, SessionsCommands, SkillCommands, TablesCommands, WorkspaceCommands}; #[derive(Parser)] #[command(name = "hotdata", version, about = concat!("Hotdata CLI - Command line interface for Hotdata (v", env!("CARGO_PKG_VERSION"), ")"), long_about = None, disable_version_flag = true)] @@ -42,6 +43,20 @@ struct Cli { } fn resolve_workspace(provided: Option) -> String { + // HOTDATA_WORKSPACE env var takes priority and blocks --workspace-id flag + if let Ok(ws) = std::env::var("HOTDATA_WORKSPACE") { + if let Some(ref flag) = provided { + if flag != &ws { + eprintln!("error: cannot override workspace -- locked by HOTDATA_WORKSPACE environment variable ({ws})"); + std::process::exit(1); + } + } + return ws; + } + if sessions::find_session_run_ancestor().is_some() { + eprintln!("error: workspace has been lost -- restart the process"); + std::process::exit(1); + } match config::load("default") { Ok(profile) => match config::resolve_workspace_id(provided, &profile) { Ok(id) => id, @@ -319,6 +334,46 @@ fn main() { } } } + Commands::Sessions { id, workspace_id, output, command } => { + let workspace_id = resolve_workspace(workspace_id); + match command { + Some(SessionsCommands::Run { name, cmd }) => { + sessions::run(id.as_deref(), &workspace_id, name.as_deref(), &cmd) + } + Some(SessionsCommands::List { output }) => { + sessions::list(&workspace_id, &output) + } + Some(SessionsCommands::New { name, output }) => { + sessions::new(&workspace_id, name.as_deref(), &output) + } + Some(SessionsCommands::Update { id: update_id, name, markdown, output }) => { + let session_id = update_id.or(id).or_else(|| { + config::load("default").ok().and_then(|p| p.session) + }); + match session_id { + Some(sid) => sessions::update(&workspace_id, &sid, name.as_deref(), markdown.as_deref(), &output), + None => { + eprintln!("error: no session ID provided and no active session set. Use 'sessions new' or 'sessions set '."); + std::process::exit(1); + } + } + } + Some(SessionsCommands::Set { id: set_id }) => { + sessions::set(set_id.as_deref(), &workspace_id) + } + None => { + match id { + Some(id) => sessions::get(&id, &workspace_id, &output), + None => { + use clap::CommandFactory; + let mut cmd = Cli::command(); + cmd.build(); + cmd.find_subcommand_mut("sessions").unwrap().print_help().unwrap(); + } + } + } + } + } Commands::Completions { shell } => { use clap::CommandFactory; use clap_complete::generate; diff --git a/src/sessions.rs b/src/sessions.rs new file mode 100644 index 0000000..7c290d8 --- /dev/null +++ b/src/sessions.rs @@ -0,0 +1,263 @@ +use crate::api::ApiClient; +use crate::config; +use crossterm::style::Stylize; +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Serialize)] +struct Session { + public_id: String, + name: String, + markdown: String, + created_at: String, + updated_at: String, +} + +#[derive(Deserialize)] +struct ListResponse { + sessions: Vec, +} + +#[derive(Deserialize)] +struct DetailResponse { + session: Session, +} + +pub fn list(workspace_id: &str, format: &str) { + let api = ApiClient::new(Some(workspace_id)); + let body: ListResponse = api.get("/sessions"); + + let current_session = std::env::var("HOTDATA_SESSION") + .ok() + .or_else(|| config::load("default").ok().and_then(|p| p.session)); + + match format { + "json" => println!("{}", serde_json::to_string_pretty(&body.sessions).unwrap()), + "yaml" => print!("{}", serde_yaml::to_string(&body.sessions).unwrap()), + "table" => { + if body.sessions.is_empty() { + eprintln!("{}", "No sessions found.".dark_grey()); + } else { + let rows: Vec> = body.sessions.iter().map(|s| { + let marker = if current_session.as_deref() == Some(&s.public_id) { "*" } else { "" }; + vec![ + marker.to_string(), + s.public_id.clone(), + s.name.clone(), + crate::util::format_date(&s.updated_at), + ] + }).collect(); + crate::table::print(&["ACTIVE", "ID", "NAME", "UPDATED"], &rows); + } + } + _ => unreachable!(), + } +} + +pub fn get(session_id: &str, workspace_id: &str, format: &str) { + let api = ApiClient::new(Some(workspace_id)); + let path = format!("/sessions/{session_id}"); + let body: DetailResponse = api.get(&path); + let s = &body.session; + + match format { + "json" => println!("{}", serde_json::to_string_pretty(s).unwrap()), + "yaml" => print!("{}", serde_yaml::to_string(s).unwrap()), + "table" => { + let label = |l: &str| format!("{:<12}", l).dark_grey().to_string(); + println!("{}{}", label("id:"), s.public_id); + println!("{}{}", label("name:"), s.name); + println!("{}{}", label("created:"), crate::util::format_date(&s.created_at)); + println!("{}{}", label("updated:"), crate::util::format_date(&s.updated_at)); + if !s.markdown.is_empty() { + println!(); + println!("{}", "Markdown:".dark_grey()); + println!("{}", s.markdown); + } + } + _ => unreachable!(), + } +} + +fn check_session_lock() { + if std::env::var("HOTDATA_SESSION").is_ok() || find_session_run_ancestor().is_some() { + eprintln!("error: session is locked"); + std::process::exit(1); + } +} + +pub fn find_session_run_ancestor() -> Option { + static CACHED: std::sync::OnceLock> = std::sync::OnceLock::new(); + *CACHED.get_or_init(find_session_run_ancestor_inner) +} + +fn find_session_run_ancestor_inner() -> Option { + use sysinfo::{ProcessRefreshKind, RefreshKind, System, UpdateKind}; + + let sys = System::new_with_specifics( + RefreshKind::nothing().with_processes( + ProcessRefreshKind::nothing().with_cmd(UpdateKind::Always), + ), + ); + + let current_pid = sysinfo::get_current_pid().ok()?; + let mut pid = sys.process(current_pid)?.parent()?; + + for _ in 0..64 { + let proc = sys.process(pid)?; + let name = proc.name().to_string_lossy(); + if name == "hotdata" { + if proc.cmd().iter().any(|a| a == "sessions") + && proc.cmd().iter().any(|a| a == "run") + { + return Some(pid); + } + } + pid = proc.parent()?; + } + None +} + +pub fn new(workspace_id: &str, name: Option<&str>, format: &str) { + check_session_lock(); + let api = ApiClient::new(Some(workspace_id)); + + let mut body = serde_json::json!({}); + if let Some(n) = name { + body["name"] = serde_json::json!(n); + } + + let resp: DetailResponse = api.post("/sessions", &body); + let s = &resp.session; + + // Set as the active session in config + if let Err(e) = config::save_session("default", &s.public_id) { + eprintln!("warning: could not save session to config: {e}"); + } + + println!("{}", "Session created".green()); + match format { + "json" => println!("{}", serde_json::to_string_pretty(s).unwrap()), + "yaml" => print!("{}", serde_yaml::to_string(s).unwrap()), + "table" => { + println!("id: {}", s.public_id); + if !s.name.is_empty() { + println!("name: {}", s.name); + } + } + _ => unreachable!(), + } +} + +pub fn update(workspace_id: &str, session_id: &str, name: Option<&str>, markdown: Option<&str>, format: &str) { + if name.is_none() && markdown.is_none() { + eprintln!("error: provide at least one of --name or --markdown."); + std::process::exit(1); + } + + let api = ApiClient::new(Some(workspace_id)); + + let mut body = serde_json::json!({}); + if let Some(n) = name { body["name"] = serde_json::json!(n); } + if let Some(m) = markdown { body["markdown"] = serde_json::json!(m); } + + let path = format!("/sessions/{session_id}"); + let resp: DetailResponse = api.patch(&path, &body); + let s = &resp.session; + + println!("{}", "Session updated".green()); + match format { + "json" => println!("{}", serde_json::to_string_pretty(s).unwrap()), + "yaml" => print!("{}", serde_yaml::to_string(s).unwrap()), + "table" => { + let label = |l: &str| format!("{:<12}", l).dark_grey().to_string(); + println!("{}{}", label("id:"), s.public_id); + println!("{}{}", label("name:"), s.name); + println!("{}{}", label("updated:"), crate::util::format_date(&s.updated_at)); + } + _ => unreachable!(), + } +} + +pub fn run(session_id: Option<&str>, workspace_id: &str, name: Option<&str>, cmd: &[String]) { + check_session_lock(); + let sid = match session_id { + Some(id) => { + // Verify the session exists + let api = ApiClient::new(Some(workspace_id)); + let path = format!("/sessions/{id}"); + let _: DetailResponse = api.get(&path); + id.to_string() + } + None => { + // Create a new session + let api = ApiClient::new(Some(workspace_id)); + let mut body = serde_json::json!({}); + if let Some(n) = name { + body["name"] = serde_json::json!(n); + } + let resp: DetailResponse = api.post("/sessions", &body); + resp.session.public_id + } + }; + + eprintln!("{} {}", "session:".dark_grey(), sid); + eprintln!("{} {}", "workspace:".dark_grey(), workspace_id); + + let status = std::process::Command::new(&cmd[0]) + .args(&cmd[1..]) + .env("HOTDATA_SESSION", &sid) + .env("HOTDATA_WORKSPACE", workspace_id) + .status(); + + match status { + Ok(s) => std::process::exit(s.code().unwrap_or(1)), + Err(e) => { + eprintln!("error: failed to execute '{}': {e}", cmd[0]); + std::process::exit(1); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn find_session_run_ancestor_returns_none_in_test() { + // No `hotdata sessions run` ancestor exists in the test runner + assert!(find_session_run_ancestor_inner().is_none()); + } + + #[test] + fn find_session_run_ancestor_cached_matches_inner() { + // The cached version should agree with the inner function + assert_eq!(find_session_run_ancestor(), find_session_run_ancestor_inner()); + } +} + +pub fn set(session_id: Option<&str>, workspace_id: &str) { + check_session_lock(); + match session_id { + Some(id) => { + // Verify the session exists by fetching it + let api = ApiClient::new(Some(workspace_id)); + let path = format!("/sessions/{id}"); + let _: DetailResponse = api.get(&path); + + if let Err(e) = config::save_session("default", id) { + eprintln!("error saving config: {e}"); + std::process::exit(1); + } + println!("{}", "Active session updated".green()); + println!("id: {}", id); + } + None => { + // Clear the active session + if let Err(e) = config::clear_session("default") { + eprintln!("error saving config: {e}"); + std::process::exit(1); + } + println!("{}", "Active session cleared".green()); + } + } +} diff --git a/src/util.rs b/src/util.rs index 00816c6..cd84ef9 100644 --- a/src/util.rs +++ b/src/util.rs @@ -142,5 +142,11 @@ pub fn api_error(body: String) -> String { serde_json::from_str::(&body) .ok() .and_then(|v| v["error"]["message"].as_str().map(str::to_string)) - .unwrap_or(body) + .unwrap_or_else(|| { + if body.trim_start().starts_with('<') { + "unexpected server error".to_string() + } else { + body + } + }) } diff --git a/src/workspace.rs b/src/workspace.rs index 719a7aa..5b53697 100644 --- a/src/workspace.rs +++ b/src/workspace.rs @@ -17,6 +17,10 @@ struct ListResponse { } pub fn set(workspace_id: Option<&str>) { + if std::env::var("HOTDATA_WORKSPACE").is_ok() || crate::sessions::find_session_run_ancestor().is_some() { + eprintln!("error: workspace is locked"); + std::process::exit(1); + } let api = ApiClient::new(None); let body: ListResponse = api.get("/workspaces"); let workspaces = body.workspaces; @@ -68,7 +72,8 @@ pub fn list(format: &str) { std::process::exit(1); } }; - let default_id = profile_config.workspaces.first().map(|w| w.public_id.as_str()).unwrap_or("").to_string(); + let default_id = std::env::var("HOTDATA_WORKSPACE") + .unwrap_or_else(|_| profile_config.workspaces.first().map(|w| w.public_id.clone()).unwrap_or_default()); let api = ApiClient::new(None); let body: ListResponse = api.get("/workspaces"); diff --git a/tests/session_env.rs b/tests/session_env.rs new file mode 100644 index 0000000..32533b8 --- /dev/null +++ b/tests/session_env.rs @@ -0,0 +1,86 @@ +use std::process::Command; + +fn hotdata() -> Command { + Command::new(env!("CARGO_BIN_EXE_hotdata")) +} + +// --- session lock tests --- + +#[test] +fn sessions_run_blocked_when_hotdata_session_set() { + let output = hotdata() + .args(["sessions", "run", "echo", "hi"]) + .env("HOTDATA_SESSION", "existing-session") + .env("HOTDATA_WORKSPACE", "ws-1") + .output() + .unwrap(); + + assert!(!output.status.success()); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!(stderr.contains("session is locked"), "stderr: {stderr}"); +} + +#[test] +fn sessions_new_blocked_when_hotdata_session_set() { + let output = hotdata() + .args(["sessions", "new"]) + .env("HOTDATA_SESSION", "existing-session") + .env("HOTDATA_WORKSPACE", "ws-1") + .output() + .unwrap(); + + assert!(!output.status.success()); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!(stderr.contains("session is locked"), "stderr: {stderr}"); +} + +#[test] +fn sessions_set_blocked_when_hotdata_session_set() { + let output = hotdata() + .args(["sessions", "set", "some-id"]) + .env("HOTDATA_SESSION", "existing-session") + .env("HOTDATA_WORKSPACE", "ws-1") + .output() + .unwrap(); + + assert!(!output.status.success()); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!(stderr.contains("session is locked"), "stderr: {stderr}"); +} + +// --- workspace env lock tests --- + +#[test] +fn workspace_env_blocks_conflicting_flag() { + let output = hotdata() + .args(["sessions", "-w", "other-ws", "list"]) + .env("HOTDATA_WORKSPACE", "locked-ws") + .env_remove("HOTDATA_SESSION") + .output() + .unwrap(); + + assert!(!output.status.success()); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!( + stderr.contains("locked by HOTDATA_WORKSPACE"), + "stderr: {stderr}" + ); +} + +#[test] +fn workspace_env_allows_matching_flag() { + // When the flag matches the env var, no workspace conflict error. + // Will fail later on auth, but should NOT fail on workspace lock. + let output = hotdata() + .args(["sessions", "-w", "ws-1", "list"]) + .env("HOTDATA_WORKSPACE", "ws-1") + .env_remove("HOTDATA_SESSION") + .output() + .unwrap(); + + let stderr = String::from_utf8_lossy(&output.stderr); + assert!( + !stderr.contains("locked by HOTDATA_WORKSPACE"), + "unexpected workspace lock error: {stderr}" + ); +}