Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -623,19 +623,26 @@ impl<'a, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'a, '_, T
))
} else {
// handle col syntax
let mut got_column = Self::find_column_in_scope(
&self.context,
&mut self.table_schema_buf,
full_name.1.as_str(),
);
let mut find_visible_column =
|context: &BinderContext<'a, T>| -> Result<Option<ScalarExpression>, DatabaseError> {
Ok(context
.using
.get(full_name.1.as_str())
.map(|using_column| using_column.visible_expr())
.transpose()?
.or_else(|| {
Self::find_column_in_scope(
context,
&mut self.table_schema_buf,
full_name.1.as_str(),
)
}))
};
let mut got_column = find_visible_column(&self.context)?;
if got_column.is_none() {
if let Some(parent) = self.parent {
self.context.mark_outer_ref();
got_column = Self::find_column_in_scope(
&parent.context,
&mut self.table_schema_buf,
full_name.1.as_str(),
);
got_column = find_visible_column(&parent.context)?;
}
}
match got_column {
Expand Down
96 changes: 86 additions & 10 deletions src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use sqlparser::ast::{
Statement, TableObject,
};
use sqlparser::tokenizer::Span;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::collections::{BTreeMap, HashMap};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

Expand All @@ -54,6 +54,7 @@ use crate::planner::{LogicalPlan, SchemaOutput};
use crate::storage::{TableCache, Transaction, ViewCache};
use crate::types::tuple::SchemaRef;
use crate::types::value::DataValue;
use crate::types::LogicalType;

pub enum InputRefType {
AggCall,
Expand Down Expand Up @@ -204,6 +205,69 @@ impl BoundSource<'_> {
}
}

#[derive(Clone)]
pub(crate) struct UsingColumn {
join_type: JoinType,
left_column: ColumnRef,
left_position: usize,
right_column: ColumnRef,
right_position: usize,
}

impl UsingColumn {
fn new(
join_type: JoinType,
left_column: ColumnRef,
left_position: usize,
right_column: ColumnRef,
right_position: usize,
) -> Self {
Self {
join_type,
left_column,
left_position,
right_column,
right_position,
}
}

fn left_expr(&self) -> ScalarExpression {
ScalarExpression::column_expr(self.left_column.clone(), self.left_position)
}

fn right_expr(&self) -> ScalarExpression {
ScalarExpression::column_expr(self.right_column.clone(), self.right_position)
}

pub(crate) fn visible_expr(&self) -> Result<ScalarExpression, DatabaseError> {
match self.join_type {
JoinType::RightOuter => Ok(self.right_expr()),
JoinType::Full => {
let left_expr = self.left_expr();
let right_expr = self.right_expr();
let left_ty = left_expr.return_type();
let right_ty = right_expr.return_type();
let ty = LogicalType::max_logical_type(&left_ty, &right_ty)?.into_owned();

Ok(ScalarExpression::Coalesce {
exprs: vec![left_expr, right_expr],
ty,
})
}
JoinType::Inner | JoinType::LeftOuter | JoinType::Cross => Ok(self.left_expr()),
}
}

pub(crate) fn hides_column(&self, column: &ColumnRef) -> bool {
let hidden_column = if self.join_type.is_right() {
&self.left_column
} else {
&self.right_column
};
hidden_column.same_column(column)
}
}

#[derive(Clone)]
pub struct BinderContext<'a, T: Transaction> {
pub(crate) scala_functions: &'a ScalaFunctions,
Expand All @@ -221,7 +285,7 @@ pub struct BinderContext<'a, T: Transaction> {
group_by_exprs: Vec<ScalarExpression>,
pub(crate) agg_calls: Vec<ScalarExpression>,
// join
using: HashSet<ColumnRef>,
using: HashMap<String, UsingColumn>,

bind_step: QueryBindStep,
sub_queries: HashMap<QueryBindStep, Vec<SubQueryType>>,
Expand Down Expand Up @@ -471,15 +535,27 @@ impl<'a, T: Transaction> BinderContext<'a, T> {

pub fn add_using(
&mut self,
name: String,
join_type: JoinType,
left_expr: &ColumnRef,
right_expr: &ColumnRef,
) {
self.using.insert(if join_type.is_right() {
left_expr.clone()
} else {
right_expr.clone()
});
left_column: &ColumnRef,
left_position: usize,
right_column: &ColumnRef,
right_position: usize,
) -> Result<(), DatabaseError> {
if self.using.contains_key(&name) {
return Err(DatabaseError::UnsupportedStmt(format!(
"duplicate `USING({name})` across joins is not supported"
)));
}
let using_column = UsingColumn::new(
join_type,
left_column.clone(),
left_position,
right_column.clone(),
right_position,
);
self.using.insert(name, using_column);
Ok(())
}

pub fn add_alias(
Expand Down
24 changes: 21 additions & 3 deletions src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,11 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
return Some(&table_name) == column.table_name();
}
is_qualified_wildcard
|| Some(&table_name) == column.table_name() && !context.using.contains(column)
|| Some(&table_name) == column.table_name()
&& !context
.using
.values()
.any(|using_column| using_column.hides_column(column))
};

let (schema_ref, position_offset) =
Expand Down Expand Up @@ -1909,7 +1913,14 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
ident,
));
};
self.context.add_using(join_type, left_column, right_column);
self.context.add_using(
name.clone(),
join_type,
left_column,
left_position,
right_column,
left_schema.len() + right_position,
)?;
on_keys.push((
ScalarExpression::column_expr(left_column.clone(), left_position),
ScalarExpression::column_expr(
Expand Down Expand Up @@ -1951,7 +1962,14 @@ impl<'a: 'b, 'b, T: Transaction, A: AsRef<[(&'static str, DataValue)]>> Binder<'
left_schema.len() + right_position,
);

self.context.add_using(join_type, left_column, right_column);
self.context.add_using(
name.to_string(),
join_type,
left_column,
left_position,
right_column,
left_schema.len() + right_position,
)?;
on_keys.push((left_expr, right_expr));
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/execution/dql/join/nested_loop_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,7 @@ mod test {
}

#[test]
fn test_right_join_using_keeps_left_visible_column_binding() -> Result<(), DatabaseError> {
fn test_right_join_using_binds_visible_column_to_right_side() -> Result<(), DatabaseError> {
let temp_dir = TempDir::new().expect("unable to create temporary working directory");
let db = DataBaseBuilder::path(temp_dir.path()).build_in_memory()?;

Expand Down Expand Up @@ -1216,9 +1216,9 @@ mod test {
Some("A".to_string()),
Some("A".to_string())
],
vec![None, None, Some("B".to_string())],
vec![None, None, Some("C".to_string())],
vec![None, None, Some("E".to_string())],
vec![Some("B".to_string()), None, Some("B".to_string())],
vec![Some("C".to_string()), None, Some("C".to_string())],
vec![Some("E".to_string()), None, Some("E".to_string())],
]
);

Expand Down
40 changes: 32 additions & 8 deletions tests/slt/crdb/join.slt
Original file line number Diff line number Diff line change
Expand Up @@ -705,9 +705,9 @@ query TTT
SELECT s, str1.s, str2.s FROM str1 RIGHT OUTER JOIN str2 USING(s) order by str2.s
----
A A A
null null B
null null C
null null E
B null B
C null C
E null E

query ITIT
SELECT * FROM str1 LEFT OUTER JOIN str2 ON str1.s = str2.s order by str1.a
Expand Down Expand Up @@ -862,6 +862,31 @@ INSERT INTO l VALUES (1, 1), (2, 1), (3, 1)
statement ok
INSERT INTO r VALUES (2, 1), (3, 1), (4, 1)

query III
SELECT a, l.a, r.a FROM l INNER JOIN r USING(a) WHERE a = 2
----
2 2 2

query III
SELECT a, l.a, r.a FROM l LEFT OUTER JOIN r USING(a) WHERE a = 1
----
1 1 null

query III
SELECT a, l.a, r.a FROM l RIGHT OUTER JOIN r USING(a) WHERE a = 4
----
4 null 4

query III
SELECT a, l.a, r.a FROM l FULL OUTER JOIN r USING(a) WHERE a = 1
----
1 1 null

query III
SELECT a, l.a, r.a FROM l FULL OUTER JOIN r USING(a) WHERE a = 4
----
4 null 4

query III
SELECT * FROM l LEFT OUTER JOIN r USING(a) WHERE a = 1
----
Expand All @@ -877,11 +902,10 @@ SELECT * FROM l RIGHT OUTER JOIN r USING(a) WHERE a = 3
----
1 3 1

# TODO: a= 4 means x on both sides
# query III
# SELECT * FROM l RIGHT OUTER JOIN r USING(a) WHERE a = 4
# ----
# NULL 4 1
query III
SELECT * FROM l RIGHT OUTER JOIN r USING(a) WHERE a = 4
----
null 4 1

statement ok
drop table if exists foo
Expand Down
Loading