Skip to content

Commit 28bb20b

Browse files
authored
fix(functions): add hints for mismatched function args (#19735)
* fix(functions): add hints for mismatched args (#14026) * fix(functions): preserve mismatch error codes * fix(expression): avoid false function mismatch hints * test: align pretty_error cluster expectation * fix(expression): narrow mismatch hint suppression * test: refresh pretty_error hint output * test: align pretty_error standalone output * fix(expression): simplify cast mismatch hints * Revert "fix(expression): simplify cast mismatch hints" This reverts commit 4fdb604. * fix(expression): concretize mismatch hint candidates * test: refresh pretty_error concrete candidates
1 parent 73814d2 commit 28bb20b

8 files changed

Lines changed: 305 additions & 39 deletions

File tree

src/query/expression/src/evaluator.rs

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use crate::expr::*;
3434
use crate::expression::Expr;
3535
use crate::function::EvalContext;
3636
use crate::type_check::check_function;
37+
use crate::type_check::format_function_argument_mismatch_hint;
3738
use crate::type_check::get_simple_cast_function;
3839
use crate::types::BooleanType;
3940
use crate::types::DataType;
@@ -293,10 +294,21 @@ impl<'a> Evaluator<'a> {
293294
child_option.strict_eval = false;
294295
}
295296

296-
let args = args
297-
.iter()
298-
.map(|expr| self.partial_run(expr, validity.clone(), &mut child_option))
299-
.collect::<Result<Vec<_>>>()?;
297+
let mut args_value = Vec::with_capacity(args.len());
298+
for expr in args {
299+
match self.partial_run(expr, validity.clone(), &mut child_option) {
300+
Ok(value) => args_value.push(value),
301+
Err(err) => {
302+
return Err(self.attach_function_argument_mismatch_hint(
303+
err,
304+
&function.signature.name,
305+
id.params(),
306+
args,
307+
));
308+
}
309+
}
310+
}
311+
let args = args_value;
300312

301313
assert!(
302314
args.iter()
@@ -342,6 +354,37 @@ impl<'a> Evaluator<'a> {
342354
Ok(result)
343355
}
344356

357+
fn attach_function_argument_mismatch_hint(
358+
&self,
359+
err: ErrorCode,
360+
name: &str,
361+
params: &[Scalar],
362+
args: &[Expr],
363+
) -> ErrorCode {
364+
let mut has_top_level_cast = false;
365+
let hint_args = args
366+
.iter()
367+
.map(|arg| match arg {
368+
Expr::Cast(Cast { expr, .. }) => {
369+
has_top_level_cast = true;
370+
expr.as_ref().clone()
371+
}
372+
_ => arg.clone(),
373+
})
374+
.collect::<Vec<_>>();
375+
if !has_top_level_cast {
376+
return err;
377+
}
378+
379+
let Some(hint) =
380+
format_function_argument_mismatch_hint(name, params, &hint_args, self.fn_registry)
381+
else {
382+
return err;
383+
};
384+
385+
err.add_message_back(format!("\n\nhint: {hint}"))
386+
}
387+
345388
pub fn run_cast(
346389
&self,
347390
span: Span,

src/query/expression/src/type_check.rs

Lines changed: 116 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ use crate::types::decimal::DecimalSize;
4141
use crate::types::i256;
4242
use crate::visit_expr;
4343

44+
const MAX_FUNCTION_MISMATCH_HINT_CANDIDATES: usize = 10;
45+
4446
#[recursive::recursive]
4547
pub fn check<Index: ColumnIndex>(
4648
expr: &RawExpr<Index>,
@@ -392,22 +394,7 @@ pub fn check_function<Index: ColumnIndex>(
392394
return Ok(checked_candidates.pop().unwrap().0);
393395
}
394396

395-
let mut msg = if params.is_empty() {
396-
format!(
397-
"no function matches signature `{name}({})`, you might need to add explicit type casts.",
398-
args.iter()
399-
.map(|arg| arg.data_type().to_string())
400-
.join(", ")
401-
)
402-
} else {
403-
format!(
404-
"no function matches signature `{name}({})({})`, you might need to add explicit type casts.",
405-
params.iter().join(", "),
406-
args.iter()
407-
.map(|arg| arg.data_type().to_string())
408-
.join(", ")
409-
)
410-
};
397+
let mut msg = format_no_matching_function_signature(name, params, args);
411398

412399
if !candidates.is_empty() {
413400
let candidates_sig: Vec<_> = candidates
@@ -426,21 +413,128 @@ pub fn check_function<Index: ColumnIndex>(
426413
.map(|(sig, err)| format!(" {sig:<max_len$} : {}", err.message()))
427414
.join("\n");
428415

429-
let shorten_msg = if candidates_len > take_len {
430-
format!("\n... and {} more", candidates_len - take_len)
431-
} else {
432-
"".to_string()
433-
};
434416
write!(
435417
&mut msg,
436-
"\n\ncandidate functions:\n{candidates_fail_reason}{shorten_msg}",
418+
"\n\ncandidate functions:\n{candidates_fail_reason}",
437419
)
438420
.unwrap();
421+
if candidates_len > take_len {
422+
write!(&mut msg, "\n... and {} more", candidates_len - take_len).unwrap();
423+
}
439424
};
440425

441426
Err(ErrorCode::SemanticError(msg).set_span(span))
442427
}
443428

429+
fn format_function_signature<Index: ColumnIndex>(
430+
name: &str,
431+
params: &[Scalar],
432+
args: &[Expr<Index>],
433+
) -> String {
434+
let args = args
435+
.iter()
436+
.map(|arg| arg.data_type().to_string())
437+
.collect::<Vec<_>>();
438+
format_function_signature_from_arg_types(name, params, &args)
439+
}
440+
441+
fn format_function_signature_from_arg_types(
442+
name: &str,
443+
params: &[Scalar],
444+
args: &[String],
445+
) -> String {
446+
let args = args.iter().join(", ");
447+
if params.is_empty() {
448+
format!("{name}({args})")
449+
} else {
450+
format!("{name}({})({args})", params.iter().join(", "))
451+
}
452+
}
453+
454+
fn format_no_matching_function_signature<Index: ColumnIndex>(
455+
name: &str,
456+
params: &[Scalar],
457+
args: &[Expr<Index>],
458+
) -> String {
459+
format!(
460+
"no function matches signature `{}`, you might need to add explicit type casts.",
461+
format_function_signature(name, params, args)
462+
)
463+
}
464+
465+
pub fn format_function_argument_mismatch_hint<Index: ColumnIndex>(
466+
name: &str,
467+
params: &[Scalar],
468+
args: &[Expr<Index>],
469+
fn_registry: &FunctionRegistry,
470+
) -> Option<String> {
471+
let name = fn_registry
472+
.aliases
473+
.get(name)
474+
.map(|name| name.as_str())
475+
.unwrap_or(name);
476+
477+
let candidates = fn_registry.search_candidates(name, params, args);
478+
if candidates.is_empty() {
479+
return None;
480+
}
481+
482+
let auto_cast_rules = fn_registry.get_auto_cast_rules(name);
483+
let dynamic_cast_rules = fn_registry.get_dynamic_cast_rules(name);
484+
let mut concrete_candidates_sig = Vec::new();
485+
for (_, func) in &candidates {
486+
let Ok((checked_args, return_type, _)) = try_check_function(
487+
args,
488+
&func.signature,
489+
auto_cast_rules,
490+
&dynamic_cast_rules,
491+
fn_registry,
492+
) else {
493+
continue;
494+
};
495+
496+
if checked_args == args {
497+
return None;
498+
}
499+
500+
concrete_candidates_sig.push(format!(
501+
"{} :: {}",
502+
format_function_signature(name, params, &checked_args),
503+
return_type
504+
));
505+
}
506+
507+
let mut msg = format_no_matching_function_signature(name, params, args);
508+
509+
let candidates_sig = if concrete_candidates_sig.is_empty() {
510+
candidates
511+
.iter()
512+
.map(|(_, func)| func.signature.to_string())
513+
.collect()
514+
} else {
515+
concrete_candidates_sig
516+
};
517+
let (mut candidates_sig, nullable_candidates_sig): (Vec<_>, Vec<_>) = candidates_sig
518+
.into_iter()
519+
.unique()
520+
.partition(|sig| !sig.contains("NULL"));
521+
candidates_sig.extend(nullable_candidates_sig);
522+
523+
let candidates_len = candidates_sig.len();
524+
let take_len = candidates_len.min(MAX_FUNCTION_MISMATCH_HINT_CANDIDATES);
525+
let candidates_sig = candidates_sig
526+
.into_iter()
527+
.take(take_len)
528+
.map(|sig| format!(" {sig}"))
529+
.join("\n");
530+
write!(&mut msg, "\n\ncandidate functions:\n{candidates_sig}").unwrap();
531+
if candidates_len > take_len {
532+
write!(&mut msg, "\n... and {} more", candidates_len - take_len).unwrap();
533+
}
534+
535+
Some(msg)
536+
}
537+
444538
#[derive(Debug)]
445539
pub struct Substitution(pub HashMap<usize, DataType>);
446540

src/query/functions/tests/it/scalars/testdata/array.txt

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@ error:
1111
--> SQL:1:2
1212
|
1313
1 | ['a', 1]
14-
| ^^^ bad decimal literal: unexpected char while evaluating function `to_decimal(38, 5)('a')` in expr `CAST('a' AS Decimal(38, 5))`, during run expr: `array(CAST('a' AS Decimal(38, 5)), CAST(1 AS Decimal(38, 5)))`
14+
| ^^^ bad decimal literal: unexpected char while evaluating function `to_decimal(38, 5)('a')` in expr `CAST('a' AS Decimal(38, 5))`
15+
16+
17+
hint: no function matches signature `array(String, UInt8)`, you might need to add explicit type casts.
18+
19+
candidate functions:
20+
array(Decimal(38, 5), Decimal(38, 5)) :: Array(Decimal(38, 5)), during run expr: `array(CAST('a' AS Decimal(38, 5)), CAST(1 AS Decimal(38, 5)))`
1521

1622

1723

@@ -157,7 +163,14 @@ error:
157163
--> SQL:1:8
158164
|
159165
1 | [1, 2]['a']
160-
| ^^^ invalid digit found in string while evaluating function `to_uint64('a')` in expr `CAST('a' AS UInt64)`, during run expr: `get(CAST(array(1, 2) AS Array(UInt8 NULL)), CAST('a' AS UInt64))`
166+
| ^^^ invalid digit found in string while evaluating function `to_uint64('a')` in expr `CAST('a' AS UInt64)`
167+
168+
169+
hint: no function matches signature `get(Array(UInt8), String)`, you might need to add explicit type casts.
170+
171+
candidate functions:
172+
get(Array(UInt8 NULL), UInt64) :: UInt8 NULL
173+
get(Array(UInt8 NULL) NULL, UInt64 NULL) :: UInt8 NULL, during run expr: `get(CAST(array(1, 2) AS Array(UInt8 NULL)), CAST('a' AS UInt64))`
161174

162175

163176

@@ -805,7 +818,14 @@ error:
805818
--> SQL:1:23
806819
|
807820
1 | array_concat([1,2,3], ['s', null])
808-
| ^^^^^^^^^^^ bad decimal literal: unexpected char while evaluating function `to_decimal(38, 5)('s')` in expr `CAST(array(CAST('s' AS String NULL), CAST(NULL AS String NULL)) AS Array(Decimal(38, 5) NULL))`, during run expr: `array_concat(CAST(array(1, 2, 3) AS Array(Decimal(38, 5) NULL)), CAST(array(CAST('s' AS String NULL), CAST(NULL AS String NULL)) AS Array(Decimal(38, 5) NULL)))`
821+
| ^^^^^^^^^^^ bad decimal literal: unexpected char while evaluating function `to_decimal(38, 5)('s')` in expr `CAST(array(CAST('s' AS String NULL), CAST(NULL AS String NULL)) AS Array(Decimal(38, 5) NULL))`
822+
823+
824+
hint: no function matches signature `array_concat(Array(UInt8), Array(String NULL))`, you might need to add explicit type casts.
825+
826+
candidate functions:
827+
array_concat(Array(Decimal(38, 5) NULL), Array(Decimal(38, 5) NULL)) :: Array(Decimal(38, 5) NULL)
828+
array_concat(Array(Decimal(38, 5) NULL) NULL, Array(Decimal(38, 5) NULL) NULL) :: Array(Decimal(38, 5) NULL) NULL, during run expr: `array_concat(CAST(array(1, 2, 3) AS Array(Decimal(38, 5) NULL)), CAST(array(CAST('s' AS String NULL), CAST(NULL AS String NULL)) AS Array(Decimal(38, 5) NULL)))`
809829

810830

811831

@@ -1058,7 +1078,20 @@ error:
10581078
--> SQL:1:22
10591079
|
10601080
1 | array_indexof([1,2,3,'s'], 's')
1061-
| ^^^ bad decimal literal: unexpected char while evaluating function `to_decimal(38, 5)('s')` in expr `CAST('s' AS Decimal(38, 5))`, during run expr: `array_indexof(array(CAST(1 AS Decimal(38, 5)), CAST(2 AS Decimal(38, 5)), CAST(3 AS Decimal(38, 5)), CAST('s' AS Decimal(38, 5))), CAST('s' AS Decimal(38, 5)))`
1081+
| ^^^ bad decimal literal: unexpected char while evaluating function `to_decimal(38, 5)('s')` in expr `CAST('s' AS Decimal(38, 5))`
1082+
1083+
1084+
hint: no function matches signature `array(UInt8, UInt8, UInt8, String)`, you might need to add explicit type casts.
1085+
1086+
candidate functions:
1087+
array(Decimal(38, 5), Decimal(38, 5), Decimal(38, 5), Decimal(38, 5)) :: Array(Decimal(38, 5))
1088+
1089+
1090+
hint: no function matches signature `array_indexof(Array(Decimal(38, 5)), String)`, you might need to add explicit type casts.
1091+
1092+
candidate functions:
1093+
array_indexof(Array(Decimal(38, 5)), Decimal(38, 5)) :: UInt64
1094+
array_indexof(Array(Decimal(38, 5)) NULL, Decimal(38, 5) NULL) :: UInt64 NULL, during run expr: `array_indexof(array(CAST(1 AS Decimal(38, 5)), CAST(2 AS Decimal(38, 5)), CAST(3 AS Decimal(38, 5)), CAST('s' AS Decimal(38, 5))), CAST('s' AS Decimal(38, 5)))`
10621095

10631096

10641097

src/query/functions/tests/it/scalars/testdata/boolean.txt

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,29 @@ error:
6565
--> SQL:1:1
6666
|
6767
1 | 'a' and 1
68-
| ^^^ cannot parse to type `BOOLEAN` while evaluating function `to_boolean('a')` in expr `CAST('a' AS Boolean)`, during run expr: `(CAST('a' AS Boolean) AND CAST(1 AS Boolean))`
68+
| ^^^ cannot parse to type `BOOLEAN` while evaluating function `to_boolean('a')` in expr `CAST('a' AS Boolean)`
69+
70+
71+
hint: no function matches signature `and(String, UInt8)`, you might need to add explicit type casts.
72+
73+
candidate functions:
74+
and(Boolean, Boolean) :: Boolean
75+
and(Boolean NULL, Boolean NULL) :: Boolean NULL, during run expr: `(CAST('a' AS Boolean) AND CAST(1 AS Boolean))`
6976

7077

7178

7279
error:
7380
--> SQL:1:9
7481
|
7582
1 | NOT NOT 'a'
76-
| ^^^ cannot parse to type `BOOLEAN` while evaluating function `to_boolean('a')` in expr `CAST('a' AS Boolean)`, during run expr: `NOT NOT CAST('a' AS Boolean)`
83+
| ^^^ cannot parse to type `BOOLEAN` while evaluating function `to_boolean('a')` in expr `CAST('a' AS Boolean)`
84+
85+
86+
hint: no function matches signature `not(String)`, you might need to add explicit type casts.
87+
88+
candidate functions:
89+
not(Boolean) :: Boolean
90+
not(Boolean NULL) :: Boolean NULL, during run expr: `NOT NOT CAST('a' AS Boolean)`
7791

7892

7993

src/query/functions/tests/it/scalars/testdata/geo_h3.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,14 @@ error:
791791
--> SQL:1:14
792792
|
793793
1 | h3_to_string('8d11aa6a38826ff')
794-
| ^^^^^^^^^^^^^^^^^ invalid digit found in string while evaluating function `to_uint64('8d11aa6a38826ff')` in expr `CAST('8d11aa6a38826ff' AS UInt64)`, during run expr: `h3_to_string(CAST('8d11aa6a38826ff' AS UInt64))`
794+
| ^^^^^^^^^^^^^^^^^ invalid digit found in string while evaluating function `to_uint64('8d11aa6a38826ff')` in expr `CAST('8d11aa6a38826ff' AS UInt64)`
795+
796+
797+
hint: no function matches signature `h3_to_string(String)`, you might need to add explicit type casts.
798+
799+
candidate functions:
800+
h3_to_string(UInt64) :: String
801+
h3_to_string(UInt64 NULL) :: String NULL, during run expr: `h3_to_string(CAST('8d11aa6a38826ff' AS UInt64))`
795802

796803

797804

0 commit comments

Comments
 (0)