flow/expr/
scalar.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Scalar expressions.
16
17use std::collections::{BTreeMap, BTreeSet};
18use std::sync::Arc;
19
20use arrow::array::{ArrayData, ArrayRef, BooleanArray, make_array};
21use arrow::buffer::BooleanBuffer;
22use arrow::compute::or_kleene;
23use common_error::ext::BoxedError;
24use datafusion::physical_expr_common::datum::compare_with_eq;
25use datafusion_common::DataFusionError;
26use datatypes::prelude::{ConcreteDataType, DataType};
27use datatypes::value::Value;
28use datatypes::vectors::{BooleanVector, Helper, VectorRef};
29use dfir_rs::lattices::cc_traits::Iter;
30use itertools::Itertools;
31use snafu::{OptionExt, ResultExt, ensure};
32
33use crate::error::{
34    DatafusionSnafu, Error, InvalidQuerySnafu, UnexpectedSnafu, UnsupportedTemporalFilterSnafu,
35};
36use crate::expr::error::{
37    ArrowSnafu, DataTypeSnafu, EvalError, InvalidArgumentSnafu, OptimizeSnafu, TypeMismatchSnafu,
38};
39use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc};
40use crate::expr::{Batch, DfScalarFunction};
41use crate::repr::ColumnType;
42/// A scalar expression with a known type.
43#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Hash)]
44pub struct TypedExpr {
45    /// The expression.
46    pub expr: ScalarExpr,
47    /// The type of the expression.
48    pub typ: ColumnType,
49}
50
51impl TypedExpr {
52    pub fn new(expr: ScalarExpr, typ: ColumnType) -> Self {
53        Self { expr, typ }
54    }
55}
56
57/// A scalar expression, which can be evaluated to a value.
58#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
59pub enum ScalarExpr {
60    /// A column of the input row
61    Column(usize),
62    /// A literal value.
63    /// Extra type info to know original type even when it is null
64    Literal(Value, ConcreteDataType),
65    /// A call to an unmaterializable function.
66    ///
67    /// These functions cannot be evaluated by `ScalarExpr::eval`. They must
68    /// be transformed away by a higher layer.
69    CallUnmaterializable(UnmaterializableFunc),
70    CallUnary {
71        func: UnaryFunc,
72        expr: Box<ScalarExpr>,
73    },
74    CallBinary {
75        func: BinaryFunc,
76        expr1: Box<ScalarExpr>,
77        expr2: Box<ScalarExpr>,
78    },
79    CallVariadic {
80        func: VariadicFunc,
81        exprs: Vec<ScalarExpr>,
82    },
83    CallDf {
84        /// invariant: the input args set inside this [`DfScalarFunction`] is
85        /// always col(0) to col(n-1) where n is the length of `expr`
86        df_scalar_fn: DfScalarFunction,
87        exprs: Vec<ScalarExpr>,
88    },
89    /// Conditionally evaluated expressions.
90    ///
91    /// It is important that `then` and `els` only be evaluated if
92    /// `cond` is true or not, respectively. This is the only way
93    /// users can guard execution (other logical operator do not
94    /// short-circuit) and we need to preserve that.
95    If {
96        cond: Box<ScalarExpr>,
97        then: Box<ScalarExpr>,
98        els: Box<ScalarExpr>,
99    },
100    InList {
101        expr: Box<ScalarExpr>,
102        list: Vec<ScalarExpr>,
103    },
104}
105
106impl ScalarExpr {
107    pub fn with_type(self, typ: ColumnType) -> TypedExpr {
108        TypedExpr::new(self, typ)
109    }
110
111    /// try to determine the type of the expression
112    pub fn typ(&self, context: &[ColumnType]) -> Result<ColumnType, Error> {
113        match self {
114            ScalarExpr::Column(i) => context.get(*i).cloned().ok_or_else(|| {
115                UnexpectedSnafu {
116                    reason: format!("column index {} out of range of len={}", i, context.len()),
117                }
118                .build()
119            }),
120            ScalarExpr::Literal(_, typ) => Ok(ColumnType::new_nullable(typ.clone())),
121            ScalarExpr::CallUnmaterializable(func) => {
122                Ok(ColumnType::new_nullable(func.signature().output))
123            }
124            ScalarExpr::CallUnary { func, .. } => {
125                Ok(ColumnType::new_nullable(func.signature().output))
126            }
127            ScalarExpr::CallBinary { func, .. } => {
128                Ok(ColumnType::new_nullable(func.signature().output))
129            }
130            ScalarExpr::CallVariadic { func, .. } => {
131                Ok(ColumnType::new_nullable(func.signature().output))
132            }
133            ScalarExpr::If { then, .. } => then.typ(context),
134            ScalarExpr::CallDf { df_scalar_fn, .. } => {
135                let arrow_typ = df_scalar_fn
136                    .fn_impl
137                    // TODO(discord9): get scheme from args instead?
138                    .data_type(df_scalar_fn.df_schema.as_arrow())
139                    .context({
140                        DatafusionSnafu {
141                            context: "Failed to get data type from datafusion scalar function",
142                        }
143                    })?;
144                let typ = ConcreteDataType::try_from(&arrow_typ)
145                    .map_err(BoxedError::new)
146                    .context(crate::error::ExternalSnafu)?;
147                Ok(ColumnType::new_nullable(typ))
148            }
149            ScalarExpr::InList { expr, .. } => expr.typ(context),
150        }
151    }
152}
153
154impl ScalarExpr {
155    pub fn cast(self, typ: ConcreteDataType) -> Self {
156        ScalarExpr::CallUnary {
157            func: UnaryFunc::Cast(typ),
158            expr: Box::new(self),
159        }
160    }
161
162    /// apply optimization to the expression, like flatten variadic function
163    pub fn optimize(&mut self) {
164        self.flatten_variadic_fn();
165    }
166
167    /// Because Substrait's `And`/`Or` function is binary, but FlowPlan's
168    /// `And`/`Or` function is variadic, we need to flatten the `And` function if multiple `And`/`Or` functions are nested.
169    fn flatten_variadic_fn(&mut self) {
170        if let ScalarExpr::CallVariadic { func, exprs } = self {
171            let mut new_exprs = vec![];
172            for expr in std::mem::take(exprs) {
173                if let ScalarExpr::CallVariadic {
174                    func: inner_func,
175                    exprs: mut inner_exprs,
176                } = expr
177                {
178                    if *func == inner_func {
179                        for inner_expr in inner_exprs.iter_mut() {
180                            inner_expr.flatten_variadic_fn();
181                        }
182                        new_exprs.extend(inner_exprs);
183                    }
184                } else {
185                    new_exprs.push(expr);
186                }
187            }
188            *exprs = new_exprs;
189        }
190    }
191}
192
193impl ScalarExpr {
194    /// Call a unary function on this expression.
195    pub fn call_unary(self, func: UnaryFunc) -> Self {
196        ScalarExpr::CallUnary {
197            func,
198            expr: Box::new(self),
199        }
200    }
201
202    /// Call a binary function on this expression and another.
203    pub fn call_binary(self, other: Self, func: BinaryFunc) -> Self {
204        ScalarExpr::CallBinary {
205            func,
206            expr1: Box::new(self),
207            expr2: Box::new(other),
208        }
209    }
210
211    pub fn eval_batch(&self, batch: &Batch) -> Result<VectorRef, EvalError> {
212        match self {
213            ScalarExpr::Column(i) => Ok(batch.batch()[*i].clone()),
214            ScalarExpr::Literal(val, dt) => Ok(Helper::try_from_scalar_value(
215                val.try_to_scalar_value(dt).context(DataTypeSnafu {
216                    msg: "Failed to convert literal to scalar value",
217                })?,
218                batch.row_count(),
219            )
220            .context(DataTypeSnafu {
221                msg: "Failed to convert scalar value to vector ref when parsing literal",
222            })?),
223            ScalarExpr::CallUnmaterializable(_) => OptimizeSnafu {
224                reason: "Can't eval unmaterializable function",
225            }
226            .fail()?,
227            ScalarExpr::CallUnary { func, expr } => func.eval_batch(batch, expr),
228            ScalarExpr::CallBinary { func, expr1, expr2 } => func.eval_batch(batch, expr1, expr2),
229            ScalarExpr::CallVariadic { func, exprs } => func.eval_batch(batch, exprs),
230            ScalarExpr::CallDf {
231                df_scalar_fn,
232                exprs,
233            } => df_scalar_fn.eval_batch(batch, exprs),
234            ScalarExpr::If { cond, then, els } => Self::eval_if_then(batch, cond, then, els),
235            ScalarExpr::InList { expr, list } => Self::eval_in_list(batch, expr, list),
236        }
237    }
238
239    fn eval_in_list(
240        batch: &Batch,
241        expr: &ScalarExpr,
242        list: &[ScalarExpr],
243    ) -> Result<VectorRef, EvalError> {
244        let eval_list = list
245            .iter()
246            .map(|e| e.eval_batch(batch))
247            .collect::<Result<Vec<_>, _>>()?;
248        let eval_expr = expr.eval_batch(batch)?;
249
250        ensure!(
251            eval_list
252                .iter()
253                .all(|v| v.data_type() == eval_expr.data_type()),
254            TypeMismatchSnafu {
255                expected: eval_expr.data_type(),
256                actual: eval_list
257                    .iter()
258                    .find(|v| v.data_type() != eval_expr.data_type())
259                    .map(|v| v.data_type())
260                    .unwrap(),
261            }
262        );
263
264        let lhs = eval_expr.to_arrow_array();
265
266        let found = eval_list
267            .iter()
268            .map(|v| v.to_arrow_array())
269            .try_fold(
270                BooleanArray::new(BooleanBuffer::new_unset(batch.row_count()), None),
271                |result, in_list_elem| -> Result<BooleanArray, DataFusionError> {
272                    let rhs = compare_with_eq(&lhs, &in_list_elem, false)?;
273
274                    Ok(or_kleene(&result, &rhs)?)
275                },
276            )
277            .with_context(|_| crate::expr::error::DatafusionSnafu {
278                context: "Failed to compare eval_expr with eval_list",
279            })?;
280
281        let res = BooleanVector::from(found);
282
283        Ok(Arc::new(res))
284    }
285
286    /// NOTE: this if then eval impl assume all given expr are pure, and will not change the state of the world
287    /// since it will evaluate both then and else branch and filter the result
288    fn eval_if_then(
289        batch: &Batch,
290        cond: &ScalarExpr,
291        then: &ScalarExpr,
292        els: &ScalarExpr,
293    ) -> Result<VectorRef, EvalError> {
294        let conds = cond.eval_batch(batch)?;
295        let bool_conds = conds
296            .as_any()
297            .downcast_ref::<BooleanVector>()
298            .context({
299                TypeMismatchSnafu {
300                    expected: ConcreteDataType::boolean_datatype(),
301                    actual: conds.data_type(),
302                }
303            })?
304            .as_boolean_array();
305
306        let indices = bool_conds
307            .into_iter()
308            .enumerate()
309            .map(|(idx, b)| {
310                (
311                    match b {
312                        Some(true) => 0,  // then branch vector
313                        Some(false) => 1, // else branch vector
314                        None => 2,        // null vector
315                    },
316                    idx,
317                )
318            })
319            .collect_vec();
320
321        let then_input_vec = then.eval_batch(batch)?;
322        let else_input_vec = els.eval_batch(batch)?;
323
324        ensure!(
325            then_input_vec.data_type() == else_input_vec.data_type(),
326            TypeMismatchSnafu {
327                expected: then_input_vec.data_type(),
328                actual: else_input_vec.data_type(),
329            }
330        );
331
332        ensure!(
333            then_input_vec.len() == else_input_vec.len()
334                && then_input_vec.len() == batch.row_count(),
335            InvalidArgumentSnafu {
336                reason: format!(
337                    "then and else branch must have the same length(found {} and {}) which equals input batch's row count(which is {})",
338                    then_input_vec.len(),
339                    else_input_vec.len(),
340                    batch.row_count()
341                )
342            }
343        );
344
345        fn new_nulls(dt: &arrow_schema::DataType, len: usize) -> ArrayRef {
346            let data = ArrayData::new_null(dt, len);
347            make_array(data)
348        }
349
350        let null_input_vec = new_nulls(
351            &then_input_vec.data_type().as_arrow_type(),
352            batch.row_count(),
353        );
354
355        let interleave_values = vec![
356            then_input_vec.to_arrow_array(),
357            else_input_vec.to_arrow_array(),
358            null_input_vec,
359        ];
360        let int_ref: Vec<_> = interleave_values.iter().map(|x| x.as_ref()).collect();
361
362        let interleave_res_arr =
363            arrow::compute::interleave(&int_ref, &indices).context(ArrowSnafu {
364                context: "Failed to interleave output arrays",
365            })?;
366        let res_vec = Helper::try_into_vector(interleave_res_arr).context(DataTypeSnafu {
367            msg: "Failed to convert arrow array to vector",
368        })?;
369        Ok(res_vec)
370    }
371
372    /// Eval this expression with the given values.
373    ///
374    /// TODO(discord9): add tests to make sure `eval_batch` is the same as `eval` in
375    /// most cases
376    pub fn eval(&self, values: &[Value]) -> Result<Value, EvalError> {
377        match self {
378            ScalarExpr::Column(index) => Ok(values[*index].clone()),
379            ScalarExpr::Literal(row_res, _ty) => Ok(row_res.clone()),
380            ScalarExpr::CallUnmaterializable(_) => OptimizeSnafu {
381                reason: "Can't eval unmaterializable function".to_string(),
382            }
383            .fail(),
384            ScalarExpr::CallUnary { func, expr } => func.eval(values, expr),
385            ScalarExpr::CallBinary { func, expr1, expr2 } => func.eval(values, expr1, expr2),
386            ScalarExpr::CallVariadic { func, exprs } => func.eval(values, exprs),
387            ScalarExpr::If { cond, then, els } => match cond.eval(values) {
388                Ok(Value::Boolean(true)) => then.eval(values),
389                Ok(Value::Boolean(false)) => els.eval(values),
390                _ => InvalidArgumentSnafu {
391                    reason: "if condition must be boolean".to_string(),
392                }
393                .fail(),
394            },
395            ScalarExpr::CallDf {
396                df_scalar_fn,
397                exprs,
398            } => df_scalar_fn.eval(values, exprs),
399            ScalarExpr::InList { expr, list } => {
400                let eval_expr = expr.eval(values)?;
401                let eval_list = list
402                    .iter()
403                    .map(|v| v.eval(values))
404                    .collect::<Result<Vec<_>, _>>()?;
405                let found = eval_list.iter().any(|item| *item == eval_expr);
406                Ok(Value::Boolean(found))
407            }
408        }
409    }
410
411    /// Rewrites column indices with their value in `permutation`.
412    ///
413    /// This method is applicable even when `permutation` is not a
414    /// strict permutation, and it only needs to have entries for
415    /// each column referenced in `self`.
416    pub fn permute(&mut self, permutation: &[usize]) -> Result<(), Error> {
417        // check first so that we don't end up with a partially permuted expression
418        ensure!(
419            self.get_all_ref_columns()
420                .into_iter()
421                .all(|i| i < permutation.len()),
422            InvalidQuerySnafu {
423                reason: format!(
424                    "permutation {:?} is not a valid permutation for expression {:?}",
425                    permutation, self
426                ),
427            }
428        );
429
430        self.visit_mut_post_nolimit(&mut |e| {
431            if let ScalarExpr::Column(old_i) = e {
432                *old_i = permutation[*old_i];
433            }
434            Ok(())
435        })?;
436        Ok(())
437    }
438
439    /// Rewrites column indices with their value in `permutation`.
440    ///
441    /// This method is applicable even when `permutation` is not a
442    /// strict permutation, and it only needs to have entries for
443    /// each column referenced in `self`.
444    pub fn permute_map(&mut self, permutation: &BTreeMap<usize, usize>) -> Result<(), Error> {
445        // check first so that we don't end up with a partially permuted expression
446        ensure!(
447            self.get_all_ref_columns()
448                .is_subset(&permutation.keys().cloned().collect()),
449            InvalidQuerySnafu {
450                reason: format!(
451                    "permutation {:?} is not a valid permutation for expression {:?}",
452                    permutation, self
453                ),
454            }
455        );
456
457        self.visit_mut_post_nolimit(&mut |e| {
458            if let ScalarExpr::Column(old_i) = e {
459                *old_i = permutation[old_i];
460            }
461            Ok(())
462        })
463    }
464
465    /// Returns the set of columns that are referenced by `self`.
466    pub fn get_all_ref_columns(&self) -> BTreeSet<usize> {
467        let mut support = BTreeSet::new();
468        self.visit_post_nolimit(&mut |e| {
469            if let ScalarExpr::Column(i) = e {
470                support.insert(*i);
471            }
472            Ok(())
473        })
474        .unwrap();
475        support
476    }
477
478    /// Return true if the expression is a column reference.
479    pub fn is_column(&self) -> bool {
480        matches!(self, ScalarExpr::Column(_))
481    }
482
483    /// Cast the expression to a column reference if it is one.
484    pub fn as_column(&self) -> Option<usize> {
485        if let ScalarExpr::Column(i) = self {
486            Some(*i)
487        } else {
488            None
489        }
490    }
491
492    /// Cast the expression to a literal if it is one.
493    pub fn as_literal(&self) -> Option<Value> {
494        if let ScalarExpr::Literal(lit, _column_type) = self {
495            Some(lit.clone())
496        } else {
497            None
498        }
499    }
500
501    /// Return true if the expression is a literal.
502    pub fn is_literal(&self) -> bool {
503        matches!(self, ScalarExpr::Literal(..))
504    }
505
506    /// Return true if the expression is a literal true.
507    pub fn is_literal_true(&self) -> bool {
508        Some(Value::Boolean(true)) == self.as_literal()
509    }
510
511    /// Return true if the expression is a literal false.
512    pub fn is_literal_false(&self) -> bool {
513        Some(Value::Boolean(false)) == self.as_literal()
514    }
515
516    /// Return true if the expression is a literal null.
517    pub fn is_literal_null(&self) -> bool {
518        Some(Value::Null) == self.as_literal()
519    }
520
521    /// Build a literal null
522    pub fn literal_null() -> Self {
523        ScalarExpr::Literal(Value::Null, ConcreteDataType::null_datatype())
524    }
525
526    /// Build a literal from value and type
527    pub fn literal(res: Value, typ: ConcreteDataType) -> Self {
528        ScalarExpr::Literal(res, typ)
529    }
530
531    /// Build a literal false
532    pub fn literal_false() -> Self {
533        ScalarExpr::Literal(Value::Boolean(false), ConcreteDataType::boolean_datatype())
534    }
535
536    /// Build a literal true
537    pub fn literal_true() -> Self {
538        ScalarExpr::Literal(Value::Boolean(true), ConcreteDataType::boolean_datatype())
539    }
540}
541
542impl ScalarExpr {
543    /// visit post-order without stack call limit, but may cause stack overflow
544    fn visit_post_nolimit<F>(&self, f: &mut F) -> Result<(), EvalError>
545    where
546        F: FnMut(&Self) -> Result<(), EvalError>,
547    {
548        self.visit_children(|e| e.visit_post_nolimit(f))?;
549        f(self)
550    }
551
552    fn visit_children<F>(&self, mut f: F) -> Result<(), EvalError>
553    where
554        F: FnMut(&Self) -> Result<(), EvalError>,
555    {
556        match self {
557            ScalarExpr::Column(_)
558            | ScalarExpr::Literal(_, _)
559            | ScalarExpr::CallUnmaterializable(_) => Ok(()),
560            ScalarExpr::CallUnary { expr, .. } => f(expr),
561            ScalarExpr::CallBinary { expr1, expr2, .. } => {
562                f(expr1)?;
563                f(expr2)
564            }
565            ScalarExpr::CallVariadic { exprs, .. } => {
566                for expr in exprs {
567                    f(expr)?;
568                }
569                Ok(())
570            }
571            ScalarExpr::If { cond, then, els } => {
572                f(cond)?;
573                f(then)?;
574                f(els)
575            }
576            ScalarExpr::CallDf {
577                df_scalar_fn: _,
578                exprs,
579            } => {
580                for expr in exprs {
581                    f(expr)?;
582                }
583                Ok(())
584            }
585            ScalarExpr::InList { expr, list } => {
586                f(expr)?;
587                for item in list {
588                    f(item)?;
589                }
590                Ok(())
591            }
592        }
593    }
594
595    fn visit_mut_post_nolimit<F>(&mut self, f: &mut F) -> Result<(), Error>
596    where
597        F: FnMut(&mut Self) -> Result<(), Error>,
598    {
599        self.visit_mut_children(|e: &mut Self| e.visit_mut_post_nolimit(f))?;
600        f(self)
601    }
602
603    fn visit_mut_children<F>(&mut self, mut f: F) -> Result<(), Error>
604    where
605        F: FnMut(&mut Self) -> Result<(), Error>,
606    {
607        match self {
608            ScalarExpr::Column(_)
609            | ScalarExpr::Literal(_, _)
610            | ScalarExpr::CallUnmaterializable(_) => Ok(()),
611            ScalarExpr::CallUnary { expr, .. } => f(expr),
612            ScalarExpr::CallBinary { expr1, expr2, .. } => {
613                f(expr1)?;
614                f(expr2)
615            }
616            ScalarExpr::CallVariadic { exprs, .. } => {
617                for expr in exprs {
618                    f(expr)?;
619                }
620                Ok(())
621            }
622            ScalarExpr::If { cond, then, els } => {
623                f(cond)?;
624                f(then)?;
625                f(els)
626            }
627            ScalarExpr::CallDf {
628                df_scalar_fn: _,
629                exprs,
630            } => {
631                for expr in exprs {
632                    f(expr)?;
633                }
634                Ok(())
635            }
636            ScalarExpr::InList { expr, list } => {
637                f(expr)?;
638                for item in list {
639                    f(item)?;
640                }
641                Ok(())
642            }
643        }
644    }
645}
646
647impl ScalarExpr {
648    /// if expr contains function `Now`
649    pub fn contains_temporal(&self) -> bool {
650        let mut contains = false;
651        self.visit_post_nolimit(&mut |e| {
652            if let ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now) = e {
653                contains = true;
654            }
655            Ok(())
656        })
657        .unwrap();
658        contains
659    }
660
661    /// extract lower or upper bound of `Now` for expr, where `lower bound <= expr < upper bound`
662    ///
663    /// returned bool indicates whether the bound is upper bound:
664    ///
665    /// false for lower bound, true for upper bound
666    /// TODO(discord9): allow simple transform like `now() + a < b` to `now() < b - a`
667    pub fn extract_bound(&self) -> Result<(Option<Self>, Option<Self>), Error> {
668        let unsupported_err = |msg: &str| {
669            UnsupportedTemporalFilterSnafu {
670                reason: msg.to_string(),
671            }
672            .fail()
673        };
674
675        let Self::CallBinary {
676            mut func,
677            mut expr1,
678            mut expr2,
679        } = self.clone()
680        else {
681            return unsupported_err("Not a binary expression");
682        };
683
684        // TODO(discord9): support simple transform like `now() + a < b` to `now() < b - a`
685
686        let expr1_is_now = *expr1 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now);
687        let expr2_is_now = *expr2 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now);
688
689        if !(expr1_is_now ^ expr2_is_now) {
690            return unsupported_err("None of the sides of the comparison is `now()`");
691        }
692
693        if expr2_is_now {
694            std::mem::swap(&mut expr1, &mut expr2);
695            func = BinaryFunc::reverse_compare(&func)?;
696        }
697
698        let step = |expr: ScalarExpr| expr.call_unary(UnaryFunc::StepTimestamp);
699        match func {
700            // now == expr2 -> now <= expr2 && now < expr2 + 1
701            BinaryFunc::Eq => Ok((Some(*expr2.clone()), Some(step(*expr2)))),
702            // now < expr2 -> now < expr2
703            BinaryFunc::Lt => Ok((None, Some(*expr2))),
704            // now <= expr2 -> now < expr2 + 1
705            BinaryFunc::Lte => Ok((None, Some(step(*expr2)))),
706            // now > expr2 -> now >= expr2 + 1
707            BinaryFunc::Gt => Ok((Some(step(*expr2)), None)),
708            // now >= expr2 -> now >= expr2
709            BinaryFunc::Gte => Ok((Some(*expr2), None)),
710            _ => unreachable!("Already checked"),
711        }
712    }
713}
714
715#[cfg(test)]
716mod test {
717    use datatypes::vectors::{Int32Vector, Vector};
718    use pretty_assertions::assert_eq;
719
720    use super::*;
721
722    #[test]
723    fn test_extract_bound() {
724        let test_list: [(ScalarExpr, Result<_, EvalError>); 5] = [
725            // col(0) == now
726            (
727                ScalarExpr::CallBinary {
728                    func: BinaryFunc::Eq,
729                    expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
730                    expr2: Box::new(ScalarExpr::Column(0)),
731                },
732                Ok((
733                    Some(ScalarExpr::Column(0)),
734                    Some(ScalarExpr::CallUnary {
735                        func: UnaryFunc::StepTimestamp,
736                        expr: Box::new(ScalarExpr::Column(0)),
737                    }),
738                )),
739            ),
740            // now < col(0)
741            (
742                ScalarExpr::CallBinary {
743                    func: BinaryFunc::Lt,
744                    expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
745                    expr2: Box::new(ScalarExpr::Column(0)),
746                },
747                Ok((None, Some(ScalarExpr::Column(0)))),
748            ),
749            // now <= col(0)
750            (
751                ScalarExpr::CallBinary {
752                    func: BinaryFunc::Lte,
753                    expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
754                    expr2: Box::new(ScalarExpr::Column(0)),
755                },
756                Ok((
757                    None,
758                    Some(ScalarExpr::CallUnary {
759                        func: UnaryFunc::StepTimestamp,
760                        expr: Box::new(ScalarExpr::Column(0)),
761                    }),
762                )),
763            ),
764            // now > col(0) -> now >= col(0) + 1
765            (
766                ScalarExpr::CallBinary {
767                    func: BinaryFunc::Gt,
768                    expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
769                    expr2: Box::new(ScalarExpr::Column(0)),
770                },
771                Ok((
772                    Some(ScalarExpr::CallUnary {
773                        func: UnaryFunc::StepTimestamp,
774                        expr: Box::new(ScalarExpr::Column(0)),
775                    }),
776                    None,
777                )),
778            ),
779            // now >= col(0)
780            (
781                ScalarExpr::CallBinary {
782                    func: BinaryFunc::Gte,
783                    expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
784                    expr2: Box::new(ScalarExpr::Column(0)),
785                },
786                Ok((Some(ScalarExpr::Column(0)), None)),
787            ),
788        ];
789        for (expr, expected) in test_list.into_iter() {
790            let actual = expr.extract_bound();
791            // EvalError is not Eq, so we need to compare the error message
792            match (actual, expected) {
793                (Ok(l), Ok(r)) => assert_eq!(l, r),
794                (l, r) => panic!("expected: {:?}, actual: {:?}", r, l),
795            }
796        }
797    }
798
799    #[test]
800    fn test_bad_permute() {
801        let mut expr = ScalarExpr::Column(4);
802        let permutation = vec![1, 2, 3];
803        let res = expr.permute(&permutation);
804        assert!(matches!(res, Err(Error::InvalidQuery { .. })));
805
806        let mut expr = ScalarExpr::Column(0);
807        let permute_map = BTreeMap::from([(1, 2), (3, 4)]);
808        let res = expr.permute_map(&permute_map);
809        assert!(matches!(res, Err(Error::InvalidQuery { .. })));
810    }
811
812    #[test]
813    fn test_eval_batch_if_then() {
814        // TODO(discord9): add more tests
815        {
816            let expr = ScalarExpr::If {
817                cond: Box::new(ScalarExpr::Column(0).call_binary(
818                    ScalarExpr::literal(Value::from(0), ConcreteDataType::int32_datatype()),
819                    BinaryFunc::Eq,
820                )),
821                then: Box::new(ScalarExpr::literal(
822                    Value::from(42),
823                    ConcreteDataType::int32_datatype(),
824                )),
825                els: Box::new(ScalarExpr::literal(
826                    Value::from(37),
827                    ConcreteDataType::int32_datatype(),
828                )),
829            };
830            let raw = vec![
831                None,
832                Some(0),
833                Some(1),
834                None,
835                None,
836                Some(0),
837                Some(0),
838                Some(1),
839                Some(1),
840            ];
841            let raw_len = raw.len();
842            let vectors = vec![Int32Vector::from(raw).slice(0, raw_len)];
843
844            let batch = Batch::try_new(vectors, raw_len).unwrap();
845            let expected = Int32Vector::from(vec![
846                None,
847                Some(42),
848                Some(37),
849                None,
850                None,
851                Some(42),
852                Some(42),
853                Some(37),
854                Some(37),
855            ])
856            .slice(0, raw_len);
857            assert_eq!(expr.eval_batch(&batch).unwrap(), expected);
858
859            let raw = vec![Some(0)];
860            let raw_len = raw.len();
861            let vectors = vec![Int32Vector::from(raw).slice(0, raw_len)];
862
863            let batch = Batch::try_new(vectors, raw_len).unwrap();
864            let expected = Int32Vector::from(vec![Some(42)]).slice(0, raw_len);
865            assert_eq!(expr.eval_batch(&batch).unwrap(), expected);
866
867            let raw: Vec<Option<i32>> = vec![];
868            let raw_len = raw.len();
869            let vectors = vec![Int32Vector::from(raw).slice(0, raw_len)];
870
871            let batch = Batch::try_new(vectors, raw_len).unwrap();
872            let expected = Int32Vector::from(vec![]).slice(0, raw_len);
873            assert_eq!(expr.eval_batch(&batch).unwrap(), expected);
874        }
875    }
876}