1use std::collections::{BTreeMap, BTreeSet};
18
19use arrow::array::{make_array, ArrayData, ArrayRef};
20use common_error::ext::BoxedError;
21use datatypes::prelude::{ConcreteDataType, DataType};
22use datatypes::value::Value;
23use datatypes::vectors::{BooleanVector, Helper, VectorRef};
24use dfir_rs::lattices::cc_traits::Iter;
25use itertools::Itertools;
26use snafu::{ensure, OptionExt, ResultExt};
27
28use crate::error::{
29 DatafusionSnafu, Error, InvalidQuerySnafu, UnexpectedSnafu, UnsupportedTemporalFilterSnafu,
30};
31use crate::expr::error::{
32 ArrowSnafu, DataTypeSnafu, EvalError, InvalidArgumentSnafu, OptimizeSnafu, TypeMismatchSnafu,
33};
34use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc};
35use crate::expr::{Batch, DfScalarFunction};
36use crate::repr::ColumnType;
37#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Hash)]
39pub struct TypedExpr {
40 pub expr: ScalarExpr,
42 pub typ: ColumnType,
44}
45
46impl TypedExpr {
47 pub fn new(expr: ScalarExpr, typ: ColumnType) -> Self {
48 Self { expr, typ }
49 }
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
54pub enum ScalarExpr {
55 Column(usize),
57 Literal(Value, ConcreteDataType),
60 CallUnmaterializable(UnmaterializableFunc),
65 CallUnary {
66 func: UnaryFunc,
67 expr: Box<ScalarExpr>,
68 },
69 CallBinary {
70 func: BinaryFunc,
71 expr1: Box<ScalarExpr>,
72 expr2: Box<ScalarExpr>,
73 },
74 CallVariadic {
75 func: VariadicFunc,
76 exprs: Vec<ScalarExpr>,
77 },
78 CallDf {
79 df_scalar_fn: DfScalarFunction,
82 exprs: Vec<ScalarExpr>,
83 },
84 If {
91 cond: Box<ScalarExpr>,
92 then: Box<ScalarExpr>,
93 els: Box<ScalarExpr>,
94 },
95}
96
97impl ScalarExpr {
98 pub fn with_type(self, typ: ColumnType) -> TypedExpr {
99 TypedExpr::new(self, typ)
100 }
101
102 pub fn typ(&self, context: &[ColumnType]) -> Result<ColumnType, Error> {
104 match self {
105 ScalarExpr::Column(i) => context.get(*i).cloned().ok_or_else(|| {
106 UnexpectedSnafu {
107 reason: format!("column index {} out of range of len={}", i, context.len()),
108 }
109 .build()
110 }),
111 ScalarExpr::Literal(_, typ) => Ok(ColumnType::new_nullable(typ.clone())),
112 ScalarExpr::CallUnmaterializable(func) => {
113 Ok(ColumnType::new_nullable(func.signature().output))
114 }
115 ScalarExpr::CallUnary { func, .. } => {
116 Ok(ColumnType::new_nullable(func.signature().output))
117 }
118 ScalarExpr::CallBinary { func, .. } => {
119 Ok(ColumnType::new_nullable(func.signature().output))
120 }
121 ScalarExpr::CallVariadic { func, .. } => {
122 Ok(ColumnType::new_nullable(func.signature().output))
123 }
124 ScalarExpr::If { then, .. } => then.typ(context),
125 ScalarExpr::CallDf { df_scalar_fn, .. } => {
126 let arrow_typ = df_scalar_fn
127 .fn_impl
128 .data_type(df_scalar_fn.df_schema.as_arrow())
130 .context({
131 DatafusionSnafu {
132 context: "Failed to get data type from datafusion scalar function",
133 }
134 })?;
135 let typ = ConcreteDataType::try_from(&arrow_typ)
136 .map_err(BoxedError::new)
137 .context(crate::error::ExternalSnafu)?;
138 Ok(ColumnType::new_nullable(typ))
139 }
140 }
141 }
142}
143
144impl ScalarExpr {
145 pub fn cast(self, typ: ConcreteDataType) -> Self {
146 ScalarExpr::CallUnary {
147 func: UnaryFunc::Cast(typ),
148 expr: Box::new(self),
149 }
150 }
151
152 pub fn optimize(&mut self) {
154 self.flatten_variadic_fn();
155 }
156
157 fn flatten_variadic_fn(&mut self) {
160 if let ScalarExpr::CallVariadic { func, exprs } = self {
161 let mut new_exprs = vec![];
162 for expr in std::mem::take(exprs) {
163 if let ScalarExpr::CallVariadic {
164 func: inner_func,
165 exprs: mut inner_exprs,
166 } = expr
167 {
168 if *func == inner_func {
169 for inner_expr in inner_exprs.iter_mut() {
170 inner_expr.flatten_variadic_fn();
171 }
172 new_exprs.extend(inner_exprs);
173 }
174 } else {
175 new_exprs.push(expr);
176 }
177 }
178 *exprs = new_exprs;
179 }
180 }
181}
182
183impl ScalarExpr {
184 pub fn call_unary(self, func: UnaryFunc) -> Self {
186 ScalarExpr::CallUnary {
187 func,
188 expr: Box::new(self),
189 }
190 }
191
192 pub fn call_binary(self, other: Self, func: BinaryFunc) -> Self {
194 ScalarExpr::CallBinary {
195 func,
196 expr1: Box::new(self),
197 expr2: Box::new(other),
198 }
199 }
200
201 pub fn eval_batch(&self, batch: &Batch) -> Result<VectorRef, EvalError> {
202 match self {
203 ScalarExpr::Column(i) => Ok(batch.batch()[*i].clone()),
204 ScalarExpr::Literal(val, dt) => Ok(Helper::try_from_scalar_value(
205 val.try_to_scalar_value(dt).context(DataTypeSnafu {
206 msg: "Failed to convert literal to scalar value",
207 })?,
208 batch.row_count(),
209 )
210 .context(DataTypeSnafu {
211 msg: "Failed to convert scalar value to vector ref when parsing literal",
212 })?),
213 ScalarExpr::CallUnmaterializable(_) => OptimizeSnafu {
214 reason: "Can't eval unmaterializable function",
215 }
216 .fail()?,
217 ScalarExpr::CallUnary { func, expr } => func.eval_batch(batch, expr),
218 ScalarExpr::CallBinary { func, expr1, expr2 } => func.eval_batch(batch, expr1, expr2),
219 ScalarExpr::CallVariadic { func, exprs } => func.eval_batch(batch, exprs),
220 ScalarExpr::CallDf {
221 df_scalar_fn,
222 exprs,
223 } => df_scalar_fn.eval_batch(batch, exprs),
224 ScalarExpr::If { cond, then, els } => Self::eval_if_then(batch, cond, then, els),
225 }
226 }
227
228 fn eval_if_then(
231 batch: &Batch,
232 cond: &ScalarExpr,
233 then: &ScalarExpr,
234 els: &ScalarExpr,
235 ) -> Result<VectorRef, EvalError> {
236 let conds = cond.eval_batch(batch)?;
237 let bool_conds = conds
238 .as_any()
239 .downcast_ref::<BooleanVector>()
240 .context({
241 TypeMismatchSnafu {
242 expected: ConcreteDataType::boolean_datatype(),
243 actual: conds.data_type(),
244 }
245 })?
246 .as_boolean_array();
247
248 let indices = bool_conds
249 .into_iter()
250 .enumerate()
251 .map(|(idx, b)| {
252 (
253 match b {
254 Some(true) => 0, Some(false) => 1, None => 2, },
258 idx,
259 )
260 })
261 .collect_vec();
262
263 let then_input_vec = then.eval_batch(batch)?;
264 let else_input_vec = els.eval_batch(batch)?;
265
266 ensure!(
267 then_input_vec.data_type() == else_input_vec.data_type(),
268 TypeMismatchSnafu {
269 expected: then_input_vec.data_type(),
270 actual: else_input_vec.data_type(),
271 }
272 );
273
274 ensure!(
275 then_input_vec.len() == else_input_vec.len() && then_input_vec.len() == batch.row_count(),
276 InvalidArgumentSnafu {
277 reason: format!(
278 "then and else branch must have the same length(found {} and {}) which equals input batch's row count(which is {})",
279 then_input_vec.len(),
280 else_input_vec.len(),
281 batch.row_count()
282 )
283 }
284 );
285
286 fn new_nulls(dt: &arrow_schema::DataType, len: usize) -> ArrayRef {
287 let data = ArrayData::new_null(dt, len);
288 make_array(data)
289 }
290
291 let null_input_vec = new_nulls(
292 &then_input_vec.data_type().as_arrow_type(),
293 batch.row_count(),
294 );
295
296 let interleave_values = vec![
297 then_input_vec.to_arrow_array(),
298 else_input_vec.to_arrow_array(),
299 null_input_vec,
300 ];
301 let int_ref: Vec<_> = interleave_values.iter().map(|x| x.as_ref()).collect();
302
303 let interleave_res_arr =
304 arrow::compute::interleave(&int_ref, &indices).context(ArrowSnafu {
305 context: "Failed to interleave output arrays",
306 })?;
307 let res_vec = Helper::try_into_vector(interleave_res_arr).context(DataTypeSnafu {
308 msg: "Failed to convert arrow array to vector",
309 })?;
310 Ok(res_vec)
311 }
312
313 pub fn eval(&self, values: &[Value]) -> Result<Value, EvalError> {
318 match self {
319 ScalarExpr::Column(index) => Ok(values[*index].clone()),
320 ScalarExpr::Literal(row_res, _ty) => Ok(row_res.clone()),
321 ScalarExpr::CallUnmaterializable(_) => OptimizeSnafu {
322 reason: "Can't eval unmaterializable function".to_string(),
323 }
324 .fail(),
325 ScalarExpr::CallUnary { func, expr } => func.eval(values, expr),
326 ScalarExpr::CallBinary { func, expr1, expr2 } => func.eval(values, expr1, expr2),
327 ScalarExpr::CallVariadic { func, exprs } => func.eval(values, exprs),
328 ScalarExpr::If { cond, then, els } => match cond.eval(values) {
329 Ok(Value::Boolean(true)) => then.eval(values),
330 Ok(Value::Boolean(false)) => els.eval(values),
331 _ => InvalidArgumentSnafu {
332 reason: "if condition must be boolean".to_string(),
333 }
334 .fail(),
335 },
336 ScalarExpr::CallDf {
337 df_scalar_fn,
338 exprs,
339 } => df_scalar_fn.eval(values, exprs),
340 }
341 }
342
343 pub fn permute(&mut self, permutation: &[usize]) -> Result<(), Error> {
349 ensure!(
351 self.get_all_ref_columns()
352 .into_iter()
353 .all(|i| i < permutation.len()),
354 InvalidQuerySnafu {
355 reason: format!(
356 "permutation {:?} is not a valid permutation for expression {:?}",
357 permutation, self
358 ),
359 }
360 );
361
362 self.visit_mut_post_nolimit(&mut |e| {
363 if let ScalarExpr::Column(old_i) = e {
364 *old_i = permutation[*old_i];
365 }
366 Ok(())
367 })?;
368 Ok(())
369 }
370
371 pub fn permute_map(&mut self, permutation: &BTreeMap<usize, usize>) -> Result<(), Error> {
377 ensure!(
379 self.get_all_ref_columns()
380 .is_subset(&permutation.keys().cloned().collect()),
381 InvalidQuerySnafu {
382 reason: format!(
383 "permutation {:?} is not a valid permutation for expression {:?}",
384 permutation, self
385 ),
386 }
387 );
388
389 self.visit_mut_post_nolimit(&mut |e| {
390 if let ScalarExpr::Column(old_i) = e {
391 *old_i = permutation[old_i];
392 }
393 Ok(())
394 })
395 }
396
397 pub fn get_all_ref_columns(&self) -> BTreeSet<usize> {
399 let mut support = BTreeSet::new();
400 self.visit_post_nolimit(&mut |e| {
401 if let ScalarExpr::Column(i) = e {
402 support.insert(*i);
403 }
404 Ok(())
405 })
406 .unwrap();
407 support
408 }
409
410 pub fn is_column(&self) -> bool {
412 matches!(self, ScalarExpr::Column(_))
413 }
414
415 pub fn as_column(&self) -> Option<usize> {
417 if let ScalarExpr::Column(i) = self {
418 Some(*i)
419 } else {
420 None
421 }
422 }
423
424 pub fn as_literal(&self) -> Option<Value> {
426 if let ScalarExpr::Literal(lit, _column_type) = self {
427 Some(lit.clone())
428 } else {
429 None
430 }
431 }
432
433 pub fn is_literal(&self) -> bool {
435 matches!(self, ScalarExpr::Literal(..))
436 }
437
438 pub fn is_literal_true(&self) -> bool {
440 Some(Value::Boolean(true)) == self.as_literal()
441 }
442
443 pub fn is_literal_false(&self) -> bool {
445 Some(Value::Boolean(false)) == self.as_literal()
446 }
447
448 pub fn is_literal_null(&self) -> bool {
450 Some(Value::Null) == self.as_literal()
451 }
452
453 pub fn literal_null() -> Self {
455 ScalarExpr::Literal(Value::Null, ConcreteDataType::null_datatype())
456 }
457
458 pub fn literal(res: Value, typ: ConcreteDataType) -> Self {
460 ScalarExpr::Literal(res, typ)
461 }
462
463 pub fn literal_false() -> Self {
465 ScalarExpr::Literal(Value::Boolean(false), ConcreteDataType::boolean_datatype())
466 }
467
468 pub fn literal_true() -> Self {
470 ScalarExpr::Literal(Value::Boolean(true), ConcreteDataType::boolean_datatype())
471 }
472}
473
474impl ScalarExpr {
475 fn visit_post_nolimit<F>(&self, f: &mut F) -> Result<(), EvalError>
477 where
478 F: FnMut(&Self) -> Result<(), EvalError>,
479 {
480 self.visit_children(|e| e.visit_post_nolimit(f))?;
481 f(self)
482 }
483
484 fn visit_children<F>(&self, mut f: F) -> Result<(), EvalError>
485 where
486 F: FnMut(&Self) -> Result<(), EvalError>,
487 {
488 match self {
489 ScalarExpr::Column(_)
490 | ScalarExpr::Literal(_, _)
491 | ScalarExpr::CallUnmaterializable(_) => Ok(()),
492 ScalarExpr::CallUnary { expr, .. } => f(expr),
493 ScalarExpr::CallBinary { expr1, expr2, .. } => {
494 f(expr1)?;
495 f(expr2)
496 }
497 ScalarExpr::CallVariadic { exprs, .. } => {
498 for expr in exprs {
499 f(expr)?;
500 }
501 Ok(())
502 }
503 ScalarExpr::If { cond, then, els } => {
504 f(cond)?;
505 f(then)?;
506 f(els)
507 }
508 ScalarExpr::CallDf {
509 df_scalar_fn: _,
510 exprs,
511 } => {
512 for expr in exprs {
513 f(expr)?;
514 }
515 Ok(())
516 }
517 }
518 }
519
520 fn visit_mut_post_nolimit<F>(&mut self, f: &mut F) -> Result<(), Error>
521 where
522 F: FnMut(&mut Self) -> Result<(), Error>,
523 {
524 self.visit_mut_children(|e: &mut Self| e.visit_mut_post_nolimit(f))?;
525 f(self)
526 }
527
528 fn visit_mut_children<F>(&mut self, mut f: F) -> Result<(), Error>
529 where
530 F: FnMut(&mut Self) -> Result<(), Error>,
531 {
532 match self {
533 ScalarExpr::Column(_)
534 | ScalarExpr::Literal(_, _)
535 | ScalarExpr::CallUnmaterializable(_) => Ok(()),
536 ScalarExpr::CallUnary { expr, .. } => f(expr),
537 ScalarExpr::CallBinary { expr1, expr2, .. } => {
538 f(expr1)?;
539 f(expr2)
540 }
541 ScalarExpr::CallVariadic { exprs, .. } => {
542 for expr in exprs {
543 f(expr)?;
544 }
545 Ok(())
546 }
547 ScalarExpr::If { cond, then, els } => {
548 f(cond)?;
549 f(then)?;
550 f(els)
551 }
552 ScalarExpr::CallDf {
553 df_scalar_fn: _,
554 exprs,
555 } => {
556 for expr in exprs {
557 f(expr)?;
558 }
559 Ok(())
560 }
561 }
562 }
563}
564
565impl ScalarExpr {
566 pub fn contains_temporal(&self) -> bool {
568 let mut contains = false;
569 self.visit_post_nolimit(&mut |e| {
570 if let ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now) = e {
571 contains = true;
572 }
573 Ok(())
574 })
575 .unwrap();
576 contains
577 }
578
579 pub fn extract_bound(&self) -> Result<(Option<Self>, Option<Self>), Error> {
586 let unsupported_err = |msg: &str| {
587 UnsupportedTemporalFilterSnafu {
588 reason: msg.to_string(),
589 }
590 .fail()
591 };
592
593 let Self::CallBinary {
594 mut func,
595 mut expr1,
596 mut expr2,
597 } = self.clone()
598 else {
599 return unsupported_err("Not a binary expression");
600 };
601
602 let expr1_is_now = *expr1 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now);
605 let expr2_is_now = *expr2 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now);
606
607 if !(expr1_is_now ^ expr2_is_now) {
608 return unsupported_err("None of the sides of the comparison is `now()`");
609 }
610
611 if expr2_is_now {
612 std::mem::swap(&mut expr1, &mut expr2);
613 func = BinaryFunc::reverse_compare(&func)?;
614 }
615
616 let step = |expr: ScalarExpr| expr.call_unary(UnaryFunc::StepTimestamp);
617 match func {
618 BinaryFunc::Eq => Ok((Some(*expr2.clone()), Some(step(*expr2)))),
620 BinaryFunc::Lt => Ok((None, Some(*expr2))),
622 BinaryFunc::Lte => Ok((None, Some(step(*expr2)))),
624 BinaryFunc::Gt => Ok((Some(step(*expr2)), None)),
626 BinaryFunc::Gte => Ok((Some(*expr2), None)),
628 _ => unreachable!("Already checked"),
629 }
630 }
631}
632
633#[cfg(test)]
634mod test {
635 use datatypes::vectors::{Int32Vector, Vector};
636 use pretty_assertions::assert_eq;
637
638 use super::*;
639
640 #[test]
641 fn test_extract_bound() {
642 let test_list: [(ScalarExpr, Result<_, EvalError>); 5] = [
643 (
645 ScalarExpr::CallBinary {
646 func: BinaryFunc::Eq,
647 expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
648 expr2: Box::new(ScalarExpr::Column(0)),
649 },
650 Ok((
651 Some(ScalarExpr::Column(0)),
652 Some(ScalarExpr::CallUnary {
653 func: UnaryFunc::StepTimestamp,
654 expr: Box::new(ScalarExpr::Column(0)),
655 }),
656 )),
657 ),
658 (
660 ScalarExpr::CallBinary {
661 func: BinaryFunc::Lt,
662 expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
663 expr2: Box::new(ScalarExpr::Column(0)),
664 },
665 Ok((None, Some(ScalarExpr::Column(0)))),
666 ),
667 (
669 ScalarExpr::CallBinary {
670 func: BinaryFunc::Lte,
671 expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
672 expr2: Box::new(ScalarExpr::Column(0)),
673 },
674 Ok((
675 None,
676 Some(ScalarExpr::CallUnary {
677 func: UnaryFunc::StepTimestamp,
678 expr: Box::new(ScalarExpr::Column(0)),
679 }),
680 )),
681 ),
682 (
684 ScalarExpr::CallBinary {
685 func: BinaryFunc::Gt,
686 expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
687 expr2: Box::new(ScalarExpr::Column(0)),
688 },
689 Ok((
690 Some(ScalarExpr::CallUnary {
691 func: UnaryFunc::StepTimestamp,
692 expr: Box::new(ScalarExpr::Column(0)),
693 }),
694 None,
695 )),
696 ),
697 (
699 ScalarExpr::CallBinary {
700 func: BinaryFunc::Gte,
701 expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
702 expr2: Box::new(ScalarExpr::Column(0)),
703 },
704 Ok((Some(ScalarExpr::Column(0)), None)),
705 ),
706 ];
707 for (expr, expected) in test_list.into_iter() {
708 let actual = expr.extract_bound();
709 match (actual, expected) {
711 (Ok(l), Ok(r)) => assert_eq!(l, r),
712 (l, r) => panic!("expected: {:?}, actual: {:?}", r, l),
713 }
714 }
715 }
716
717 #[test]
718 fn test_bad_permute() {
719 let mut expr = ScalarExpr::Column(4);
720 let permutation = vec![1, 2, 3];
721 let res = expr.permute(&permutation);
722 assert!(matches!(res, Err(Error::InvalidQuery { .. })));
723
724 let mut expr = ScalarExpr::Column(0);
725 let permute_map = BTreeMap::from([(1, 2), (3, 4)]);
726 let res = expr.permute_map(&permute_map);
727 assert!(matches!(res, Err(Error::InvalidQuery { .. })));
728 }
729
730 #[test]
731 fn test_eval_batch_if_then() {
732 {
734 let expr = ScalarExpr::If {
735 cond: Box::new(ScalarExpr::Column(0).call_binary(
736 ScalarExpr::literal(Value::from(0), ConcreteDataType::int32_datatype()),
737 BinaryFunc::Eq,
738 )),
739 then: Box::new(ScalarExpr::literal(
740 Value::from(42),
741 ConcreteDataType::int32_datatype(),
742 )),
743 els: Box::new(ScalarExpr::literal(
744 Value::from(37),
745 ConcreteDataType::int32_datatype(),
746 )),
747 };
748 let raw = vec![
749 None,
750 Some(0),
751 Some(1),
752 None,
753 None,
754 Some(0),
755 Some(0),
756 Some(1),
757 Some(1),
758 ];
759 let raw_len = raw.len();
760 let vectors = vec![Int32Vector::from(raw).slice(0, raw_len)];
761
762 let batch = Batch::try_new(vectors, raw_len).unwrap();
763 let expected = Int32Vector::from(vec![
764 None,
765 Some(42),
766 Some(37),
767 None,
768 None,
769 Some(42),
770 Some(42),
771 Some(37),
772 Some(37),
773 ])
774 .slice(0, raw_len);
775 assert_eq!(expr.eval_batch(&batch).unwrap(), expected);
776
777 let raw = vec![Some(0)];
778 let raw_len = raw.len();
779 let vectors = vec![Int32Vector::from(raw).slice(0, raw_len)];
780
781 let batch = Batch::try_new(vectors, raw_len).unwrap();
782 let expected = Int32Vector::from(vec![Some(42)]).slice(0, raw_len);
783 assert_eq!(expr.eval_batch(&batch).unwrap(), expected);
784
785 let raw: Vec<Option<i32>> = vec![];
786 let raw_len = raw.len();
787 let vectors = vec![Int32Vector::from(raw).slice(0, raw_len)];
788
789 let batch = Batch::try_new(vectors, raw_len).unwrap();
790 let expected = Int32Vector::from(vec![]).slice(0, raw_len);
791 assert_eq!(expr.eval_batch(&batch).unwrap(), expected);
792 }
793 }
794}