@@ -35,10 +35,21 @@ pub fn analyze_common_subexpression(
3535 }
3636
3737 let signature_to_exprs = collect_table_signatures ( s_expr, metadata) ;
38+ let mut expr_groups = signature_to_exprs. into_values ( ) . collect :: < Vec < _ > > ( ) ;
39+ // Keep CSE materialization order deterministic by following the first
40+ // occurrence of each candidate group in the plan tree.
41+ expr_groups. sort_by ( |lhs, rhs| lhs[ 0 ] . 0 . cmp ( & rhs[ 0 ] . 0 ) ) ;
3842 let mut replacements = vec ! [ ] ;
3943 let mut materialized_ctes = vec ! [ ] ;
40- for exprs in signature_to_exprs. values ( ) {
41- process_candidate_expressions ( exprs, metadata, & mut replacements, & mut materialized_ctes) ?;
44+ let mut selected_paths = vec ! [ ] ;
45+ for exprs in & expr_groups {
46+ process_candidate_expressions (
47+ exprs,
48+ metadata,
49+ & mut replacements,
50+ & mut materialized_ctes,
51+ & mut selected_paths,
52+ ) ?;
4253 }
4354 Ok ( ( replacements, materialized_ctes) )
4455}
@@ -48,17 +59,22 @@ fn process_candidate_expressions(
4859 metadata : & mut Metadata ,
4960 replacements : & mut Vec < SExprReplacement > ,
5061 materialized_ctes : & mut Vec < SExpr > ,
62+ selected_paths : & mut Vec < Vec < usize > > ,
5163) -> Result < ( ) > {
64+ let candidates = candidates
65+ . iter ( )
66+ . filter ( |( path, _) | {
67+ !selected_paths
68+ . iter ( )
69+ . any ( |selected| paths_overlap ( path, selected) )
70+ } )
71+ . cloned ( )
72+ . collect :: < Vec < _ > > ( ) ;
5273 if candidates. len ( ) < 2 {
5374 return Ok ( ( ) ) ;
5475 }
5576
56- let mut cte_def = candidates[ 0 ] . 1 . clone ( ) ;
57- if let RelOperator :: Scan ( scan) = cte_def. plan . as_ref ( ) {
58- let mut scan = scan. clone ( ) ;
59- scan. scan_id = metadata. next_scan_id ( ) ;
60- cte_def = SExpr :: create_leaf ( Arc :: new ( RelOperator :: Scan ( scan) ) ) ;
61- }
77+ let cte_def = refresh_scan_ids ( & candidates[ 0 ] . 1 , metadata) ?;
6278 let cte_def = Arc :: new ( cte_def) ;
6379
6480 let cte_def_columns = cte_def. derive_relational_prop ( ) ?. output_columns . clone ( ) ;
@@ -83,6 +99,7 @@ fn process_candidate_expressions(
8399 output_columns : cte_ref_columns. iter ( ) . copied ( ) . collect ( ) ,
84100 def : expr. clone ( ) ,
85101 column_mapping,
102+ stat_info : None ,
86103 } ;
87104 let cte_ref_expr = Arc :: new ( SExpr :: create_leaf ( Arc :: new (
88105 RelOperator :: MaterializedCTERef ( cte_ref) ,
@@ -91,14 +108,209 @@ fn process_candidate_expressions(
91108 path : path. clone ( ) ,
92109 new_expr : cte_ref_expr. clone ( ) ,
93110 } ) ;
111+ selected_paths. push ( path) ;
94112 }
95113 Ok ( ( ) )
96114}
97115
116+ #[ recursive:: recursive]
117+ fn refresh_scan_ids ( s_expr : & SExpr , metadata : & mut Metadata ) -> Result < SExpr > {
118+ let new_children = s_expr
119+ . children ( )
120+ . map ( |child| refresh_scan_ids ( child, metadata) )
121+ . collect :: < Result < Vec < _ > > > ( ) ?;
122+
123+ let mut result = if new_children
124+ . iter ( )
125+ . zip ( s_expr. children ( ) )
126+ . any ( |( new, old) | !new. eq ( old) )
127+ {
128+ s_expr. replace_children ( new_children. into_iter ( ) . map ( Arc :: new) )
129+ } else {
130+ s_expr. clone ( )
131+ } ;
132+
133+ if let RelOperator :: Scan ( scan) = result. plan . as_ref ( ) {
134+ let mut scan = scan. clone ( ) ;
135+ scan. scan_id = metadata. next_scan_id ( ) ;
136+ result = result. replace_plan ( Arc :: new ( RelOperator :: Scan ( scan) ) ) ;
137+ }
138+
139+ Ok ( result)
140+ }
141+
142+ fn paths_overlap ( lhs : & [ usize ] , rhs : & [ usize ] ) -> bool {
143+ lhs. starts_with ( rhs) || rhs. starts_with ( lhs)
144+ }
145+
98146fn contains_recursive_cte ( expr : & SExpr ) -> bool {
99147 if matches ! ( expr. plan( ) , RelOperator :: RecursiveCteScan ( _) ) {
100148 return true ;
101149 }
102150
103151 expr. children ( ) . any ( contains_recursive_cte)
104152}
153+
154+ #[ cfg( test) ]
155+ mod tests {
156+ use std:: any:: Any ;
157+
158+ use databend_common_catalog:: table:: Table ;
159+ use databend_common_expression:: TableDataType ;
160+ use databend_common_expression:: TableField ;
161+ use databend_common_expression:: TableSchema ;
162+ use databend_common_expression:: types:: NumberDataType ;
163+ use databend_common_meta_app:: schema:: CatalogInfo ;
164+ use databend_common_meta_app:: schema:: DatabaseType ;
165+ use databend_common_meta_app:: schema:: TableIdent ;
166+ use databend_common_meta_app:: schema:: TableInfo ;
167+ use databend_common_meta_app:: schema:: TableMeta ;
168+
169+ use super :: * ;
170+ use crate :: planner:: metadata:: Metadata ;
171+ use crate :: plans:: Join ;
172+ use crate :: plans:: JoinType ;
173+ use crate :: plans:: RelOperator ;
174+ use crate :: plans:: Scan ;
175+
176+ #[ derive( Debug ) ]
177+ struct FakeTable {
178+ table_info : TableInfo ,
179+ }
180+
181+ #[ async_trait:: async_trait]
182+ impl Table for FakeTable {
183+ fn as_any ( & self ) -> & dyn Any {
184+ self
185+ }
186+
187+ fn get_table_info ( & self ) -> & TableInfo {
188+ & self . table_info
189+ }
190+
191+ fn support_column_projection ( & self ) -> bool {
192+ true
193+ }
194+ }
195+
196+ fn fake_fuse_table ( table_id : u64 , table_name : & str ) -> Arc < dyn Table > {
197+ Arc :: new ( FakeTable {
198+ table_info : TableInfo {
199+ ident : TableIdent :: new ( table_id, 0 ) ,
200+ desc : format ! ( "'default'.'{table_name}'" ) ,
201+ name : table_name. to_string ( ) ,
202+ meta : TableMeta {
203+ schema : Arc :: new ( TableSchema :: new ( vec ! [ TableField :: new(
204+ "a" ,
205+ TableDataType :: Number ( NumberDataType :: UInt64 ) ,
206+ ) ] ) ) ,
207+ engine : "FUSE" . to_string ( ) ,
208+ ..Default :: default ( )
209+ } ,
210+ catalog_info : Arc :: new ( CatalogInfo :: default ( ) ) ,
211+ db_type : DatabaseType :: NormalDB ,
212+ } ,
213+ } )
214+ }
215+
216+ fn add_table ( metadata : & mut Metadata , table : Arc < dyn Table > ) -> usize {
217+ metadata. add_table (
218+ "default" . to_string ( ) ,
219+ "default" . to_string ( ) ,
220+ table,
221+ None ,
222+ None ,
223+ false ,
224+ false ,
225+ false ,
226+ None ,
227+ )
228+ }
229+
230+ fn scan_expr ( metadata : & Metadata , table_index : usize ) -> SExpr {
231+ let columns = metadata
232+ . columns_by_table_index ( table_index)
233+ . into_iter ( )
234+ . map ( |column| column. index ( ) )
235+ . collect ( ) ;
236+ SExpr :: create_leaf ( Arc :: new ( RelOperator :: Scan ( Scan {
237+ table_index,
238+ columns,
239+ ..Default :: default ( )
240+ } ) ) )
241+ }
242+
243+ fn cross_join_expr ( left : SExpr , right : SExpr ) -> SExpr {
244+ SExpr :: create_binary (
245+ Arc :: new ( RelOperator :: Join ( Join {
246+ join_type : JoinType :: Cross ,
247+ ..Default :: default ( )
248+ } ) ) ,
249+ Arc :: new ( left) ,
250+ Arc :: new ( right) ,
251+ )
252+ }
253+
254+ #[ test]
255+ fn test_analyze_common_subexpression_prefers_cross_join_subtree ( ) {
256+ let mut metadata = Metadata :: default ( ) ;
257+ let t1 = fake_fuse_table ( 1 , "t1" ) ;
258+ let t2 = fake_fuse_table ( 2 , "t2" ) ;
259+
260+ let t1_left = add_table ( & mut metadata, t1. clone ( ) ) ;
261+ let t2_left = add_table ( & mut metadata, t2. clone ( ) ) ;
262+ let t1_right = add_table ( & mut metadata, t1) ;
263+ let t2_right = add_table ( & mut metadata, t2) ;
264+
265+ let left = cross_join_expr ( scan_expr ( & metadata, t1_left) , scan_expr ( & metadata, t2_left) ) ;
266+ let right = cross_join_expr (
267+ scan_expr ( & metadata, t1_right) ,
268+ scan_expr ( & metadata, t2_right) ,
269+ ) ;
270+ let root = cross_join_expr ( left, right) ;
271+
272+ let ( replacements, materialized_ctes) =
273+ analyze_common_subexpression ( & root, & mut metadata) . unwrap ( ) ;
274+
275+ assert_eq ! ( replacements. len( ) , 2 ) ;
276+ assert_eq ! ( materialized_ctes. len( ) , 1 ) ;
277+
278+ let cte_def = materialized_ctes[ 0 ] . child ( 0 ) . unwrap ( ) ;
279+ let RelOperator :: Join ( join) = cte_def. plan ( ) else {
280+ panic ! (
281+ "expected cross join materialized cte, got {:?}" ,
282+ cte_def. plan( )
283+ ) ;
284+ } ;
285+ assert_eq ! ( join. join_type, JoinType :: Cross ) ;
286+ }
287+
288+ #[ test]
289+ fn test_analyze_common_subexpression_keeps_cross_join_operand_order ( ) {
290+ let mut metadata = Metadata :: default ( ) ;
291+ let t1 = fake_fuse_table ( 1 , "t1" ) ;
292+ let t2 = fake_fuse_table ( 2 , "t2" ) ;
293+
294+ let t1_left = add_table ( & mut metadata, t1. clone ( ) ) ;
295+ let t2_left = add_table ( & mut metadata, t2. clone ( ) ) ;
296+ let t1_right = add_table ( & mut metadata, t1) ;
297+ let t2_right = add_table ( & mut metadata, t2) ;
298+
299+ let left = cross_join_expr ( scan_expr ( & metadata, t1_left) , scan_expr ( & metadata, t2_left) ) ;
300+ let right = cross_join_expr (
301+ scan_expr ( & metadata, t2_right) ,
302+ scan_expr ( & metadata, t1_right) ,
303+ ) ;
304+ let root = cross_join_expr ( left, right) ;
305+
306+ let ( _replacements, materialized_ctes) =
307+ analyze_common_subexpression ( & root, & mut metadata) . unwrap ( ) ;
308+
309+ assert_eq ! ( materialized_ctes. len( ) , 2 ) ;
310+ assert ! (
311+ materialized_ctes
312+ . iter( )
313+ . all( |cte| matches!( cte. child( 0 ) . unwrap( ) . plan( ) , RelOperator :: Scan ( _) ) )
314+ ) ;
315+ }
316+ }
0 commit comments