diff --git a/src/ast/mod.rs b/src/ast/mod.rs index c0826f2008..c853c6cad6 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -107,8 +107,8 @@ pub use self::query::{ TableIndexHintType, TableIndexHints, TableIndexType, TableSample, TableSampleBucket, TableSampleKind, TableSampleMethod, TableSampleModifier, TableSampleQuantity, TableSampleSeed, TableSampleSeedModifier, TableSampleUnit, TableVersion, TableWithJoins, Top, TopQuantity, - UpdateTableFromKind, ValueTableMode, Values, WildcardAdditionalOptions, With, WithFill, - XmlNamespaceDefinition, XmlPassingArgument, XmlPassingClause, XmlTableColumn, + UpdateTableFromKind, ValueTableMode, Values, WildcardAdditionalOptions, With, WithExpression, + WithFill, XmlNamespaceDefinition, XmlPassingArgument, XmlPassingClause, XmlTableColumn, XmlTableColumnOption, }; diff --git a/src/ast/query.rs b/src/ast/query.rs index 1de0e0e9db..5df43d4afd 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -754,8 +754,8 @@ pub struct With { pub with_token: AttachedToken, /// Whether the `WITH` is recursive (`WITH RECURSIVE`). pub recursive: bool, - /// The list of CTEs declared by this `WITH` clause. - pub cte_tables: Vec, + /// The expressions declared by this `WITH` clause. + pub exprs: Vec, } impl fmt::Display for With { @@ -764,11 +764,33 @@ impl fmt::Display for With { if self.recursive { f.write_str("RECURSIVE ")?; } - display_comma_separated(&self.cte_tables).fmt(f)?; + display_comma_separated(&self.exprs).fmt(f)?; Ok(()) } } +/// A single expression in a `WITH` clause. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum WithExpression { + /// A common table expression. + Cte(Cte), + /// A common scalar expression. + /// + /// [Clickhouse]: https://clickhouse.com/docs/sql-reference/statements/select/with#common-scalar-expressions + Cse(ExprWithAlias), +} + +impl fmt::Display for WithExpression { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WithExpression::Cte(cte) => cte.fmt(f), + WithExpression::Cse(cse) => cse.fmt(f), + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] diff --git a/src/ast/spans.rs b/src/ast/spans.rs index f6ba895478..674c0589d0 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -47,7 +47,7 @@ use super::{ ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint, TableFactor, TableObject, TableOptionsClustered, TableWithJoins, Update, UpdateTableFromKind, Use, Values, ViewColumnDef, - WhileStatement, WildcardAdditionalOptions, With, WithFill, + WhileStatement, WildcardAdditionalOptions, With, WithExpression, WithFill, }; /// Given an iterator of spans, return the [Span::union] of all spans. @@ -185,12 +185,19 @@ impl Spanned for With { let With { with_token, recursive: _, // bool - cte_tables, + exprs, } = self; - union_spans( - core::iter::once(with_token.0.span).chain(cte_tables.iter().map(|item| item.span())), - ) + union_spans(core::iter::once(with_token.0.span).chain(exprs.iter().map(|item| item.span()))) + } +} + +impl Spanned for WithExpression { + fn span(&self) -> Span { + match self { + WithExpression::Cte(cte) => cte.span(), + WithExpression::Cse(cse) => cse.span(), + } } } @@ -2716,8 +2723,12 @@ pub mod tests { ); let query = test.0.parse_query().unwrap(); - let cte_span = query.clone().with.unwrap().cte_tables[0].span(); - let cte_query_span = query.clone().with.unwrap().cte_tables[0].query.span(); + let cte = match &query.with.as_ref().unwrap().exprs[0] { + WithExpression::Cte(cte) => cte, + _ => panic!("expected a CTE"), + }; + let cte_span = cte.span(); + let cte_query_span = cte.query.span(); let body_span = query.body.span(); // the WITH keyboard is part of the query diff --git a/src/dialect/clickhouse.rs b/src/dialect/clickhouse.rs index 6ee60cc993..d3b99b96ac 100644 --- a/src/dialect/clickhouse.rs +++ b/src/dialect/clickhouse.rs @@ -153,4 +153,9 @@ impl Dialect for ClickHouseDialect { fn supports_comma_separated_trim(&self) -> bool { true } + + /// See + fn supports_common_scalar_expressions(&self) -> bool { + true + } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 9b2ede40d2..04ca7359d0 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -1745,6 +1745,18 @@ pub trait Dialect: Debug + Any { false } + /// Returns true if the dialect supports Common Scalar Expressions in a `WITH` clause. + /// + /// For example: + /// ```sql + /// WITH 42 AS answer SELECT answer FROM t + /// ``` + /// + /// [ClickHouse](https://clickhouse.com/docs/sql-reference/statements/select/with#common-scalar-expressions) + fn supports_common_scalar_expressions(&self) -> bool { + false + } + /// Returns true if the dialect supports parenthesized multi-column /// aliases in SELECT items. For example: /// ```sql diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 07497b04f6..3b4e7487e0 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -14105,7 +14105,7 @@ impl<'a> Parser<'a> { Some(With { with_token: with_token.clone().into(), recursive: self.parse_keyword(Keyword::RECURSIVE), - cte_tables: self.parse_comma_separated(Parser::parse_cte)?, + exprs: self.parse_comma_separated(Parser::parse_with_expression)?, }) } else { None @@ -14639,6 +14639,23 @@ impl<'a> Parser<'a> { Ok(cte) } + /// Parse a single expression in a `WITH` clause. + pub fn parse_with_expression(&mut self) -> Result { + if !self.dialect.supports_common_scalar_expressions() { + return self.parse_cte().map(WithExpression::Cte); + } + if let Some(cte) = self.maybe_parse(|p| p.parse_cte())? { + return Ok(WithExpression::Cte(cte)); + } + let expr = self.parse_expr()?; + self.expect_keyword(Keyword::AS)?; + let alias = self.parse_identifier()?; + Ok(WithExpression::Cse(ExprWithAlias { + expr, + alias: Some(alias), + })) + } + /// Parse a "query body", which is an expression with roughly the /// following grammar: /// ```sql diff --git a/tests/sqlparser_clickhouse.rs b/tests/sqlparser_clickhouse.rs index 716a3919fc..6318404adc 100644 --- a/tests/sqlparser_clickhouse.rs +++ b/tests/sqlparser_clickhouse.rs @@ -1845,6 +1845,56 @@ fn parse_inner_array_join() { } } +#[test] +fn parse_with_clause_common_scalar_expression() { + let dialects = all_dialects_where(|d| d.supports_common_scalar_expressions()); + + // Plain literal scalar. + let query = dialects.verified_query("WITH 42 AS answer SELECT answer FROM t"); + let with = query.with.as_ref().unwrap(); + assert!(!with.recursive); + assert_eq!(with.exprs.len(), 1); + match &with.exprs[0] { + WithExpression::Cse(ExprWithAlias { expr, alias }) => { + assert_eq!(alias.as_ref().unwrap().value, "answer"); + assert!(matches!(expr, Expr::Value(_))); + } + other => panic!("expected a common scalar expression, got {other:?}"), + } + + // String literal scalar (from the ClickHouse docs). + dialects.verified_stmt( + "WITH '2019-08-01 15:23:00' AS ts_upper_bound SELECT * FROM hits \ + WHERE EventDate = toDate(ts_upper_bound) AND EventTime <= ts_upper_bound", + ); + + // Aggregate function call as a common scalar expression. + dialects.verified_stmt( + "WITH sum(bytes) AS s SELECT formatReadableSize(s), \"table\" \ + FROM system.parts GROUP BY \"table\" ORDER BY s", + ); + + // Scalar subquery as the bound expression. + dialects.verified_stmt( + "WITH (SELECT sum(bytes) FROM system.parts WHERE active) AS total_disk_usage \ + SELECT (sum(bytes) / total_disk_usage) * 100 AS table_disk_usage, \"table\" \ + FROM system.parts GROUP BY \"table\" ORDER BY table_disk_usage DESC LIMIT 10", + ); + + // Bare-identifier scalar — disambiguation case (`name AS alias` looks like + // a CTE prefix but the missing `(` after `AS` makes it a scalar). + dialects.verified_stmt("WITH user_id AS uid SELECT uid FROM t"); + + // Mixing a common scalar expression with a real CTE in the same WITH list. + dialects.verified_stmt("WITH 1 AS one, cte AS (SELECT 1) SELECT one FROM cte"); + + // Lambda as the bound expression (also taken from the docs). + dialects.verified_stmt( + "WITH '.txt' AS extension, (id, extension) -> concat(lower(id), extension) AS gen_name \ + SELECT gen_name('test', '.sql') AS file_name", + ); +} + fn clickhouse() -> TestedDialects { TestedDialects::new(vec![Box::new(ClickHouseDialect {})]) } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 326fbf678e..9822b34d80 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -7854,18 +7854,21 @@ fn parse_ctes() { fn assert_ctes_in_select(expected: &[&str], sel: &Query) { for (i, exp) in expected.iter().enumerate() { - let Cte { alias, query, .. } = &sel.with.as_ref().unwrap().cte_tables[i]; - assert_eq!(*exp, query.to_string()); - assert_eq!(false, alias.explicit); + let cte = match &sel.with.as_ref().unwrap().exprs[i] { + WithExpression::Cte(cte) => cte, + other => panic!("expected a CTE, got {other:?}"), + }; + assert_eq!(*exp, cte.query.to_string()); + assert_eq!(false, cte.alias.explicit); assert_eq!( if i == 0 { Ident::new("a") } else { Ident::new("b") }, - alias.name + cte.alias.name ); - assert!(alias.columns.is_empty()); + assert!(cte.alias.columns.is_empty()); } } @@ -7898,26 +7901,29 @@ fn parse_ctes() { // CTE in a CTE... let sql = &format!("WITH outer_cte AS ({with}) SELECT * FROM outer_cte"); let select = verified_query(sql); - assert_ctes_in_select(&cte_sqls, &only(&select.with.unwrap().cte_tables).query); + let with = select.with.as_ref().unwrap(); + let outer_cte = match only(&with.exprs) { + WithExpression::Cte(cte) => cte, + other => panic!("expected a CTE, got {other:?}"), + }; + assert_ctes_in_select(&cte_sqls, &outer_cte.query); } #[test] fn parse_cte_renamed_columns() { let sql = "WITH cte (col1, col2) AS (SELECT foo, bar FROM baz) SELECT * FROM cte"; let query = all_dialects().verified_query(sql); + let with = query.with.unwrap(); + let cte = match with.exprs.first().unwrap() { + WithExpression::Cte(cte) => cte, + other => panic!("expected a CTE, got {other:?}"), + }; assert_eq!( vec![ TableAliasColumnDef::from_name("col1"), TableAliasColumnDef::from_name("col2") ], - query - .with - .unwrap() - .cte_tables - .first() - .unwrap() - .alias - .columns + cte.alias.columns ); } @@ -7931,8 +7937,8 @@ fn parse_recursive_cte() { let with = query.with.as_ref().unwrap(); assert!(with.recursive); - assert_eq!(with.cte_tables.len(), 1); - let expected = Cte { + assert_eq!(with.exprs.len(), 1); + let expected = WithExpression::Cte(Cte { alias: TableAlias { explicit: false, name: Ident { @@ -7947,8 +7953,8 @@ fn parse_recursive_cte() { from: None, materialized: None, closing_paren_token: AttachedToken::empty(), - }; - assert_eq!(with.cte_tables.first().unwrap(), &expected); + }); + assert_eq!(with.exprs.first().unwrap(), &expected); } #[test]