diff --git a/asap-query-engine/src/engines/simple_engine/sql.rs b/asap-query-engine/src/engines/simple_engine/sql.rs index d9fe864..bf2eb49 100644 --- a/asap-query-engine/src/engines/simple_engine/sql.rs +++ b/asap-query-engine/src/engines/simple_engine/sql.rs @@ -137,6 +137,50 @@ fn sort_and_truncate_instant_vector( results } +/// Detect a SQL top-k query and return its `k`. +/// +/// Recognises the heavy-hitter shape that `CountMinSketchWithHeap` serves: +/// +/// ```sql +/// SELECT , COUNT() AS +/// FROM WHERE <1s window> +/// GROUP BY +/// ORDER BY DESC +/// LIMIT k +/// ``` +/// +/// The grouping key (``) becomes the *aggregated* dimension inside the +/// sketch's heap — not a precompute partition key — so a single sketch per +/// window tracks the top keys by event count. +/// +/// The SQL parser only accepts identifier ORDER BY targets, so the descending +/// order must reference the aggregate's alias (e.g. `transfer_events`), not the +/// `COUNT(col)` expression itself. +pub(crate) fn detect_sql_topk(query_data: &SQLQueryData) -> Option { + let k = query_data.limit?; + // Need a GROUP BY key to rank and an ORDER BY to define "top". + if query_data.labels.is_empty() || query_data.order_by.is_empty() { + return None; + } + // CountMinSketchWithHeap tracks heavy hitters by COUNT. + if !query_data + .aggregation_info + .get_name() + .eq_ignore_ascii_case("count") + { + return None; + } + // Primary ordering must be the aggregate alias, descending (largest first). + let primary = &query_data.order_by[0]; + if primary.ascending { + return None; + } + if query_data.aggregation_alias.as_deref() != Some(primary.column.as_str()) { + return None; + } + Some(k) +} + impl SimpleEngine { /// Finds the query configuration for a SQL query using structural pattern matching. /// @@ -268,6 +312,7 @@ impl SimpleEngine { &self, match_result: &SQLQuery, query_pattern_type: QueryPatternType, + topk_k: Option, ) -> QueryRequirements { let query_data = match_result .outer_data() @@ -284,9 +329,18 @@ impl SimpleEngine { _ => query_data.aggregation_info.get_name().to_lowercase(), }; - let statistics: Vec = Self::parse_single_statistic(&statistic_name) - .into_iter() - .collect(); + // For top-k the requirement is `Statistic::Topk` (→ CountMinSketchWithHeap) + // and the grouping is empty: the GROUP BY column is the sketch's + // *aggregated* (heavy-hitter) dimension, held inside one sketch per + // window, not a precompute partition key. + let is_topk = topk_k.is_some(); + let statistics: Vec = if is_topk { + vec![Statistic::Topk] + } else { + Self::parse_single_statistic(&statistic_name) + .into_iter() + .collect() + }; let data_range_ms = match query_pattern_type { QueryPatternType::OnlySpatial => None, @@ -305,7 +359,11 @@ impl SimpleEngine { } }; - let grouping_labels = KeyByLabelNames::new(query_data.labels.clone().into_iter().collect()); + let grouping_labels = if is_topk { + KeyByLabelNames::empty() + } else { + KeyByLabelNames::new(query_data.labels.clone().into_iter().collect()) + }; QueryRequirements { metric, @@ -323,8 +381,17 @@ impl SimpleEngine { ) -> Option<(KeyByLabelNames, QueryResult)> { let (context, post) = self.build_query_execution_context_sql_with_post_processing(query, time)?; - let (output_labels, result) = self.execute_context(context, false, false)?; - let result = post.apply(&output_labels, result); + let is_topk = context.metadata.statistic_to_compute == Statistic::Topk; + // Top-k: enable heap-based limiting (truncate to k) but NOT PromQL-style + // metric-name formatting; the sketch heap already produces the ranked + // `(group-by key, count)` rows, so SQL ORDER BY / LIMIT post-processing + // would be redundant and is skipped. + let (output_labels, result) = self.execute_context(context, is_topk, false)?; + let result = if is_topk { + result + } else { + post.apply(&output_labels, result) + }; Some((output_labels, result)) } @@ -502,15 +569,26 @@ impl SimpleEngine { } }; - let statistic_to_compute = Self::parse_single_statistic(&statistic_name)?; + // Top-k detection takes precedence: `... ORDER BY DESC LIMIT k` + // is served by CountMinSketchWithHeap (Statistic::Topk) rather than the + // plain COUNT path, so the sketch heap drives the result set. + let topk_k = detect_sql_topk(&query_data); + let statistic_to_compute = if topk_k.is_some() { + Statistic::Topk + } else { + Self::parse_single_statistic(&statistic_name)? + }; - let query_kwargs = self + let mut query_kwargs = self .build_query_kwargs_sql(&statistic_to_compute, &match_result) .map_err(|e| { warn!("{}", e); e }) .ok()?; + if let Some(k) = topk_k { + query_kwargs.insert("k".to_string(), k.to_string()); + } // Create query metadata let metadata = QueryMetadata { @@ -524,24 +602,24 @@ impl SimpleEngine { self.calculate_query_timestamps_sql(query_time, query_pattern_type, &match_result); // Resolve aggregation: try pre-configured query_configs first, fall back to capability matching. - let agg_info: AggregationIdInfo = if let Some(config) = - self.find_query_config_sql(&query_data) - { - self.get_aggregation_id_info(&config) - .map_err(|e| { - warn!("{}", e); - e - }) - .ok()? - } else { - warn!("No query_config entry for SQL query. Attempting capability-based matching."); - let requirements = self.build_query_requirements_sql(&match_result, query_pattern_type); - self.streaming_config - .read() - .unwrap() - .clone() - .find_compatible_aggregation(&requirements)? - }; + let agg_info: AggregationIdInfo = + if let Some(config) = self.find_query_config_sql(&query_data) { + self.get_aggregation_id_info(&config) + .map_err(|e| { + warn!("{}", e); + e + }) + .ok()? + } else { + warn!("No query_config entry for SQL query. Attempting capability-based matching."); + let requirements = + self.build_query_requirements_sql(&match_result, query_pattern_type, topk_k); + self.streaming_config + .read() + .unwrap() + .clone() + .find_compatible_aggregation(&requirements)? + }; let metric = &match_result.outer_data()?.metric; @@ -687,8 +765,11 @@ impl SimpleEngine { warn!( "No query_config entry for SQL spatio-temporal query. Attempting capability-based matching." ); - let requirements = - self.build_query_requirements_sql(match_result, QueryPatternType::OnlyTemporal); + let requirements = self.build_query_requirements_sql( + match_result, + QueryPatternType::OnlyTemporal, + None, + ); self.streaming_config .read() .unwrap() @@ -709,6 +790,111 @@ impl SimpleEngine { } } +#[cfg(test)] +mod detect_topk_tests { + use super::detect_sql_topk; + use sql_utilities::ast_matching::SQLPatternParser; + use sql_utilities::sqlhelper::{SQLSchema, Table}; + use sqlparser::dialect::GenericDialect; + use sqlparser::parser::Parser; + use std::collections::HashSet; + + /// Parse a SQL string into `SQLQueryData` against a netflow-shaped schema. + /// Returns `None` if the parser rejects the query (e.g. unsupported ORDER BY). + fn parse(sql: &str) -> Option { + let value_cols: HashSet = ["pkt_len"].iter().map(|s| s.to_string()).collect(); + let labels: HashSet = ["srcip", "dstip", "proto"] + .iter() + .map(|s| s.to_string()) + .collect(); + let table = Table::new( + "netflow_table".to_string(), + "time".to_string(), + value_cols, + labels, + ); + let schema = SQLSchema::new(vec![table]); + let statements = Parser::parse_sql(&GenericDialect {}, sql).ok()?; + SQLPatternParser::new(&schema, 0.0).parse_query(&statements) + } + + const WINDOW: &str = + "WHERE time BETWEEN DATEADD(s, -1, '2025-10-01 00:00:10') AND '2025-10-01 00:00:10'"; + + #[test] + fn count_order_by_alias_desc_limit_is_topk() { + let sql = format!( + "SELECT srcip, COUNT(pkt_len) AS transfer_events FROM netflow_table {WINDOW} \ + GROUP BY srcip ORDER BY transfer_events DESC LIMIT 10" + ); + let qd = parse(&sql).expect("valid topk query should parse"); + assert_eq!(detect_sql_topk(&qd), Some(10)); + } + + #[test] + fn missing_limit_is_not_topk() { + let sql = format!( + "SELECT srcip, COUNT(pkt_len) AS transfer_events FROM netflow_table {WINDOW} \ + GROUP BY srcip ORDER BY transfer_events DESC" + ); + let qd = parse(&sql).expect("query should parse"); + assert_eq!(detect_sql_topk(&qd), None, "no LIMIT ⇒ not top-k"); + } + + #[test] + fn ascending_order_is_not_topk() { + let sql = format!( + "SELECT srcip, COUNT(pkt_len) AS transfer_events FROM netflow_table {WINDOW} \ + GROUP BY srcip ORDER BY transfer_events ASC LIMIT 10" + ); + let qd = parse(&sql).expect("query should parse"); + assert_eq!( + detect_sql_topk(&qd), + None, + "ASC ordering is bottom-k, not top-k" + ); + } + + #[test] + fn no_order_by_is_not_topk() { + let sql = format!( + "SELECT srcip, COUNT(pkt_len) AS transfer_events FROM netflow_table {WINDOW} \ + GROUP BY srcip LIMIT 10" + ); + let qd = parse(&sql).expect("query should parse"); + assert_eq!( + detect_sql_topk(&qd), + None, + "LIMIT without ORDER BY is not top-k" + ); + } + + #[test] + fn sum_aggregate_is_not_topk() { + let sql = format!( + "SELECT srcip, SUM(pkt_len) AS total FROM netflow_table {WINDOW} \ + GROUP BY srcip ORDER BY total DESC LIMIT 10" + ); + let qd = parse(&sql).expect("query should parse"); + assert_eq!( + detect_sql_topk(&qd), + None, + "only COUNT maps to CMS-with-heap top-k" + ); + } + + #[test] + fn order_by_group_key_is_not_topk() { + // Ordering by the group-by key (not the count) is a plain sorted listing. + let sql = format!( + "SELECT srcip, COUNT(pkt_len) AS transfer_events FROM netflow_table {WINDOW} \ + GROUP BY srcip ORDER BY srcip DESC LIMIT 10" + ); + let qd = parse(&sql).expect("query should parse"); + assert_eq!(detect_sql_topk(&qd), None); + } +} + #[cfg(test)] mod sort_and_truncate_tests { use super::sort_and_truncate_instant_vector; @@ -901,3 +1087,212 @@ mod sort_and_truncate_tests { assert_eq!(values, vec![5.0, 3.0]); } } + +/// End-to-end tests for SQL top-k queries served by `CountMinSketchWithHeap`. +/// +/// Exercises the full path for `SELECT srcip, COUNT(pkt_len) AS k FROM +/// netflow_table WHERE <1s window> GROUP BY srcip ORDER BY k DESC LIMIT n`: +/// * SQL detection promotes it to `Statistic::Topk`. +/// * The single `CountMinSketchWithHeap` aggregation resolves self-keyed +/// (key id == value id), so the sketch heap enumerates candidate `srcip`s. +/// * The pipeline sorts by count descending and truncates to `n`, without +/// PromQL-style metric-name prefixing (rows stay bare `(srcip, count)`). +/// +/// Lives here alongside `detect_topk_tests` / `sort_and_truncate_tests` so all +/// SQL top-k coverage is co-located in the SQL handler. Unlike those pure-fn +/// modules this one builds a real `SimpleEngine` + store and runs the pipeline, +/// since the top-k execution path skips `SqlPostProcessing::apply` (its ordering +/// happens in `format_final_results` and truncation in `execute_query_pipeline`). +#[cfg(test)] +mod topk_pipeline_tests { + use super::SimpleEngine; + use crate::data_model::{ + AggregationConfig, AggregationReference, AggregationType, CleanupPolicy, InferenceConfig, + PrecomputedOutput, QueryConfig, QueryLanguage, SchemaConfig, StreamingConfig, WindowType, + }; + use crate::precompute_operators::CountMinSketchWithHeapAccumulator; + use crate::stores::simple_map_store::SimpleMapStore; + use crate::stores::Store; + use promql_utilities::data_model::KeyByLabelNames; + use promql_utilities::query_logics::enums::Statistic; + use sql_utilities::sqlhelper::{SQLSchema, Table}; + use std::collections::{HashMap, HashSet}; + use std::sync::Arc; + + const AGG_ID: u64 = 101; + const METRIC: &str = "netflow_table"; + // '2025-10-01 00:00:10' (UTC) in seconds. + const QUERY_TIME: f64 = 1_759_276_810.0; + + /// Build a SQL engine whose only aggregation is a self-keyed + /// `CountMinSketchWithHeap` over `netflow_table`, grouped globally (no + /// partition labels) and aggregating the `srcip` heavy-hitter dimension. + /// Returns the engine plus a handle to the shared store for inserting + /// precomputed sketches. + fn build_topk_engine() -> (SimpleEngine, Arc) { + // Template stored in the inference config. Matches incoming top-k queries + // structurally (ORDER BY / LIMIT / aliases are ignored by SQL pattern + // matching), and references a single `CountMinSketchWithHeap` aggregation + // so the engine resolves it self-keyed. + let template = "SELECT srcip, COUNT(pkt_len) FROM netflow_table \ + WHERE time BETWEEN DATEADD(s, -1, NOW()) AND NOW() GROUP BY srcip"; + + let value_cols: HashSet = ["pkt_len"].iter().map(|s| s.to_string()).collect(); + let labels: HashSet = ["srcip", "dstip", "proto"] + .iter() + .map(|s| s.to_string()) + .collect(); + let table = Table::new(METRIC.to_string(), "time".to_string(), value_cols, labels); + let sql_schema = SQLSchema::new(vec![table]); + + let query_config = QueryConfig::new(template.to_string()) + .add_aggregation(AggregationReference::new(AGG_ID, None)); + + let inference_config = InferenceConfig { + schema: SchemaConfig::SQL(sql_schema), + query_configs: vec![query_config], + cleanup_policy: CleanupPolicy::NoCleanup, + }; + + let agg_config = AggregationConfig { + aggregation_id: AGG_ID, + aggregation_type: AggregationType::CountMinSketchWithHeap, + aggregation_sub_type: String::new(), + parameters: HashMap::new(), + // Empty grouping: one global sketch. The GROUP BY column (`srcip`) + // is the sketch's *aggregated* heavy-hitter dimension, not a + // precompute partition key. + grouping_labels: KeyByLabelNames::empty(), + aggregated_labels: KeyByLabelNames::new(vec!["srcip".to_string()]), + rollup_labels: KeyByLabelNames::empty(), + original_yaml: String::new(), + window_size: 1, + slide_interval: 1, + window_type: WindowType::Tumbling, + spatial_filter: String::new(), + spatial_filter_normalized: String::new(), + metric: METRIC.to_string(), + num_aggregates_to_retain: None, + read_count_threshold: None, + table_name: None, + value_column: None, + }; + + let mut agg_configs = HashMap::new(); + agg_configs.insert(AGG_ID, agg_config); + let streaming_config = Arc::new(StreamingConfig { + aggregation_configs: agg_configs, + }); + + let store = Arc::new(SimpleMapStore::new( + streaming_config.clone(), + CleanupPolicy::NoCleanup, + )); + + let engine = SimpleEngine::new( + store.clone(), + inference_config, + streaming_config, + 1, // 1s scrape interval ⇒ the 1s window classifies as OnlySpatial + QueryLanguage::sql, + ); + (engine, store) + } + + /// Incoming top-k query over a 1-second absolute window. + fn topk_query(limit: u64) -> String { + format!( + "SELECT srcip, COUNT(pkt_len) AS transfer_events FROM netflow_table \ + WHERE time BETWEEN DATEADD(s, -1, '2025-10-01 00:00:10') AND '2025-10-01 00:00:10' \ + GROUP BY srcip ORDER BY transfer_events DESC LIMIT {limit}" + ) + } + + #[test] + fn detects_topk_and_resolves_self_keyed_heap() { + let (engine, _store) = build_topk_engine(); + let context = engine + .build_query_execution_context_sql(topk_query(10), QUERY_TIME) + .expect("top-k query should build a context via the query_config path"); + + assert_eq!( + context.metadata.statistic_to_compute, + Statistic::Topk, + "ORDER BY DESC LIMIT n must be promoted to Topk", + ); + assert_eq!( + context.metadata.query_kwargs.get("k").map(String::as_str), + Some("10"), + "LIMIT should be threaded through as the `k` kwarg", + ); + // Self-keyed: the heap supplies both keys and counts, so no separate + // key aggregation / keys query is planned. + assert_eq!( + context.agg_info.aggregation_id_for_key, + context.agg_info.aggregation_id_for_value, + ); + assert!(context.store_plan.keys_query.is_none()); + } + + #[test] + fn returns_top_k_srcips_sorted_descending() { + let (engine, store) = build_topk_engine(); + + // Build the context first so we can insert the sketch into exactly the + // window the store plan will query. + let context = engine + .build_query_execution_context_sql(topk_query(10), QUERY_TIME) + .expect("context should build"); + let window = &context.store_plan.values_query; + + // 15 distinct srcips with strictly increasing counts 10, 20, ... 150. + // A width-1024 / depth-3 sketch makes collisions among 15 keys + // effectively impossible, so estimates equal the inserted counts. + let mut sketch = CountMinSketchWithHeapAccumulator::new(3, 1024, 32); + for i in 1..=15u64 { + let srcip = format!("10.0.0.{i}"); + sketch.inner.update(&srcip, (i * 10) as f64); + } + + let output = + PrecomputedOutput::new(window.start_timestamp, window.end_timestamp, None, AGG_ID); + store + .insert_precomputed_output(output, Box::new(sketch)) + .expect("insert should succeed"); + + // enable_topk_limiting=true (truncate to k via heap), formatting=false + // (SQL rows stay bare, no __name__ prefix). + let results = engine + .execute_query_pipeline(&context, true, false) + .expect("pipeline should produce results"); + + assert_eq!(results.len(), 10, "LIMIT 10 must truncate to 10 rows"); + + // Sorted by count descending. + for pair in results.windows(2) { + assert!( + pair[0].value >= pair[1].value, + "results must be sorted by count descending: {} then {}", + pair[0].value, + pair[1].value, + ); + } + + // Highest count first; bare single-label rows (no metric-name prefix). + assert_eq!(results[0].labels.labels, vec!["10.0.0.15".to_string()]); + assert_eq!(results[0].value, 150.0); + for element in &results { + assert_eq!( + element.labels.labels.len(), + 1, + "SQL top-k rows carry only the GROUP BY column, never a metric prefix", + ); + } + + // The returned set is exactly the 10 largest srcips (6..=15). + let returned: HashSet = + results.iter().map(|e| e.labels.labels[0].clone()).collect(); + let expected: HashSet = (6..=15u64).map(|i| format!("10.0.0.{i}")).collect(); + assert_eq!(returned, expected); + } +} diff --git a/asap-query-engine/src/precompute_engine/accumulator_factory.rs b/asap-query-engine/src/precompute_engine/accumulator_factory.rs index e6005e9..68a0b89 100644 --- a/asap-query-engine/src/precompute_engine/accumulator_factory.rs +++ b/asap-query-engine/src/precompute_engine/accumulator_factory.rs @@ -1,8 +1,9 @@ use crate::data_model::{AggregateCore, AggregationType, KeyByLabelValues, Measurement}; use crate::precompute_operators::{ - CountMinSketchAccumulator, DatasketchesKLLAccumulator, HllAccumulator, - HydraKllSketchAccumulator, IncreaseAccumulator, MinMaxAccumulator, MultipleIncreaseAccumulator, - MultipleMinMaxAccumulator, MultipleSumAccumulator, SumAccumulator, DEFAULT_HLL_PRECISION, + CountMinSketchAccumulator, CountMinSketchWithHeapAccumulator, DatasketchesKLLAccumulator, + HllAccumulator, HydraKllSketchAccumulator, IncreaseAccumulator, MinMaxAccumulator, + MultipleIncreaseAccumulator, MultipleMinMaxAccumulator, MultipleSumAccumulator, SumAccumulator, + DEFAULT_HLL_PRECISION, }; use asap_types::aggregation_config::AggregationConfig; @@ -563,6 +564,67 @@ impl AccumulatorUpdater for CmsAccumulatorUpdater { } } +// --------------------------------------------------------------------------- +// CmsWithHeapAccumulatorUpdater (CountMinSketchWithHeap — top-k) +// --------------------------------------------------------------------------- + +/// Keyed updater backing `Statistic::Topk`. Wraps a `CountMinSketchWithHeap` +/// (CMS + heavy-hitter heap) so the query engine can enumerate the top-k keys +/// from the heap and read each key's frequency estimate from the sketch. +/// +/// `count_events` selects the per-sample weight fed into the sketch: +/// * `true` → weight 1 per observation (COUNT semantics, e.g. `COUNT(pkt_len)`), +/// * `false` → the sample value itself (SUM-of-value semantics). +pub struct CmsWithHeapAccumulatorUpdater { + acc: CountMinSketchWithHeapAccumulator, + row_num: usize, + col_num: usize, + heap_size: usize, + count_events: bool, +} + +impl CmsWithHeapAccumulatorUpdater { + pub fn new(row_num: usize, col_num: usize, heap_size: usize, count_events: bool) -> Self { + Self { + acc: CountMinSketchWithHeapAccumulator::new(row_num, col_num, heap_size), + row_num, + col_num, + heap_size, + count_events, + } + } +} + +impl AccumulatorUpdater for CmsWithHeapAccumulatorUpdater { + fn update_single(&mut self, _value: f64, _timestamp_ms: i64) { + debug_assert!( + false, + "update_single called on keyed updater; use update_keyed" + ); + } + + fn update_keyed(&mut self, key: &KeyByLabelValues, value: f64, _timestamp_ms: i64) { + let weight = if self.count_events { 1.0 } else { value }; + self.acc.inner.update(&key.to_semicolon_str(), weight); + } + + impl_accumulator_methods!(acc); + + fn reset(&mut self) { + self.acc = + CountMinSketchWithHeapAccumulator::new(self.row_num, self.col_num, self.heap_size); + } + + fn is_keyed(&self) -> bool { + true + } + + fn memory_usage_bytes(&self) -> usize { + std::mem::size_of::() + + self.row_num * self.col_num * std::mem::size_of::() + } +} + // --------------------------------------------------------------------------- // HydraKllAccumulatorUpdater // --------------------------------------------------------------------------- @@ -669,6 +731,35 @@ fn hydra_kll_params(config: &AggregationConfig) -> (usize, usize, u16) { (row_num, col_num, kll_k_param(config)) } +/// Extract `(row_num, col_num, heap_size)` for CountMinSketchWithHeap configs. +/// +/// Accepts the planner/Arroyo-canonical `depth`/`width`/`heapsize` names first, +/// then falls back to the `row_num`/`col_num`/`heap_size` aliases. Defaults +/// mirror the planner sketch defaults (depth 3, width 1024) with a heap of 32. +fn cms_heap_params(config: &AggregationConfig) -> (usize, usize, usize) { + let read = |names: &[&str], default: u64| -> usize { + names + .iter() + .find_map(|n| config.parameters.get(*n).and_then(|v| v.as_u64())) + .unwrap_or(default) as usize + }; + let row_num = read(&["depth", "row_num"], 3); + let col_num = read(&["width", "col_num"], 1024); + let heap_size = read(&["heapsize", "heap_size"], 32); + (row_num, col_num, heap_size) +} + +/// Whether a CountMinSketchWithHeap config should count events (weight 1 per +/// observation, COUNT semantics) rather than summing the sample value. +/// Defaults to `true` so `COUNT(...)` top-k works out of the box. +fn cms_count_events(config: &AggregationConfig) -> bool { + config + .parameters + .get("count_events") + .and_then(|v| v.as_bool()) + .unwrap_or(true) +} + /// Extract the HLL `precision` parameter from a config. Falls back to /// `DEFAULT_HLL_PRECISION` (14) when absent or non-numeric. The valid range is /// 4..=18 per the underlying `HllSketch` storage; out-of-range values are @@ -750,10 +841,19 @@ pub fn create_accumulator_updater(config: &AggregationConfig) -> Box Box::new(IncreaseAccumulatorUpdater::new()), - AggregationType::CountMinSketch | AggregationType::CountMinSketchWithHeap => { + AggregationType::CountMinSketch => { let (row_num, col_num) = cms_params(config); Box::new(CmsAccumulatorUpdater::new(row_num, col_num)) } + AggregationType::CountMinSketchWithHeap => { + let (row_num, col_num, heap_size) = cms_heap_params(config); + Box::new(CmsWithHeapAccumulatorUpdater::new( + row_num, + col_num, + heap_size, + cms_count_events(config), + )) + } AggregationType::HydraKLL => { let (row_num, col_num, k) = hydra_kll_params(config); Box::new(HydraKllAccumulatorUpdater::new(row_num, col_num, k)) @@ -1120,4 +1220,123 @@ mod tests { .expect("should be KLL"); assert_eq!(kll.inner.k, 50, "k should be 50 from capital-K param"); } + + fn cms_heap_config( + parameters: std::collections::HashMap, + ) -> AggregationConfig { + AggregationConfig::new( + 101, + AggregationType::CountMinSketchWithHeap, + "topk".to_string(), + parameters, + promql_utilities::data_model::key_by_label_names::KeyByLabelNames::new(vec![]), + promql_utilities::data_model::key_by_label_names::KeyByLabelNames::new(vec![ + "srcip".to_string() + ]), + promql_utilities::data_model::key_by_label_names::KeyByLabelNames::new(vec![]), + String::new(), + 1, + 0, + WindowType::Tumbling, + "netflow_table".to_string(), + "netflow_table".to_string(), + None, + None, + None, + None, + ) + } + + #[test] + fn test_cms_with_heap_factory_routes_to_heap_accumulator_and_is_keyed() { + // CountMinSketchWithHeap must build a CmsWithHeapAccumulatorUpdater whose + // accumulator exposes the heap (get_keys), NOT a plain CMS (no heap). + let config = cms_heap_config(std::collections::HashMap::new()); + let updater = create_accumulator_updater(&config); + assert!(updater.is_keyed(), "CMS-with-heap top-k is keyed by srcip"); + + let acc = updater.snapshot_accumulator(); + assert_eq!(acc.type_name(), "CountMinSketchWithHeapAccumulator"); + assert_eq!( + acc.get_accumulator_type(), + AggregationType::CountMinSketchWithHeap + ); + assert!( + acc.get_keys().is_some(), + "heap accumulator must enumerate top-k candidate keys" + ); + } + + #[test] + fn test_cms_with_heap_count_events_uses_unit_weight() { + // count_events (the default) → each observation contributes weight 1, so + // the per-key estimate is the EVENT COUNT, not the sum of sample values. + let config = cms_heap_config(std::collections::HashMap::new()); + let mut updater = create_accumulator_updater(&config); + + let key = KeyByLabelValues::new_with_labels(vec!["10.0.0.1".to_string()]); + // Feed 5 events with large values; count semantics must yield ~5, not ~Σvalue. + for _ in 0..5 { + updater.update_keyed(&key, 1000.0, 0); + } + let acc = updater.take_accumulator(); + let cms = acc + .as_any() + .downcast_ref::() + .expect("CountMinSketchWithHeap accumulator"); + assert_eq!( + cms.query_key(&key), + 5.0, + "count_events should count events (5), not sum values (5000)" + ); + } + + #[test] + fn test_cms_with_heap_count_events_false_sums_values() { + // count_events=false → weight is the sample value, giving SUM semantics. + let mut params = std::collections::HashMap::new(); + params.insert("count_events".to_string(), serde_json::json!(false)); + let config = cms_heap_config(params); + let mut updater = create_accumulator_updater(&config); + + let key = KeyByLabelValues::new_with_labels(vec!["10.0.0.1".to_string()]); + for _ in 0..5 { + updater.update_keyed(&key, 10.0, 0); + } + let acc = updater.take_accumulator(); + let cms = acc + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(cms.query_key(&key), 50.0, "sum of 5×10 == 50"); + } + + #[test] + fn test_cms_heap_params_reads_depth_width_heapsize() { + let mut params = std::collections::HashMap::new(); + params.insert("depth".to_string(), serde_json::json!(4)); + params.insert("width".to_string(), serde_json::json!(2048)); + params.insert("heapsize".to_string(), serde_json::json!(40)); + let config = cms_heap_config(params); + assert_eq!(cms_heap_params(&config), (4, 2048, 40)); + assert!(cms_count_events(&config), "count_events defaults to true"); + } + + #[test] + fn test_cms_with_heap_reset_clears_state() { + let config = cms_heap_config(std::collections::HashMap::new()); + let mut updater = create_accumulator_updater(&config); + let key = KeyByLabelValues::new_with_labels(vec!["k".to_string()]); + for _ in 0..10 { + updater.update_keyed(&key, 1.0, 0); + } + updater.reset(); + let acc = updater.take_accumulator(); + let cms = acc + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(cms.query_key(&key), 0.0, "reset must clear the sketch"); + assert!(cms.get_topk_keys().is_empty(), "reset must clear the heap"); + } }