common_function/scalars/expression/
binary.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::iter;
16
17use common_query::error::Result;
18use datatypes::prelude::*;
19use datatypes::vectors::{ConstantVector, Helper};
20
21use crate::scalars::expression::ctx::EvalContext;
22
23pub fn scalar_binary_op<L: Scalar, R: Scalar, O: Scalar, F>(
24    l: &VectorRef,
25    r: &VectorRef,
26    f: F,
27    ctx: &mut EvalContext,
28) -> Result<<O as Scalar>::VectorType>
29where
30    F: Fn(Option<L::RefType<'_>>, Option<R::RefType<'_>>, &mut EvalContext) -> Option<O>,
31{
32    debug_assert!(
33        l.len() == r.len(),
34        "Size of vectors must match to apply binary expression"
35    );
36
37    let result = match (l.is_const(), r.is_const()) {
38        (false, true) => {
39            let left: &<L as Scalar>::VectorType = unsafe { Helper::static_cast(l) };
40            let right: &ConstantVector = unsafe { Helper::static_cast(r) };
41            let right: &<R as Scalar>::VectorType = unsafe { Helper::static_cast(right.inner()) };
42            let b = right.get_data(0);
43
44            let it = left.iter_data().map(|a| f(a, b, ctx));
45            <O as Scalar>::VectorType::from_owned_iterator(it)
46        }
47
48        (false, false) => {
49            let left: &<L as Scalar>::VectorType = unsafe { Helper::static_cast(l) };
50            let right: &<R as Scalar>::VectorType = unsafe { Helper::static_cast(r) };
51
52            let it = left
53                .iter_data()
54                .zip(right.iter_data())
55                .map(|(a, b)| f(a, b, ctx));
56            <O as Scalar>::VectorType::from_owned_iterator(it)
57        }
58
59        (true, false) => {
60            let left: &ConstantVector = unsafe { Helper::static_cast(l) };
61            let left: &<L as Scalar>::VectorType = unsafe { Helper::static_cast(left.inner()) };
62            let a = left.get_data(0);
63
64            let right: &<R as Scalar>::VectorType = unsafe { Helper::static_cast(r) };
65            let it = right.iter_data().map(|b| f(a, b, ctx));
66            <O as Scalar>::VectorType::from_owned_iterator(it)
67        }
68
69        (true, true) => {
70            let left: &ConstantVector = unsafe { Helper::static_cast(l) };
71            let left: &<L as Scalar>::VectorType = unsafe { Helper::static_cast(left.inner()) };
72            let a = left.get_data(0);
73
74            let right: &ConstantVector = unsafe { Helper::static_cast(r) };
75            let right: &<R as Scalar>::VectorType = unsafe { Helper::static_cast(right.inner()) };
76            let b = right.get_data(0);
77
78            let it = iter::repeat(a)
79                .zip(iter::repeat(b))
80                .map(|(a, b)| f(a, b, ctx))
81                .take(left.len());
82            <O as Scalar>::VectorType::from_owned_iterator(it)
83        }
84    };
85
86    if let Some(error) = ctx.error.take() {
87        return Err(error);
88    }
89    Ok(result)
90}