|
13 | 13 | // limitations under the License. |
14 | 14 |
|
15 | 15 | use std::cmp::Ordering; |
| 16 | +use std::marker::PhantomData; |
16 | 17 |
|
17 | 18 | use crate::Scalar; |
| 19 | +use crate::property::Domain; |
18 | 20 | use crate::stat_distribution::ArgStat; |
| 21 | +use crate::stat_distribution::BooleanDistribution; |
19 | 22 | use crate::stat_distribution::Ndv; |
| 23 | +use crate::stat_distribution::OwnedDistribution; |
| 24 | +use crate::stat_distribution::ReturnStat; |
20 | 25 | use crate::stat_distribution::StatBinaryArg; |
21 | 26 | use crate::stat_distribution::StatEstimate; |
| 27 | +use crate::types::boolean::BooleanDomain; |
| 28 | +use crate::types::nullable::NullableDomain; |
22 | 29 |
|
23 | 30 | pub trait StatComparisonOp { |
24 | 31 | type Reverse: StatComparisonOp; |
@@ -72,13 +79,13 @@ pub trait StatComparisonOp { |
72 | 79 | } |
73 | 80 | } |
74 | 81 |
|
75 | | -#[derive(Default)] |
| 82 | +#[derive(Default, Clone, Copy)] |
76 | 83 | pub struct LtOp; |
77 | | -#[derive(Default)] |
| 84 | +#[derive(Default, Clone, Copy)] |
78 | 85 | pub struct LteOp; |
79 | | -#[derive(Default)] |
| 86 | +#[derive(Default, Clone, Copy)] |
80 | 87 | pub struct GtOp; |
81 | | -#[derive(Default)] |
| 88 | +#[derive(Default, Clone, Copy)] |
82 | 89 | pub struct GteOp; |
83 | 90 |
|
84 | 91 | impl StatComparisonOp for LtOp { |
@@ -109,72 +116,146 @@ impl StatComparisonOp for GteOp { |
109 | 116 | const INCLUDE_EQUAL: bool = true; |
110 | 117 | } |
111 | 118 |
|
112 | | -pub struct ConstantComparison<'s, 'a> { |
| 119 | +pub struct ConstantComparison<'s, 'a, A: ConstantComparisonAdapter> { |
113 | 120 | pub stat: &'s ArgStat<'a>, |
114 | | - pub constant: Scalar, |
115 | | - pub cardinality: f64, |
| 121 | + pub constant: A::Value, |
| 122 | + pub domain: Option<A::Domain>, |
| 123 | + pub non_null_cardinality: f64, |
| 124 | + pub null_count: u64, |
| 125 | + pub nullable: bool, |
| 126 | + _a: PhantomData<fn(A)>, |
116 | 127 | } |
117 | 128 |
|
118 | | -impl<'s, 'a> ConstantComparison<'s, 'a> { |
119 | | - pub fn from_equality_args(stat: &'s StatBinaryArg<'a>) -> Option<Self> { |
120 | | - Self::from_right_constant(stat).or_else(|| Self::from_left_constant(stat)) |
| 129 | +pub trait ConstantComparisonAdapter { |
| 130 | + type Value; |
| 131 | + type Domain; |
| 132 | + |
| 133 | + fn constant(scalar: Scalar) -> Result<Self::Value, String>; |
| 134 | + |
| 135 | + fn domain(domain: &Domain) -> Result<Self::Domain, String>; |
| 136 | + |
| 137 | + fn compare(left: &Self::Value, right: &Self::Value) -> Ordering; |
| 138 | +} |
| 139 | + |
| 140 | +impl<'s, 'a, A: ConstantComparisonAdapter> ConstantComparison<'s, 'a, A> { |
| 141 | + pub fn from_constant_args(stat: &'s StatBinaryArg<'a>) -> Result<Option<(Self, bool)>, String> { |
| 142 | + if let Some(input) = |
| 143 | + Self::new(&stat.args[0], &stat.args[1], stat.cardinality)?.map(|input| (input, false)) |
| 144 | + { |
| 145 | + return Ok(Some(input)); |
| 146 | + } |
| 147 | + Ok(Self::new(&stat.args[1], &stat.args[0], stat.cardinality)?.map(|input| (input, true))) |
121 | 148 | } |
122 | 149 |
|
123 | | - pub fn from_right_constant(stat: &'s StatBinaryArg<'a>) -> Option<Self> { |
124 | | - Some(Self { |
125 | | - stat: &stat.args[0], |
126 | | - constant: stat.args[1].singleton()?, |
127 | | - cardinality: stat.cardinality, |
128 | | - }) |
| 150 | + fn new( |
| 151 | + stat: &'s ArgStat<'a>, |
| 152 | + constant_stat: &ArgStat<'_>, |
| 153 | + input_cardinality: f64, |
| 154 | + ) -> Result<Option<Self>, String> { |
| 155 | + let Some(constant) = constant_stat.singleton() else { |
| 156 | + return Ok(None); |
| 157 | + }; |
| 158 | + if constant.is_null() { |
| 159 | + return Err( |
| 160 | + "constant comparison null constant was not handled before typed comparison" |
| 161 | + .to_string(), |
| 162 | + ); |
| 163 | + } |
| 164 | + let nullable = stat.domain.is_nullable() || constant_stat.domain.is_nullable(); |
| 165 | + let null_count = stat.null_count.min(input_cardinality.ceil() as u64); |
| 166 | + let non_null_cardinality = (input_cardinality - null_count as f64).max(0.0); |
| 167 | + let domain = match &stat.domain { |
| 168 | + Domain::Nullable(NullableDomain { value: None, .. }) => None, |
| 169 | + Domain::Nullable(NullableDomain { |
| 170 | + value: Some(box domain), |
| 171 | + .. |
| 172 | + }) |
| 173 | + | domain => match A::domain(domain) { |
| 174 | + Ok(domain) => Some(domain), |
| 175 | + Err(err) => { |
| 176 | + return Err(err); |
| 177 | + } |
| 178 | + }, |
| 179 | + }; |
| 180 | + let constant = match A::constant(constant) { |
| 181 | + Ok(constant) => constant, |
| 182 | + Err(err) => { |
| 183 | + return Err(err); |
| 184 | + } |
| 185 | + }; |
| 186 | + |
| 187 | + Ok(Some(Self { |
| 188 | + stat, |
| 189 | + constant, |
| 190 | + domain, |
| 191 | + non_null_cardinality, |
| 192 | + null_count, |
| 193 | + nullable, |
| 194 | + _a: PhantomData, |
| 195 | + })) |
129 | 196 | } |
130 | 197 |
|
131 | | - pub fn from_left_constant(stat: &'s StatBinaryArg<'a>) -> Option<Self> { |
132 | | - Some(Self { |
133 | | - stat: &stat.args[1], |
134 | | - constant: stat.args[0].singleton()?, |
135 | | - cardinality: stat.cardinality, |
136 | | - }) |
| 198 | + pub fn boolean_stat(&self, true_count: StatEstimate) -> ReturnStat { |
| 199 | + let domain = if self.nullable { |
| 200 | + Domain::Nullable(NullableDomain { |
| 201 | + has_null: self.null_count != 0, |
| 202 | + value: Some(Box::new(Domain::Boolean(BooleanDomain { |
| 203 | + has_true: true, |
| 204 | + has_false: true, |
| 205 | + }))), |
| 206 | + }) |
| 207 | + } else { |
| 208 | + Domain::Boolean(BooleanDomain { |
| 209 | + has_true: true, |
| 210 | + has_false: true, |
| 211 | + }) |
| 212 | + }; |
| 213 | + |
| 214 | + ReturnStat { |
| 215 | + domain, |
| 216 | + ndv: Ndv::Stat(2.0), |
| 217 | + null_count: self.null_count, |
| 218 | + distribution: OwnedDistribution::Boolean(BooleanDistribution { true_count }), |
| 219 | + } |
137 | 220 | } |
138 | 221 |
|
139 | | - pub fn equality_true_count( |
| 222 | + pub fn constant_equality_true_count( |
140 | 223 | &self, |
| 224 | + minmax_cmp: Option<(Ordering, Ordering)>, |
141 | 225 | not_eq: bool, |
142 | | - compare: impl Fn(&Scalar, &Scalar) -> Option<Ordering>, |
143 | | - ) -> Option<StatEstimate> { |
144 | | - let Some((min, max)) = self.stat.value_minmax() else { |
145 | | - return Some(StatEstimate::exact(if not_eq { |
146 | | - self.cardinality |
147 | | - } else { |
148 | | - 0.0 |
149 | | - })); |
| 226 | + ) -> StatEstimate { |
| 227 | + let Some((cmp_min, cmp_max)) = minmax_cmp else { |
| 228 | + return estimate_ndv_true_count(self.stat.ndv, not_eq, self.non_null_cardinality); |
150 | 229 | }; |
151 | | - if compare(&self.constant, &min)? == Ordering::Less |
152 | | - || compare(&self.constant, &max)? == Ordering::Greater |
153 | | - { |
154 | | - return Some(StatEstimate::exact(if not_eq { |
155 | | - self.cardinality |
| 230 | + if cmp_min == Ordering::Less || cmp_max == Ordering::Greater { |
| 231 | + return StatEstimate::exact(if not_eq { |
| 232 | + self.non_null_cardinality |
156 | 233 | } else { |
157 | 234 | 0.0 |
158 | | - })); |
| 235 | + }); |
159 | 236 | } |
160 | 237 |
|
161 | | - Some(estimate_ndv_true_count( |
162 | | - self.stat.ndv, |
163 | | - not_eq, |
164 | | - self.cardinality, |
165 | | - )) |
| 238 | + estimate_ndv_true_count(self.stat.ndv, not_eq, self.non_null_cardinality) |
166 | 239 | } |
| 240 | +} |
167 | 241 |
|
168 | | - pub fn minmax_range_true_count<Op: StatComparisonOp>( |
169 | | - &self, |
170 | | - compare: impl Fn(&Scalar, &Scalar) -> Option<Ordering>, |
171 | | - ) -> Option<StatEstimate> { |
172 | | - try { |
173 | | - let (min, max) = self.stat.value_minmax()?; |
174 | | - let cmp_min = compare(&self.constant, &min)?; |
175 | | - let cmp_max = compare(&self.constant, &max)?; |
176 | | - Op::estimate_minmax_range_true_count(self.stat.ndv, self.cardinality, cmp_min, cmp_max)? |
177 | | - } |
| 242 | +pub fn null_comparison_stat(stat: &StatBinaryArg) -> Option<ReturnStat> { |
| 243 | + if stat.args.iter().any(|arg| { |
| 244 | + arg.domain |
| 245 | + .as_singleton() |
| 246 | + .is_some_and(|scalar| scalar.is_null()) |
| 247 | + }) { |
| 248 | + Some(ReturnStat { |
| 249 | + domain: Domain::Nullable(NullableDomain { |
| 250 | + has_null: true, |
| 251 | + value: None, |
| 252 | + }), |
| 253 | + ndv: Ndv::Stat(0.0), |
| 254 | + null_count: stat.cardinality.ceil() as u64, |
| 255 | + distribution: OwnedDistribution::Unknown, |
| 256 | + }) |
| 257 | + } else { |
| 258 | + None |
178 | 259 | } |
179 | 260 | } |
180 | 261 |
|
|
0 commit comments