flow/expr/
scalar.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Scalar expressions.
16
17use std::collections::{BTreeMap, BTreeSet};
18use std::sync::Arc;
19
20use arrow::array::{make_array, ArrayData, ArrayRef, BooleanArray};
21use arrow::buffer::BooleanBuffer;
22use arrow::compute::or_kleene;
23use common_error::ext::BoxedError;
24use datafusion::physical_expr_common::datum::compare_with_eq;
25use datafusion_common::DataFusionError;
26use datatypes::prelude::{ConcreteDataType, DataType};
27use datatypes::value::Value;
28use datatypes::vectors::{BooleanVector, Helper, VectorRef};
29use dfir_rs::lattices::cc_traits::Iter;
30use itertools::Itertools;
31use snafu::{ensure, OptionExt, ResultExt};
32
33use crate::error::{
34    DatafusionSnafu, Error, InvalidQuerySnafu, UnexpectedSnafu, UnsupportedTemporalFilterSnafu,
35};
36use crate::expr::error::{
37    ArrowSnafu, DataTypeSnafu, EvalError, InvalidArgumentSnafu, OptimizeSnafu, TypeMismatchSnafu,
38};
39use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc};
40use crate::expr::{Batch, DfScalarFunction};
41use crate::repr::ColumnType;
42/// A scalar expression with a known type.
43#[derive(Ord, PartialOrd, Clone, Debug, Eq, PartialEq, Hash)]
44pub struct TypedExpr {
45    /// The expression.
46    pub expr: ScalarExpr,
47    /// The type of the expression.
48    pub typ: ColumnType,
49}
50
51impl TypedExpr {
52    pub fn new(expr: ScalarExpr, typ: ColumnType) -> Self {
53        Self { expr, typ }
54    }
55}
56
57/// A scalar expression, which can be evaluated to a value.
58#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
59pub enum ScalarExpr {
60    /// A column of the input row
61    Column(usize),
62    /// A literal value.
63    /// Extra type info to know original type even when it is null
64    Literal(Value, ConcreteDataType),
65    /// A call to an unmaterializable function.
66    ///
67    /// These functions cannot be evaluated by `ScalarExpr::eval`. They must
68    /// be transformed away by a higher layer.
69    CallUnmaterializable(UnmaterializableFunc),
70    CallUnary {
71        func: UnaryFunc,
72        expr: Box<ScalarExpr>,
73    },
74    CallBinary {
75        func: BinaryFunc,
76        expr1: Box<ScalarExpr>,
77        expr2: Box<ScalarExpr>,
78    },
79    CallVariadic {
80        func: VariadicFunc,
81        exprs: Vec<ScalarExpr>,
82    },
83    CallDf {
84        /// invariant: the input args set inside this [`DfScalarFunction`] is
85        /// always col(0) to col(n-1) where n is the length of `expr`
86        df_scalar_fn: DfScalarFunction,
87        exprs: Vec<ScalarExpr>,
88    },
89    /// Conditionally evaluated expressions.
90    ///
91    /// It is important that `then` and `els` only be evaluated if
92    /// `cond` is true or not, respectively. This is the only way
93    /// users can guard execution (other logical operator do not
94    /// short-circuit) and we need to preserve that.
95    If {
96        cond: Box<ScalarExpr>,
97        then: Box<ScalarExpr>,
98        els: Box<ScalarExpr>,
99    },
100    InList {
101        expr: Box<ScalarExpr>,
102        list: Vec<ScalarExpr>,
103    },
104}
105
106impl ScalarExpr {
107    pub fn with_type(self, typ: ColumnType) -> TypedExpr {
108        TypedExpr::new(self, typ)
109    }
110
111    /// try to determine the type of the expression
112    pub fn typ(&self, context: &[ColumnType]) -> Result<ColumnType, Error> {
113        match self {
114            ScalarExpr::Column(i) => context.get(*i).cloned().ok_or_else(|| {
115                UnexpectedSnafu {
116                    reason: format!("column index {} out of range of len={}", i, context.len()),
117                }
118                .build()
119            }),
120            ScalarExpr::Literal(_, typ) => Ok(ColumnType::new_nullable(typ.clone())),
121            ScalarExpr::CallUnmaterializable(func) => {
122                Ok(ColumnType::new_nullable(func.signature().output))
123            }
124            ScalarExpr::CallUnary { func, .. } => {
125                Ok(ColumnType::new_nullable(func.signature().output))
126            }
127            ScalarExpr::CallBinary { func, .. } => {
128                Ok(ColumnType::new_nullable(func.signature().output))
129            }
130            ScalarExpr::CallVariadic { func, .. } => {
131                Ok(ColumnType::new_nullable(func.signature().output))
132            }
133            ScalarExpr::If { then, .. } => then.typ(context),
134            ScalarExpr::CallDf { df_scalar_fn, .. } => {
135                let arrow_typ = df_scalar_fn
136                    .fn_impl
137                    // TODO(discord9): get scheme from args instead?
138                    .data_type(df_scalar_fn.df_schema.as_arrow())
139                    .context({
140                        DatafusionSnafu {
141                            context: "Failed to get data type from datafusion scalar function",
142                        }
143                    })?;
144                let typ = ConcreteDataType::try_from(&arrow_typ)
145                    .map_err(BoxedError::new)
146                    .context(crate::error::ExternalSnafu)?;
147                Ok(ColumnType::new_nullable(typ))
148            }
149            ScalarExpr::InList { expr, .. } => expr.typ(context),
150        }
151    }
152}
153
154impl ScalarExpr {
155    pub fn cast(self, typ: ConcreteDataType) -> Self {
156        ScalarExpr::CallUnary {
157            func: UnaryFunc::Cast(typ),
158            expr: Box::new(self),
159        }
160    }
161
162    /// apply optimization to the expression, like flatten variadic function
163    pub fn optimize(&mut self) {
164        self.flatten_variadic_fn();
165    }
166
167    /// Because Substrait's `And`/`Or` function is binary, but FlowPlan's
168    /// `And`/`Or` function is variadic, we need to flatten the `And` function if multiple `And`/`Or` functions are nested.
169    fn flatten_variadic_fn(&mut self) {
170        if let ScalarExpr::CallVariadic { func, exprs } = self {
171            let mut new_exprs = vec![];
172            for expr in std::mem::take(exprs) {
173                if let ScalarExpr::CallVariadic {
174                    func: inner_func,
175                    exprs: mut inner_exprs,
176                } = expr
177                {
178                    if *func == inner_func {
179                        for inner_expr in inner_exprs.iter_mut() {
180                            inner_expr.flatten_variadic_fn();
181                        }
182                        new_exprs.extend(inner_exprs);
183                    }
184                } else {
185                    new_exprs.push(expr);
186                }
187            }
188            *exprs = new_exprs;
189        }
190    }
191}
192
193impl ScalarExpr {
194    /// Call a unary function on this expression.
195    pub fn call_unary(self, func: UnaryFunc) -> Self {
196        ScalarExpr::CallUnary {
197            func,
198            expr: Box::new(self),
199        }
200    }
201
202    /// Call a binary function on this expression and another.
203    pub fn call_binary(self, other: Self, func: BinaryFunc) -> Self {
204        ScalarExpr::CallBinary {
205            func,
206            expr1: Box::new(self),
207            expr2: Box::new(other),
208        }
209    }
210
211    pub fn eval_batch(&self, batch: &Batch) -> Result<VectorRef, EvalError> {
212        match self {
213            ScalarExpr::Column(i) => Ok(batch.batch()[*i].clone()),
214            ScalarExpr::Literal(val, dt) => Ok(Helper::try_from_scalar_value(
215                val.try_to_scalar_value(dt).context(DataTypeSnafu {
216                    msg: "Failed to convert literal to scalar value",
217                })?,
218                batch.row_count(),
219            )
220            .context(DataTypeSnafu {
221                msg: "Failed to convert scalar value to vector ref when parsing literal",
222            })?),
223            ScalarExpr::CallUnmaterializable(_) => OptimizeSnafu {
224                reason: "Can't eval unmaterializable function",
225            }
226            .fail()?,
227            ScalarExpr::CallUnary { func, expr } => func.eval_batch(batch, expr),
228            ScalarExpr::CallBinary { func, expr1, expr2 } => func.eval_batch(batch, expr1, expr2),
229            ScalarExpr::CallVariadic { func, exprs } => func.eval_batch(batch, exprs),
230            ScalarExpr::CallDf {
231                df_scalar_fn,
232                exprs,
233            } => df_scalar_fn.eval_batch(batch, exprs),
234            ScalarExpr::If { cond, then, els } => Self::eval_if_then(batch, cond, then, els),
235            ScalarExpr::InList { expr, list } => Self::eval_in_list(batch, expr, list),
236        }
237    }
238
239    fn eval_in_list(
240        batch: &Batch,
241        expr: &ScalarExpr,
242        list: &[ScalarExpr],
243    ) -> Result<VectorRef, EvalError> {
244        let eval_list = list
245            .iter()
246            .map(|e| e.eval_batch(batch))
247            .collect::<Result<Vec<_>, _>>()?;
248        let eval_expr = expr.eval_batch(batch)?;
249
250        ensure!(
251            eval_list
252                .iter()
253                .all(|v| v.data_type() == eval_expr.data_type()),
254            TypeMismatchSnafu {
255                expected: eval_expr.data_type(),
256                actual: eval_list
257                    .iter()
258                    .find(|v| v.data_type() != eval_expr.data_type())
259                    .map(|v| v.data_type())
260                    .unwrap(),
261            }
262        );
263
264        let lhs = eval_expr.to_arrow_array();
265
266        let found = eval_list
267            .iter()
268            .map(|v| v.to_arrow_array())
269            .try_fold(
270                BooleanArray::new(BooleanBuffer::new_unset(batch.row_count()), None),
271                |result, in_list_elem| -> Result<BooleanArray, DataFusionError> {
272                    let rhs = compare_with_eq(&lhs, &in_list_elem, false)?;
273
274                    Ok(or_kleene(&result, &rhs)?)
275                },
276            )
277            .with_context(|_| crate::expr::error::DatafusionSnafu {
278                context: "Failed to compare eval_expr with eval_list",
279            })?;
280
281        let res = BooleanVector::from(found);
282
283        Ok(Arc::new(res))
284    }
285
286    /// NOTE: this if then eval impl assume all given expr are pure, and will not change the state of the world
287    /// since it will evaluate both then and else branch and filter the result
288    fn eval_if_then(
289        batch: &Batch,
290        cond: &ScalarExpr,
291        then: &ScalarExpr,
292        els: &ScalarExpr,
293    ) -> Result<VectorRef, EvalError> {
294        let conds = cond.eval_batch(batch)?;
295        let bool_conds = conds
296            .as_any()
297            .downcast_ref::<BooleanVector>()
298            .context({
299                TypeMismatchSnafu {
300                    expected: ConcreteDataType::boolean_datatype(),
301                    actual: conds.data_type(),
302                }
303            })?
304            .as_boolean_array();
305
306        let indices = bool_conds
307            .into_iter()
308            .enumerate()
309            .map(|(idx, b)| {
310                (
311                    match b {
312                        Some(true) => 0,  // then branch vector
313                        Some(false) => 1, // else branch vector
314                        None => 2,        // null vector
315                    },
316                    idx,
317                )
318            })
319            .collect_vec();
320
321        let then_input_vec = then.eval_batch(batch)?;
322        let else_input_vec = els.eval_batch(batch)?;
323
324        ensure!(
325            then_input_vec.data_type() == else_input_vec.data_type(),
326            TypeMismatchSnafu {
327                expected: then_input_vec.data_type(),
328                actual: else_input_vec.data_type(),
329            }
330        );
331
332        ensure!(
333            then_input_vec.len() == else_input_vec.len() && then_input_vec.len() == batch.row_count(),
334            InvalidArgumentSnafu {
335                reason: format!(
336                    "then and else branch must have the same length(found {} and {}) which equals input batch's row count(which is {})",
337                    then_input_vec.len(),
338                    else_input_vec.len(),
339                    batch.row_count()
340                )
341            }
342        );
343
344        fn new_nulls(dt: &arrow_schema::DataType, len: usize) -> ArrayRef {
345            let data = ArrayData::new_null(dt, len);
346            make_array(data)
347        }
348
349        let null_input_vec = new_nulls(
350            &then_input_vec.data_type().as_arrow_type(),
351            batch.row_count(),
352        );
353
354        let interleave_values = vec![
355            then_input_vec.to_arrow_array(),
356            else_input_vec.to_arrow_array(),
357            null_input_vec,
358        ];
359        let int_ref: Vec<_> = interleave_values.iter().map(|x| x.as_ref()).collect();
360
361        let interleave_res_arr =
362            arrow::compute::interleave(&int_ref, &indices).context(ArrowSnafu {
363                context: "Failed to interleave output arrays",
364            })?;
365        let res_vec = Helper::try_into_vector(interleave_res_arr).context(DataTypeSnafu {
366            msg: "Failed to convert arrow array to vector",
367        })?;
368        Ok(res_vec)
369    }
370
371    /// Eval this expression with the given values.
372    ///
373    /// TODO(discord9): add tests to make sure `eval_batch` is the same as `eval` in
374    /// most cases
375    pub fn eval(&self, values: &[Value]) -> Result<Value, EvalError> {
376        match self {
377            ScalarExpr::Column(index) => Ok(values[*index].clone()),
378            ScalarExpr::Literal(row_res, _ty) => Ok(row_res.clone()),
379            ScalarExpr::CallUnmaterializable(_) => OptimizeSnafu {
380                reason: "Can't eval unmaterializable function".to_string(),
381            }
382            .fail(),
383            ScalarExpr::CallUnary { func, expr } => func.eval(values, expr),
384            ScalarExpr::CallBinary { func, expr1, expr2 } => func.eval(values, expr1, expr2),
385            ScalarExpr::CallVariadic { func, exprs } => func.eval(values, exprs),
386            ScalarExpr::If { cond, then, els } => match cond.eval(values) {
387                Ok(Value::Boolean(true)) => then.eval(values),
388                Ok(Value::Boolean(false)) => els.eval(values),
389                _ => InvalidArgumentSnafu {
390                    reason: "if condition must be boolean".to_string(),
391                }
392                .fail(),
393            },
394            ScalarExpr::CallDf {
395                df_scalar_fn,
396                exprs,
397            } => df_scalar_fn.eval(values, exprs),
398            ScalarExpr::InList { expr, list } => {
399                let eval_expr = expr.eval(values)?;
400                let eval_list = list
401                    .iter()
402                    .map(|v| v.eval(values))
403                    .collect::<Result<Vec<_>, _>>()?;
404                let found = eval_list.iter().any(|item| *item == eval_expr);
405                Ok(Value::Boolean(found))
406            }
407        }
408    }
409
410    /// Rewrites column indices with their value in `permutation`.
411    ///
412    /// This method is applicable even when `permutation` is not a
413    /// strict permutation, and it only needs to have entries for
414    /// each column referenced in `self`.
415    pub fn permute(&mut self, permutation: &[usize]) -> Result<(), Error> {
416        // check first so that we don't end up with a partially permuted expression
417        ensure!(
418            self.get_all_ref_columns()
419                .into_iter()
420                .all(|i| i < permutation.len()),
421            InvalidQuerySnafu {
422                reason: format!(
423                    "permutation {:?} is not a valid permutation for expression {:?}",
424                    permutation, self
425                ),
426            }
427        );
428
429        self.visit_mut_post_nolimit(&mut |e| {
430            if let ScalarExpr::Column(old_i) = e {
431                *old_i = permutation[*old_i];
432            }
433            Ok(())
434        })?;
435        Ok(())
436    }
437
438    /// Rewrites column indices with their value in `permutation`.
439    ///
440    /// This method is applicable even when `permutation` is not a
441    /// strict permutation, and it only needs to have entries for
442    /// each column referenced in `self`.
443    pub fn permute_map(&mut self, permutation: &BTreeMap<usize, usize>) -> Result<(), Error> {
444        // check first so that we don't end up with a partially permuted expression
445        ensure!(
446            self.get_all_ref_columns()
447                .is_subset(&permutation.keys().cloned().collect()),
448            InvalidQuerySnafu {
449                reason: format!(
450                    "permutation {:?} is not a valid permutation for expression {:?}",
451                    permutation, self
452                ),
453            }
454        );
455
456        self.visit_mut_post_nolimit(&mut |e| {
457            if let ScalarExpr::Column(old_i) = e {
458                *old_i = permutation[old_i];
459            }
460            Ok(())
461        })
462    }
463
464    /// Returns the set of columns that are referenced by `self`.
465    pub fn get_all_ref_columns(&self) -> BTreeSet<usize> {
466        let mut support = BTreeSet::new();
467        self.visit_post_nolimit(&mut |e| {
468            if let ScalarExpr::Column(i) = e {
469                support.insert(*i);
470            }
471            Ok(())
472        })
473        .unwrap();
474        support
475    }
476
477    /// Return true if the expression is a column reference.
478    pub fn is_column(&self) -> bool {
479        matches!(self, ScalarExpr::Column(_))
480    }
481
482    /// Cast the expression to a column reference if it is one.
483    pub fn as_column(&self) -> Option<usize> {
484        if let ScalarExpr::Column(i) = self {
485            Some(*i)
486        } else {
487            None
488        }
489    }
490
491    /// Cast the expression to a literal if it is one.
492    pub fn as_literal(&self) -> Option<Value> {
493        if let ScalarExpr::Literal(lit, _column_type) = self {
494            Some(lit.clone())
495        } else {
496            None
497        }
498    }
499
500    /// Return true if the expression is a literal.
501    pub fn is_literal(&self) -> bool {
502        matches!(self, ScalarExpr::Literal(..))
503    }
504
505    /// Return true if the expression is a literal true.
506    pub fn is_literal_true(&self) -> bool {
507        Some(Value::Boolean(true)) == self.as_literal()
508    }
509
510    /// Return true if the expression is a literal false.
511    pub fn is_literal_false(&self) -> bool {
512        Some(Value::Boolean(false)) == self.as_literal()
513    }
514
515    /// Return true if the expression is a literal null.
516    pub fn is_literal_null(&self) -> bool {
517        Some(Value::Null) == self.as_literal()
518    }
519
520    /// Build a literal null
521    pub fn literal_null() -> Self {
522        ScalarExpr::Literal(Value::Null, ConcreteDataType::null_datatype())
523    }
524
525    /// Build a literal from value and type
526    pub fn literal(res: Value, typ: ConcreteDataType) -> Self {
527        ScalarExpr::Literal(res, typ)
528    }
529
530    /// Build a literal false
531    pub fn literal_false() -> Self {
532        ScalarExpr::Literal(Value::Boolean(false), ConcreteDataType::boolean_datatype())
533    }
534
535    /// Build a literal true
536    pub fn literal_true() -> Self {
537        ScalarExpr::Literal(Value::Boolean(true), ConcreteDataType::boolean_datatype())
538    }
539}
540
541impl ScalarExpr {
542    /// visit post-order without stack call limit, but may cause stack overflow
543    fn visit_post_nolimit<F>(&self, f: &mut F) -> Result<(), EvalError>
544    where
545        F: FnMut(&Self) -> Result<(), EvalError>,
546    {
547        self.visit_children(|e| e.visit_post_nolimit(f))?;
548        f(self)
549    }
550
551    fn visit_children<F>(&self, mut f: F) -> Result<(), EvalError>
552    where
553        F: FnMut(&Self) -> Result<(), EvalError>,
554    {
555        match self {
556            ScalarExpr::Column(_)
557            | ScalarExpr::Literal(_, _)
558            | ScalarExpr::CallUnmaterializable(_) => Ok(()),
559            ScalarExpr::CallUnary { expr, .. } => f(expr),
560            ScalarExpr::CallBinary { expr1, expr2, .. } => {
561                f(expr1)?;
562                f(expr2)
563            }
564            ScalarExpr::CallVariadic { exprs, .. } => {
565                for expr in exprs {
566                    f(expr)?;
567                }
568                Ok(())
569            }
570            ScalarExpr::If { cond, then, els } => {
571                f(cond)?;
572                f(then)?;
573                f(els)
574            }
575            ScalarExpr::CallDf {
576                df_scalar_fn: _,
577                exprs,
578            } => {
579                for expr in exprs {
580                    f(expr)?;
581                }
582                Ok(())
583            }
584            ScalarExpr::InList { expr, list } => {
585                f(expr)?;
586                for item in list {
587                    f(item)?;
588                }
589                Ok(())
590            }
591        }
592    }
593
594    fn visit_mut_post_nolimit<F>(&mut self, f: &mut F) -> Result<(), Error>
595    where
596        F: FnMut(&mut Self) -> Result<(), Error>,
597    {
598        self.visit_mut_children(|e: &mut Self| e.visit_mut_post_nolimit(f))?;
599        f(self)
600    }
601
602    fn visit_mut_children<F>(&mut self, mut f: F) -> Result<(), Error>
603    where
604        F: FnMut(&mut Self) -> Result<(), Error>,
605    {
606        match self {
607            ScalarExpr::Column(_)
608            | ScalarExpr::Literal(_, _)
609            | ScalarExpr::CallUnmaterializable(_) => Ok(()),
610            ScalarExpr::CallUnary { expr, .. } => f(expr),
611            ScalarExpr::CallBinary { expr1, expr2, .. } => {
612                f(expr1)?;
613                f(expr2)
614            }
615            ScalarExpr::CallVariadic { exprs, .. } => {
616                for expr in exprs {
617                    f(expr)?;
618                }
619                Ok(())
620            }
621            ScalarExpr::If { cond, then, els } => {
622                f(cond)?;
623                f(then)?;
624                f(els)
625            }
626            ScalarExpr::CallDf {
627                df_scalar_fn: _,
628                exprs,
629            } => {
630                for expr in exprs {
631                    f(expr)?;
632                }
633                Ok(())
634            }
635            ScalarExpr::InList { expr, list } => {
636                f(expr)?;
637                for item in list {
638                    f(item)?;
639                }
640                Ok(())
641            }
642        }
643    }
644}
645
646impl ScalarExpr {
647    /// if expr contains function `Now`
648    pub fn contains_temporal(&self) -> bool {
649        let mut contains = false;
650        self.visit_post_nolimit(&mut |e| {
651            if let ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now) = e {
652                contains = true;
653            }
654            Ok(())
655        })
656        .unwrap();
657        contains
658    }
659
660    /// extract lower or upper bound of `Now` for expr, where `lower bound <= expr < upper bound`
661    ///
662    /// returned bool indicates whether the bound is upper bound:
663    ///
664    /// false for lower bound, true for upper bound
665    /// TODO(discord9): allow simple transform like `now() + a < b` to `now() < b - a`
666    pub fn extract_bound(&self) -> Result<(Option<Self>, Option<Self>), Error> {
667        let unsupported_err = |msg: &str| {
668            UnsupportedTemporalFilterSnafu {
669                reason: msg.to_string(),
670            }
671            .fail()
672        };
673
674        let Self::CallBinary {
675            mut func,
676            mut expr1,
677            mut expr2,
678        } = self.clone()
679        else {
680            return unsupported_err("Not a binary expression");
681        };
682
683        // TODO(discord9): support simple transform like `now() + a < b` to `now() < b - a`
684
685        let expr1_is_now = *expr1 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now);
686        let expr2_is_now = *expr2 == ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now);
687
688        if !(expr1_is_now ^ expr2_is_now) {
689            return unsupported_err("None of the sides of the comparison is `now()`");
690        }
691
692        if expr2_is_now {
693            std::mem::swap(&mut expr1, &mut expr2);
694            func = BinaryFunc::reverse_compare(&func)?;
695        }
696
697        let step = |expr: ScalarExpr| expr.call_unary(UnaryFunc::StepTimestamp);
698        match func {
699            // now == expr2 -> now <= expr2 && now < expr2 + 1
700            BinaryFunc::Eq => Ok((Some(*expr2.clone()), Some(step(*expr2)))),
701            // now < expr2 -> now < expr2
702            BinaryFunc::Lt => Ok((None, Some(*expr2))),
703            // now <= expr2 -> now < expr2 + 1
704            BinaryFunc::Lte => Ok((None, Some(step(*expr2)))),
705            // now > expr2 -> now >= expr2 + 1
706            BinaryFunc::Gt => Ok((Some(step(*expr2)), None)),
707            // now >= expr2 -> now >= expr2
708            BinaryFunc::Gte => Ok((Some(*expr2), None)),
709            _ => unreachable!("Already checked"),
710        }
711    }
712}
713
714#[cfg(test)]
715mod test {
716    use datatypes::vectors::{Int32Vector, Vector};
717    use pretty_assertions::assert_eq;
718
719    use super::*;
720
721    #[test]
722    fn test_extract_bound() {
723        let test_list: [(ScalarExpr, Result<_, EvalError>); 5] = [
724            // col(0) == now
725            (
726                ScalarExpr::CallBinary {
727                    func: BinaryFunc::Eq,
728                    expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
729                    expr2: Box::new(ScalarExpr::Column(0)),
730                },
731                Ok((
732                    Some(ScalarExpr::Column(0)),
733                    Some(ScalarExpr::CallUnary {
734                        func: UnaryFunc::StepTimestamp,
735                        expr: Box::new(ScalarExpr::Column(0)),
736                    }),
737                )),
738            ),
739            // now < col(0)
740            (
741                ScalarExpr::CallBinary {
742                    func: BinaryFunc::Lt,
743                    expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
744                    expr2: Box::new(ScalarExpr::Column(0)),
745                },
746                Ok((None, Some(ScalarExpr::Column(0)))),
747            ),
748            // now <= col(0)
749            (
750                ScalarExpr::CallBinary {
751                    func: BinaryFunc::Lte,
752                    expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
753                    expr2: Box::new(ScalarExpr::Column(0)),
754                },
755                Ok((
756                    None,
757                    Some(ScalarExpr::CallUnary {
758                        func: UnaryFunc::StepTimestamp,
759                        expr: Box::new(ScalarExpr::Column(0)),
760                    }),
761                )),
762            ),
763            // now > col(0) -> now >= col(0) + 1
764            (
765                ScalarExpr::CallBinary {
766                    func: BinaryFunc::Gt,
767                    expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
768                    expr2: Box::new(ScalarExpr::Column(0)),
769                },
770                Ok((
771                    Some(ScalarExpr::CallUnary {
772                        func: UnaryFunc::StepTimestamp,
773                        expr: Box::new(ScalarExpr::Column(0)),
774                    }),
775                    None,
776                )),
777            ),
778            // now >= col(0)
779            (
780                ScalarExpr::CallBinary {
781                    func: BinaryFunc::Gte,
782                    expr1: Box::new(ScalarExpr::CallUnmaterializable(UnmaterializableFunc::Now)),
783                    expr2: Box::new(ScalarExpr::Column(0)),
784                },
785                Ok((Some(ScalarExpr::Column(0)), None)),
786            ),
787        ];
788        for (expr, expected) in test_list.into_iter() {
789            let actual = expr.extract_bound();
790            // EvalError is not Eq, so we need to compare the error message
791            match (actual, expected) {
792                (Ok(l), Ok(r)) => assert_eq!(l, r),
793                (l, r) => panic!("expected: {:?}, actual: {:?}", r, l),
794            }
795        }
796    }
797
798    #[test]
799    fn test_bad_permute() {
800        let mut expr = ScalarExpr::Column(4);
801        let permutation = vec![1, 2, 3];
802        let res = expr.permute(&permutation);
803        assert!(matches!(res, Err(Error::InvalidQuery { .. })));
804
805        let mut expr = ScalarExpr::Column(0);
806        let permute_map = BTreeMap::from([(1, 2), (3, 4)]);
807        let res = expr.permute_map(&permute_map);
808        assert!(matches!(res, Err(Error::InvalidQuery { .. })));
809    }
810
811    #[test]
812    fn test_eval_batch_if_then() {
813        // TODO(discord9): add more tests
814        {
815            let expr = ScalarExpr::If {
816                cond: Box::new(ScalarExpr::Column(0).call_binary(
817                    ScalarExpr::literal(Value::from(0), ConcreteDataType::int32_datatype()),
818                    BinaryFunc::Eq,
819                )),
820                then: Box::new(ScalarExpr::literal(
821                    Value::from(42),
822                    ConcreteDataType::int32_datatype(),
823                )),
824                els: Box::new(ScalarExpr::literal(
825                    Value::from(37),
826                    ConcreteDataType::int32_datatype(),
827                )),
828            };
829            let raw = vec![
830                None,
831                Some(0),
832                Some(1),
833                None,
834                None,
835                Some(0),
836                Some(0),
837                Some(1),
838                Some(1),
839            ];
840            let raw_len = raw.len();
841            let vectors = vec![Int32Vector::from(raw).slice(0, raw_len)];
842
843            let batch = Batch::try_new(vectors, raw_len).unwrap();
844            let expected = Int32Vector::from(vec![
845                None,
846                Some(42),
847                Some(37),
848                None,
849                None,
850                Some(42),
851                Some(42),
852                Some(37),
853                Some(37),
854            ])
855            .slice(0, raw_len);
856            assert_eq!(expr.eval_batch(&batch).unwrap(), expected);
857
858            let raw = vec![Some(0)];
859            let raw_len = raw.len();
860            let vectors = vec![Int32Vector::from(raw).slice(0, raw_len)];
861
862            let batch = Batch::try_new(vectors, raw_len).unwrap();
863            let expected = Int32Vector::from(vec![Some(42)]).slice(0, raw_len);
864            assert_eq!(expr.eval_batch(&batch).unwrap(), expected);
865
866            let raw: Vec<Option<i32>> = vec![];
867            let raw_len = raw.len();
868            let vectors = vec![Int32Vector::from(raw).slice(0, raw_len)];
869
870            let batch = Batch::try_new(vectors, raw_len).unwrap();
871            let expected = Int32Vector::from(vec![]).slice(0, raw_len);
872            assert_eq!(expr.eval_batch(&batch).unwrap(), expected);
873        }
874    }
875}