flow/expr/
scalar.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Scalar expressions.
16
17use std::collections::{BTreeMap, BTreeSet};
18
19use arrow::array::{make_array, ArrayData, ArrayRef};
20use common_error::ext::BoxedError;
21use datatypes::prelude::{ConcreteDataType, DataType};
22use datatypes::value::Value;
23use datatypes::vectors::{BooleanVector, Helper, VectorRef};
24use dfir_rs::lattices::cc_traits::Iter;
25use itertools::Itertools;
26use snafu::{ensure, OptionExt, ResultExt};
27
28use crate::error::{
29    DatafusionSnafu, Error, InvalidQuerySnafu, UnexpectedSnafu, UnsupportedTemporalFilterSnafu,
30};
31use crate::expr::error::{
32    ArrowSnafu, DataTypeSnafu, EvalError, InvalidArgumentSnafu, OptimizeSnafu, TypeMismatchSnafu,
33};
34use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc};
35use crate::expr::{Batch, DfScalarFunction};
36use crate::repr::ColumnType;
37/// A scalar expression with a known type.
38#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Hash)]
39pub struct TypedExpr {
40    /// The expression.
41    pub expr: ScalarExpr,
42    /// The type of the expression.
43    pub typ: ColumnType,
44}
45
46impl TypedExpr {
47    pub fn new(expr: ScalarExpr, typ: ColumnType) -> Self {
48        Self { expr, typ }
49    }
50}
51
52/// A scalar expression, which can be evaluated to a value.
53#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
54pub enum ScalarExpr {
55    /// A column of the input row
56    Column(usize),
57    /// A literal value.
58    /// Extra type info to know original type even when it is null
59    Literal(Value, ConcreteDataType),
60    /// A call to an unmaterializable function.
61    ///
62    /// These functions cannot be evaluated by `ScalarExpr::eval`. They must
63    /// be transformed away by a higher layer.
64    CallUnmaterializable(UnmaterializableFunc),
65    CallUnary {
66        func: UnaryFunc,
67        expr: Box<ScalarExpr>,
68    },
69    CallBinary {
70        func: BinaryFunc,
71        expr1: Box<ScalarExpr>,
72        expr2: Box<ScalarExpr>,
73    },
74    CallVariadic {
75        func: VariadicFunc,
76        exprs: Vec<ScalarExpr>,
77    },
78    CallDf {
79        /// invariant: the input args set inside this [`DfScalarFunction`] is
80        /// always col(0) to col(n-1) where n is the length of `expr`
81        df_scalar_fn: DfScalarFunction,
82        exprs: Vec<ScalarExpr>,
83    },
84    /// Conditionally evaluated expressions.
85    ///
86    /// It is important that `then` and `els` only be evaluated if
87    /// `cond` is true or not, respectively. This is the only way
88    /// users can guard execution (other logical operator do not
89    /// short-circuit) and we need to preserve that.
90    If {
91        cond: Box<ScalarExpr>,
92        then: Box<ScalarExpr>,
93        els: Box<ScalarExpr>,
94    },
95}
96
97impl ScalarExpr {
98    pub fn with_type(self, typ: ColumnType) -> TypedExpr {
99        TypedExpr::new(self, typ)
100    }
101
102    /// try to determine the type of the expression
103    pub fn typ(&self, context: &[ColumnType]) -> Result<ColumnType, Error> {
104        match self {
105            ScalarExpr::Column(i) => context.get(*i).cloned().ok_or_else(|| {
106                UnexpectedSnafu {
107                    reason: format!("column index {} out of range of len={}", i, context.len()),
108                }
109                .build()
110            }),
111            ScalarExpr::Literal(_, typ) => Ok(ColumnType::new_nullable(typ.clone())),
112            ScalarExpr::CallUnmaterializable(func) => {
113                Ok(ColumnType::new_nullable(func.signature().output))
114            }
115            ScalarExpr::CallUnary { func, .. } => {
116                Ok(ColumnType::new_nullable(func.signature().output))
117            }
118            ScalarExpr::CallBinary { func, .. } => {
119                Ok(ColumnType::new_nullable(func.signature().output))
120            }
121            ScalarExpr::CallVariadic { func, .. } => {
122                Ok(ColumnType::new_nullable(func.signature().output))
123            }
124            ScalarExpr::If { then, .. } => then.typ(context),
125            ScalarExpr::CallDf { df_scalar_fn, .. } => {
126                let arrow_typ = df_scalar_fn
127                    .fn_impl
128                    // TODO(discord9): get scheme from args instead?
129                    .data_type(df_scalar_fn.df_schema.as_arrow())
130                    .context({
131                        DatafusionSnafu {
132                            context: "Failed to get data type from datafusion scalar function",
133                        }
134                    })?;
135                let typ = ConcreteDataType::try_from(&arrow_typ)
136                    .map_err(BoxedError::new)
137                    .context(crate::error::ExternalSnafu)?;
138                Ok(ColumnType::new_nullable(typ))
139            }
140        }
141    }
142}
143
144impl ScalarExpr {
145    pub fn cast(self, typ: ConcreteDataType) -> Self {
146        ScalarExpr::CallUnary {
147            func: UnaryFunc::Cast(typ),
148            expr: Box::new(self),
149        }
150    }
151
152    /// apply optimization to the expression, like flatten variadic function
153    pub fn optimize(&mut self) {
154        self.flatten_variadic_fn();
155    }
156
157    /// Because Substrait's `And`/`Or` function is binary, but FlowPlan's
158    /// `And`/`Or` function is variadic, we need to flatten the `And` function if multiple `And`/`Or` functions are nested.
159    fn flatten_variadic_fn(&mut self) {
160        if let ScalarExpr::CallVariadic { func, exprs } = self {
161            let mut new_exprs = vec![];
162            for expr in std::mem::take(exprs) {
163                if let ScalarExpr::CallVariadic {
164                    func: inner_func,
165                    exprs: mut inner_exprs,
166                } = expr
167                {
168                    if *func == inner_func {
169                        for inner_expr in inner_exprs.iter_mut() {
170                            inner_expr.flatten_variadic_fn();
171                        }
172                        new_exprs.extend(inner_exprs);
173                    }
174                } else {
175                    new_exprs.push(expr);
176                }
177            }
178            *exprs = new_exprs;
179        }
180    }
181}
182
183impl ScalarExpr {
184    /// Call a unary function on this expression.
185    pub fn call_unary(self, func: UnaryFunc) -> Self {
186        ScalarExpr::CallUnary {
187            func,
188            expr: Box::new(self),
189        }
190    }
191
192    /// Call a binary function on this expression and another.
193    pub fn call_binary(self, other: Self, func: BinaryFunc) -> Self {
194        ScalarExpr::CallBinary {
195            func,
196            expr1: Box::new(self),
197            expr2: Box::new(other),
198        }
199    }
200
201    pub fn eval_batch(&self, batch: &Batch) -> Result<VectorRef, EvalError> {
202        match self {
203            ScalarExpr::Column(i) => Ok(batch.batch()[*i].clone()),
204            ScalarExpr::Literal(val, dt) => Ok(Helper::try_from_scalar_value(
205                val.try_to_scalar_value(dt).context(DataTypeSnafu {
206                    msg: "Failed to convert literal to scalar value",
207                })?,
208                batch.row_count(),
209            )
210            .context(DataTypeSnafu {
211                msg: "Failed to convert scalar value to vector ref when parsing literal",
212            })?),
213            ScalarExpr::CallUnmaterializable(_) => OptimizeSnafu {
214                reason: "Can't eval unmaterializable function",
215            }
216            .fail()?,
217            ScalarExpr::CallUnary { func, expr } => func.eval_batch(batch, expr),
218            ScalarExpr::CallBinary { func, expr1, expr2 } => func.eval_batch(batch, expr1, expr2),
219            ScalarExpr::CallVariadic { func, exprs } => func.eval_batch(batch, exprs),
220            ScalarExpr::CallDf {
221                df_scalar_fn,
222                exprs,
223            } => df_scalar_fn.eval_batch(batch, exprs),
224            ScalarExpr::If { cond, then, els } => Self::eval_if_then(batch, cond, then, els),
225        }
226    }
227
228    /// NOTE: this if then eval impl assume all given expr are pure, and will not change the state of the world
229    /// since it will evaluate both then and else branch and filter the result
230    fn eval_if_then(
231        batch: &Batch,
232        cond: &ScalarExpr,
233        then: &ScalarExpr,
234        els: &ScalarExpr,
235    ) -> Result<VectorRef, EvalError> {
236        let conds = cond.eval_batch(batch)?;
237        let bool_conds = conds
238            .as_any()
239            .downcast_ref::<BooleanVector>()
240            .context({
241                TypeMismatchSnafu {
242                    expected: ConcreteDataType::boolean_datatype(),
243                    actual: conds.data_type(),
244                }
245            })?
246            .as_boolean_array();
247
248        let indices = bool_conds
249            .into_iter()
250            .enumerate()
251            .map(|(idx, b)| {
252                (
253                    match b {
254                        Some(true) => 0,  // then branch vector
255                        Some(false) => 1, // else branch vector
256                        None => 2,        // null vector
257                    },
258                    idx,
259                )
260            })
261            .collect_vec();
262
263        let then_input_vec = then.eval_batch(batch)?;
264        let else_input_vec = els.eval_batch(batch)?;
265
266        ensure!(
267            then_input_vec.data_type() == else_input_vec.data_type(),
268            TypeMismatchSnafu {
269                expected: then_input_vec.data_type(),
270                actual: else_input_vec.data_type(),
271            }
272        );
273
274        ensure!(
275            then_input_vec.len() == else_input_vec.len() && then_input_vec.len() == batch.row_count(),
276            InvalidArgumentSnafu {
277                reason: format!(
278                    "then and else branch must have the same length(found {} and {}) which equals input batch's row count(which is {})",
279                    then_input_vec.len(),
280                    else_input_vec.len(),
281                    batch.row_count()
282                )
283            }
284        );
285
286        fn new_nulls(dt: &arrow_schema::DataType, len: usize) -> ArrayRef {
287            let data = ArrayData::new_null(dt, len);
288            make_array(data)
289        }
290
291        let null_input_vec = new_nulls(
292            &then_input_vec.data_type().as_arrow_type(),
293            batch.row_count(),
294        );
295
296        let interleave_values = vec![
297            then_input_vec.to_arrow_array(),
298            else_input_vec.to_arrow_array(),
299            null_input_vec,
300        ];
301        let int_ref: Vec<_> = interleave_values.iter().map(|x| x.as_ref()).collect();
302
303        let interleave_res_arr =
304            arrow::compute::interleave(&int_ref, &indices).context(ArrowSnafu {
305                context: "Failed to interleave output arrays",
306            })?;
307        let res_vec = Helper::try_into_vector(interleave_res_arr).context(DataTypeSnafu {
308            msg: "Failed to convert arrow array to vector",
309        })?;
310        Ok(res_vec)
311    }
312
313    /// Eval this expression with the given values.
314    ///
315    /// TODO(discord9): add tests to make sure `eval_batch` is the same as `eval` in
316    /// most cases
317    pub fn eval(&self, values: &[Value]) -> Result<Value, EvalError> {
318        match self {
319            ScalarExpr::Column(index) => Ok(values[*index].clone()),
320            ScalarExpr::Literal(row_res, _ty) => Ok(row_res.clone()),
321            ScalarExpr::CallUnmaterializable(_) => OptimizeSnafu {
322                reason: "Can't eval unmaterializable function".to_string(),
323            }
324            .fail(),
325            ScalarExpr::CallUnary { func, expr } => func.eval(values, expr),
326            ScalarExpr::CallBinary { func, expr1, expr2 } => func.eval(values, expr1, expr2),
327            ScalarExpr::CallVariadic { func, exprs } => func.eval(values, exprs),
328            ScalarExpr::If { cond, then, els } => match cond.eval(values) {
329                Ok(Value::Boolean(true)) => then.eval(values),
330                Ok(Value::Boolean(false)) => els.eval(values),
331                _ => InvalidArgumentSnafu {
332                    reason: "if condition must be boolean".to_string(),
333                }
334                .fail(),
335            },
336            ScalarExpr::CallDf {
337                df_scalar_fn,
338                exprs,
339            } => df_scalar_fn.eval(values, exprs),
340        }
341    }
342
343    /// Rewrites column indices with their value in `permutation`.
344    ///
345    /// This method is applicable even when `permutation` is not a
346    /// strict permutation, and it only needs to have entries for
347    /// each column referenced in `self`.
348    pub fn permute(&mut self, permutation: &[usize]) -> Result<(), Error> {
349        // check first so that we don't end up with a partially permuted expression
350        ensure!(
351            self.get_all_ref_columns()
352                .into_iter()
353                .all(|i| i < permutation.len()),
354            InvalidQuerySnafu {
355                reason: format!(
356                    "permutation {:?} is not a valid permutation for expression {:?}",
357                    permutation, self
358                ),
359            }
360        );
361
362        self.visit_mut_post_nolimit(&mut |e| {
363            if let ScalarExpr::Column(old_i) = e {
364                *old_i = permutation[*old_i];
365            }
366            Ok(())
367        })?;
368        Ok(())
369    }
370
371    /// Rewrites column indices with their value in `permutation`.
372    ///
373    /// This method is applicable even when `permutation` is not a
374    /// strict permutation, and it only needs to have entries for
375    /// each column referenced in `self`.
376    pub fn permute_map(&mut self, permutation: &BTreeMap<usize, usize>) -> Result<(), Error> {
377        // check first so that we don't end up with a partially permuted expression
378        ensure!(
379            self.get_all_ref_columns()
380                .is_subset(&permutation.keys().cloned().collect()),
381            InvalidQuerySnafu {
382                reason: format!(
383                    "permutation {:?} is not a valid permutation for expression {:?}",
384                    permutation, self
385                ),
386            }
387        );
388
389        self.visit_mut_post_nolimit(&mut |e| {
390            if let ScalarExpr::Column(old_i) = e {
391                *old_i = permutation[old_i];
392            }
393            Ok(())
394        })
395    }
396
397    /// Returns the set of columns that are referenced by `self`.
398    pub fn get_all_ref_columns(&self) -> BTreeSet<usize> {
399        let mut support = BTreeSet::new();
400        self.visit_post_nolimit(&mut |e| {
401            if let ScalarExpr::Column(i) = e {
402                support.insert(*i);
403            }
404            Ok(())
405        })
406        .unwrap();
407        support
408    }
409
410    /// Return true if the expression is a column reference.
411    pub fn is_column(&self) -> bool {
412        matches!(self, ScalarExpr::Column(_))
413    }
414
415    /// Cast the expression to a column reference if it is one.
416    pub fn as_column(&self) -> Option<usize> {
417        if let ScalarExpr::Column(i) = self {
418            Some(*i)
419        } else {
420            None
421        }
422    }
423
424    /// Cast the expression to a literal if it is one.
425    pub fn as_literal(&self) -> Option<Value> {
426        if let ScalarExpr::Literal(lit, _column_type) = self {
427            Some(lit.clone())
428        } else {
429            None
430        }
431    }
432
433    /// Return true if the expression is a literal.
434    pub fn is_literal(&self) -> bool {
435        matches!(self, ScalarExpr::Literal(..))
436    }
437
438    /// Return true if the expression is a literal true.
439    pub fn is_literal_true(&self) -> bool {
440        Some(Value::Boolean(true)) == self.as_literal()
441    }
442
443    /// Return true if the expression is a literal false.
444    pub fn is_literal_false(&self) -> bool {
445        Some(Value::Boolean(false)) == self.as_literal()
446    }
447
448    /// Return true if the expression is a literal null.
449    pub fn is_literal_null(&self) -> bool {
450        Some(Value::Null) == self.as_literal()
451    }
452
453    /// Build a literal null
454    pub fn literal_null() -> Self {
455        ScalarExpr::Literal(Value::Null, ConcreteDataType::null_datatype())
456    }
457
458    /// Build a literal from value and type
459    pub fn literal(res: Value, typ: ConcreteDataType) -> Self {
460        ScalarExpr::Literal(res, typ)
461    }
462
463    /// Build a literal false
464    pub fn literal_false() -> Self {
465        ScalarExpr::Literal(Value::Boolean(false), ConcreteDataType::boolean_datatype())
466    }
467
468    /// Build a literal true
469    pub fn literal_true() -> Self {
470        ScalarExpr::Literal(Value::Boolean(true), ConcreteDataType::boolean_datatype())
471    }
472}
473
474impl ScalarExpr {
475    /// visit post-order without stack call limit, but may cause stack overflow
476    fn visit_post_nolimit<F>(&self, f: &mut F) -> Result<(), EvalError>
477    where
478        F: FnMut(&Self) -> Result<(), EvalError>,
479    {
480        self.visit_children(|e| e.visit_post_nolimit(f))?;
481        f(self)
482    }
483
484    fn visit_children<F>(&self, mut f: F) -> Result<(), EvalError>
485    where
486        F: FnMut(&Self) -> Result<(), EvalError>,
487    {
488        match self {
489            ScalarExpr::Column(_)
490            | ScalarExpr::Literal(_, _)
491            | ScalarExpr::CallUnmaterializable(_) => Ok(()),
492            ScalarExpr::CallUnary { expr, .. } => f(expr),
493            ScalarExpr::CallBinary { expr1, expr2, .. } => {
494                f(expr1)?;
495                f(expr2)
496            }
497            ScalarExpr::CallVariadic { exprs, .. } => {
498                for expr in exprs {
499                    f(expr)?;
500                }
501                Ok(())
502            }
503            ScalarExpr::If { cond, then, els } => {
504                f(cond)?;
505                f(then)?;
506                f(els)
507            }
508            ScalarExpr::CallDf {
509                df_scalar_fn: _,
510                exprs,
511            } => {
512                for expr in exprs {
513                    f(expr)?;
514                }
515                Ok(())
516            }
517        }
518    }
519
520    fn visit_mut_post_nolimit<F>(&mut self, f: &mut F) -> Result<(), Error>
521    where
522        F: FnMut(&mut Self) -> Result<(), Error>,
523    {
524        self.visit_mut_children(|e: &mut Self| e.visit_mut_post_nolimit(f))?;
525        f(self)
526    }
527
528    fn visit_mut_children<F>(&mut self, mut f: F) -> Result<(), Error>
529    where
530        F: FnMut(&mut Self) -> Result<(), Error>,
531    {
532        match self {
533            ScalarExpr::Column(_)
534            | ScalarExpr::Literal(_, _)
535            | ScalarExpr::CallUnmaterializable(_) => Ok(()),
536            ScalarExpr::CallUnary { expr, .. } => f(expr),
537            ScalarExpr::CallBinary { expr1, expr2, .. } => {
538                f(expr1)?;
539                f(expr2)
540            }
541            ScalarExpr::CallVariadic { exprs, .. } => {
542                for expr in exprs {
543                    f(expr)?;
544                }
545                Ok(())
546            }
547            ScalarExpr::If { cond, then, els } => {
548                f(cond)?;
549                f(then)?;
550                f(els)
551            }
552            ScalarExpr::CallDf {
553                df_scalar_fn: _,
554                exprs,
555            } => {
556                for expr in exprs {
557                    f(expr)?;
558                }
559                Ok(())
560            }
561        }
562    }
563}
564
565impl ScalarExpr {
566    /// if expr contains function `Now`
567    pub fn contains_temporal(&self) -> bool {
568        let mut contains = false;
569        self.visit_post_nolimit(&mut |e| {
570            if let ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now) = e {
571                contains = true;
572            }
573            Ok(())
574        })
575        .unwrap();
576        contains
577    }
578
579    /// extract lower or upper bound of `Now` for expr, where `lower bound <= expr < upper bound`
580    ///
581    /// returned bool indicates whether the bound is upper bound:
582    ///
583    /// false for lower bound, true for upper bound
584    /// TODO(discord9): allow simple transform like `now() + a < b` to `now() < b - a`
585    pub fn extract_bound(&self) -> Result<(Option<Self>, Option<Self>), Error> {
586        let unsupported_err = |msg: &str| {
587            UnsupportedTemporalFilterSnafu {
588                reason: msg.to_string(),
589            }
590            .fail()
591        };
592
593        let Self::CallBinary {
594            mut func,
595            mut expr1,
596            mut expr2,
597        } = self.clone()
598        else {
599            return unsupported_err("Not a binary expression");
600        };
601
602        // TODO(discord9): support simple transform like `now() + a < b` to `now() < b - a`
603
604        let expr1_is_now = *expr1 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now);
605        let expr2_is_now = *expr2 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now);
606
607        if !(expr1_is_now ^ expr2_is_now) {
608            return unsupported_err("None of the sides of the comparison is `now()`");
609        }
610
611        if expr2_is_now {
612            std::mem::swap(&mut expr1, &mut expr2);
613            func = BinaryFunc::reverse_compare(&func)?;
614        }
615
616        let step = |expr: ScalarExpr| expr.call_unary(UnaryFunc::StepTimestamp);
617        match func {
618            // now == expr2 -> now <= expr2 && now < expr2 + 1
619            BinaryFunc::Eq => Ok((Some(*expr2.clone()), Some(step(*expr2)))),
620            // now < expr2 -> now < expr2
621            BinaryFunc::Lt => Ok((None, Some(*expr2))),
622            // now <= expr2 -> now < expr2 + 1
623            BinaryFunc::Lte => Ok((None, Some(step(*expr2)))),
624            // now > expr2 -> now >= expr2 + 1
625            BinaryFunc::Gt => Ok((Some(step(*expr2)), None)),
626            // now >= expr2 -> now >= expr2
627            BinaryFunc::Gte => Ok((Some(*expr2), None)),
628            _ => unreachable!("Already checked"),
629        }
630    }
631}
632
633#[cfg(test)]
634mod test {
635    use datatypes::vectors::{Int32Vector, Vector};
636    use pretty_assertions::assert_eq;
637
638    use super::*;
639
640    #[test]
641    fn test_extract_bound() {
642        let test_list: [(ScalarExpr, Result<_, EvalError>); 5] = [
643            // col(0) == now
644            (
645                ScalarExpr::CallBinary {
646                    func: BinaryFunc::Eq,
647                    expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
648                    expr2: Box::new(ScalarExpr::Column(0)),
649                },
650                Ok((
651                    Some(ScalarExpr::Column(0)),
652                    Some(ScalarExpr::CallUnary {
653                        func: UnaryFunc::StepTimestamp,
654                        expr: Box::new(ScalarExpr::Column(0)),
655                    }),
656                )),
657            ),
658            // now < col(0)
659            (
660                ScalarExpr::CallBinary {
661                    func: BinaryFunc::Lt,
662                    expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
663                    expr2: Box::new(ScalarExpr::Column(0)),
664                },
665                Ok((None, Some(ScalarExpr::Column(0)))),
666            ),
667            // now <= col(0)
668            (
669                ScalarExpr::CallBinary {
670                    func: BinaryFunc::Lte,
671                    expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
672                    expr2: Box::new(ScalarExpr::Column(0)),
673                },
674                Ok((
675                    None,
676                    Some(ScalarExpr::CallUnary {
677                        func: UnaryFunc::StepTimestamp,
678                        expr: Box::new(ScalarExpr::Column(0)),
679                    }),
680                )),
681            ),
682            // now > col(0) -> now >= col(0) + 1
683            (
684                ScalarExpr::CallBinary {
685                    func: BinaryFunc::Gt,
686                    expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
687                    expr2: Box::new(ScalarExpr::Column(0)),
688                },
689                Ok((
690                    Some(ScalarExpr::CallUnary {
691                        func: UnaryFunc::StepTimestamp,
692                        expr: Box::new(ScalarExpr::Column(0)),
693                    }),
694                    None,
695                )),
696            ),
697            // now >= col(0)
698            (
699                ScalarExpr::CallBinary {
700                    func: BinaryFunc::Gte,
701                    expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
702                    expr2: Box::new(ScalarExpr::Column(0)),
703                },
704                Ok((Some(ScalarExpr::Column(0)), None)),
705            ),
706        ];
707        for (expr, expected) in test_list.into_iter() {
708            let actual = expr.extract_bound();
709            // EvalError is not Eq, so we need to compare the error message
710            match (actual, expected) {
711                (Ok(l), Ok(r)) => assert_eq!(l, r),
712                (l, r) => panic!("expected: {:?}, actual: {:?}", r, l),
713            }
714        }
715    }
716
717    #[test]
718    fn test_bad_permute() {
719        let mut expr = ScalarExpr::Column(4);
720        let permutation = vec![1, 2, 3];
721        let res = expr.permute(&permutation);
722        assert!(matches!(res, Err(Error::InvalidQuery { .. })));
723
724        let mut expr = ScalarExpr::Column(0);
725        let permute_map = BTreeMap::from([(1, 2), (3, 4)]);
726        let res = expr.permute_map(&permute_map);
727        assert!(matches!(res, Err(Error::InvalidQuery { .. })));
728    }
729
730    #[test]
731    fn test_eval_batch_if_then() {
732        // TODO(discord9): add more tests
733        {
734            let expr = ScalarExpr::If {
735                cond: Box::new(ScalarExpr::Column(0).call_binary(
736                    ScalarExpr::literal(Value::from(0), ConcreteDataType::int32_datatype()),
737                    BinaryFunc::Eq,
738                )),
739                then: Box::new(ScalarExpr::literal(
740                    Value::from(42),
741                    ConcreteDataType::int32_datatype(),
742                )),
743                els: Box::new(ScalarExpr::literal(
744                    Value::from(37),
745                    ConcreteDataType::int32_datatype(),
746                )),
747            };
748            let raw = vec![
749                None,
750                Some(0),
751                Some(1),
752                None,
753                None,
754                Some(0),
755                Some(0),
756                Some(1),
757                Some(1),
758            ];
759            let raw_len = raw.len();
760            let vectors = vec![Int32Vector::from(raw).slice(0, raw_len)];
761
762            let batch = Batch::try_new(vectors, raw_len).unwrap();
763            let expected = Int32Vector::from(vec![
764                None,
765                Some(42),
766                Some(37),
767                None,
768                None,
769                Some(42),
770                Some(42),
771                Some(37),
772                Some(37),
773            ])
774            .slice(0, raw_len);
775            assert_eq!(expr.eval_batch(&batch).unwrap(), expected);
776
777            let raw = vec![Some(0)];
778            let raw_len = raw.len();
779            let vectors = vec![Int32Vector::from(raw).slice(0, raw_len)];
780
781            let batch = Batch::try_new(vectors, raw_len).unwrap();
782            let expected = Int32Vector::from(vec![Some(42)]).slice(0, raw_len);
783            assert_eq!(expr.eval_batch(&batch).unwrap(), expected);
784
785            let raw: Vec<Option<i32>> = vec![];
786            let raw_len = raw.len();
787            let vectors = vec![Int32Vector::from(raw).slice(0, raw_len)];
788
789            let batch = Batch::try_new(vectors, raw_len).unwrap();
790            let expected = Int32Vector::from(vec![]).slice(0, raw_len);
791            assert_eq!(expr.eval_batch(&batch).unwrap(), expected);
792        }
793    }
794}