1111// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212// See the License for the specific language governing permissions and
1313// limitations under the License.
14+ use std:: collections:: BTreeSet ;
1415
15-
16- use databend_common_expression:: type_check:: check;
17- use databend_common_expression:: types:: DataType ;
18- use databend_common_expression:: types:: NumberDataType ;
19- use databend_common_expression:: types:: NumberScalar ;
2016use databend_common_expression:: DataField ;
2117use databend_common_expression:: DataSchemaRefExt ;
2218use databend_common_expression:: Expr ;
2319use databend_common_expression:: RawExpr ;
2420use databend_common_expression:: Scalar ;
21+ use databend_common_expression:: type_check:: check;
22+ use databend_common_expression:: types:: DataType ;
23+ use databend_common_expression:: types:: NumberDataType ;
24+ use databend_common_expression:: types:: NumberScalar ;
2525use databend_common_functions:: BUILTIN_FUNCTIONS ;
26- use databend_common_sql:: evaluator:: apply_cse;
2726use databend_common_sql:: evaluator:: BlockOperator ;
28- use databend_common_sql:: optimizer :: ir :: ColumnSet ;
27+ use databend_common_sql:: evaluator :: apply_cse ;
2928use itertools:: Itertools ;
3029
30+ fn count_function_calls ( expr : & Expr , fn_name : & str ) -> usize {
31+ match expr {
32+ Expr :: FunctionCall ( call) => {
33+ usize:: from ( call. function . signature . name == fn_name)
34+ + call
35+ . args
36+ . iter ( )
37+ . map ( |arg| count_function_calls ( arg, fn_name) )
38+ . sum :: < usize > ( )
39+ }
40+ Expr :: LambdaFunctionCall ( call) => call
41+ . args
42+ . iter ( )
43+ . map ( |arg| count_function_calls ( arg, fn_name) )
44+ . sum ( ) ,
45+ Expr :: Cast ( cast) => count_function_calls ( & cast. expr , fn_name) ,
46+ Expr :: Constant ( _) | Expr :: ColumnRef ( _) => 0 ,
47+ }
48+ }
49+
3150#[ test]
3251fn test_cse ( ) {
3352 let schema = DataSchemaRefExt :: create ( vec ! [ DataField :: new(
@@ -36,7 +55,7 @@ fn test_cse() {
3655 ) ] ) ;
3756
3857 // a + 1, (a + 1) *2
39- let exprs = vec ! [
58+ let exprs = [
4059 RawExpr :: FunctionCall {
4160 span : None ,
4261 name : "plus" . to_string ( ) ,
@@ -45,7 +64,7 @@ fn test_cse() {
4564 RawExpr :: ColumnRef {
4665 span: None ,
4766 id: 0usize ,
48- data_type: schema. field( 0 ) . data_type( ) ,
67+ data_type: schema. field( 0 ) . data_type( ) . clone ( ) ,
4968 display_name: schema. field( 0 ) . name( ) . clone( ) ,
5069 } ,
5170 RawExpr :: Constant {
@@ -68,7 +87,7 @@ fn test_cse() {
6887 RawExpr :: ColumnRef {
6988 span: None ,
7089 id: 0usize ,
71- data_type: schema. field( 0 ) . data_type( ) ,
90+ data_type: schema. field( 0 ) . data_type( ) . clone ( ) ,
7291 display_name: schema. field( 0 ) . name( ) . clone( ) ,
7392 } ,
7493 RawExpr :: Constant {
@@ -92,7 +111,7 @@ fn test_cse() {
92111 . map ( |expr| check ( expr, & BUILTIN_FUNCTIONS ) . unwrap ( ) )
93112 . collect ( ) ;
94113
95- let mut projections = ColumnSet :: new ( ) ;
114+ let mut projections = BTreeSet :: new ( ) ;
96115 projections. insert ( 1 ) ;
97116 projections. insert ( 2 ) ;
98117 let operators = vec ! [ BlockOperator :: Map {
@@ -123,3 +142,82 @@ fn test_cse() {
123142 _ => unreachable ! ( ) ,
124143 }
125144}
145+
146+ #[ test]
147+ fn test_cse_parse_json_reuse ( ) {
148+ let schema = DataSchemaRefExt :: create ( vec ! [ DataField :: new( "repo" , DataType :: String ) ] ) ;
149+
150+ let parse_repo = || RawExpr :: FunctionCall {
151+ span : None ,
152+ name : "parse_json" . to_string ( ) ,
153+ params : vec ! [ ] ,
154+ args : vec ! [ RawExpr :: ColumnRef {
155+ span: None ,
156+ id: 0usize ,
157+ data_type: schema. field( 0 ) . data_type( ) . clone( ) ,
158+ display_name: schema. field( 0 ) . name( ) . clone( ) ,
159+ } ] ,
160+ } ;
161+
162+ let exprs = [
163+ RawExpr :: FunctionCall {
164+ span : None ,
165+ name : "get" . to_string ( ) ,
166+ params : vec ! [ ] ,
167+ args : vec ! [ parse_repo( ) , RawExpr :: Constant {
168+ span: None ,
169+ scalar: Scalar :: String ( "name" . to_string( ) ) ,
170+ data_type: None ,
171+ } ] ,
172+ } ,
173+ RawExpr :: FunctionCall {
174+ span : None ,
175+ name : "get" . to_string ( ) ,
176+ params : vec ! [ ] ,
177+ args : vec ! [ parse_repo( ) , RawExpr :: Constant {
178+ span: None ,
179+ scalar: Scalar :: String ( "url" . to_string( ) ) ,
180+ data_type: None ,
181+ } ] ,
182+ } ,
183+ ] ;
184+
185+ let exprs: Vec < Expr > = exprs
186+ . iter ( )
187+ . map ( |expr| check ( expr, & BUILTIN_FUNCTIONS ) . unwrap ( ) )
188+ . collect ( ) ;
189+
190+ let mut projections = BTreeSet :: new ( ) ;
191+ projections. insert ( 1 ) ;
192+ projections. insert ( 2 ) ;
193+ let operators = vec ! [ BlockOperator :: Map {
194+ exprs,
195+ projections: Some ( projections) ,
196+ } ] ;
197+
198+ let mut operators = apply_cse ( operators, 1 ) ;
199+
200+ assert_eq ! ( operators. len( ) , 1 ) ;
201+
202+ match operators. pop ( ) . unwrap ( ) {
203+ BlockOperator :: Map { exprs, projections } => {
204+ assert_eq ! ( exprs. len( ) , 3 ) ;
205+ assert_eq ! (
206+ exprs
207+ . iter( )
208+ . map( |expr| count_function_calls( expr, "parse_json" ) )
209+ . sum:: <usize >( ) ,
210+ 1
211+ ) ;
212+ assert_eq ! (
213+ projections
214+ . unwrap( )
215+ . into_iter( )
216+ . sorted( )
217+ . collect:: <Vec <_>>( ) ,
218+ vec![ 2 , 3 ]
219+ ) ;
220+ }
221+ _ => unreachable ! ( ) ,
222+ }
223+ }
0 commit comments