common_function/scalars/expression/
binary.rs1use std::iter;
16
17use common_query::error::Result;
18use datatypes::prelude::*;
19use datatypes::vectors::{ConstantVector, Helper};
20
21use crate::scalars::expression::ctx::EvalContext;
22
23pub fn scalar_binary_op<L: Scalar, R: Scalar, O: Scalar, F>(
24 l: &VectorRef,
25 r: &VectorRef,
26 f: F,
27 ctx: &mut EvalContext,
28) -> Result<<O as Scalar>::VectorType>
29where
30 F: Fn(Option<L::RefType<'_>>, Option<R::RefType<'_>>, &mut EvalContext) -> Option<O>,
31{
32 debug_assert!(
33 l.len() == r.len(),
34 "Size of vectors must match to apply binary expression"
35 );
36
37 let result = match (l.is_const(), r.is_const()) {
38 (false, true) => {
39 let left: &<L as Scalar>::VectorType = unsafe { Helper::static_cast(l) };
40 let right: &ConstantVector = unsafe { Helper::static_cast(r) };
41 let right: &<R as Scalar>::VectorType = unsafe { Helper::static_cast(right.inner()) };
42 let b = right.get_data(0);
43
44 let it = left.iter_data().map(|a| f(a, b, ctx));
45 <O as Scalar>::VectorType::from_owned_iterator(it)
46 }
47
48 (false, false) => {
49 let left: &<L as Scalar>::VectorType = unsafe { Helper::static_cast(l) };
50 let right: &<R as Scalar>::VectorType = unsafe { Helper::static_cast(r) };
51
52 let it = left
53 .iter_data()
54 .zip(right.iter_data())
55 .map(|(a, b)| f(a, b, ctx));
56 <O as Scalar>::VectorType::from_owned_iterator(it)
57 }
58
59 (true, false) => {
60 let left: &ConstantVector = unsafe { Helper::static_cast(l) };
61 let left: &<L as Scalar>::VectorType = unsafe { Helper::static_cast(left.inner()) };
62 let a = left.get_data(0);
63
64 let right: &<R as Scalar>::VectorType = unsafe { Helper::static_cast(r) };
65 let it = right.iter_data().map(|b| f(a, b, ctx));
66 <O as Scalar>::VectorType::from_owned_iterator(it)
67 }
68
69 (true, true) => {
70 let left: &ConstantVector = unsafe { Helper::static_cast(l) };
71 let left: &<L as Scalar>::VectorType = unsafe { Helper::static_cast(left.inner()) };
72 let a = left.get_data(0);
73
74 let right: &ConstantVector = unsafe { Helper::static_cast(r) };
75 let right: &<R as Scalar>::VectorType = unsafe { Helper::static_cast(right.inner()) };
76 let b = right.get_data(0);
77
78 let it = iter::repeat(a)
79 .zip(iter::repeat(b))
80 .map(|(a, b)| f(a, b, ctx))
81 .take(left.len());
82 <O as Scalar>::VectorType::from_owned_iterator(it)
83 }
84 };
85
86 if let Some(error) = ctx.error.take() {
87 return Err(error);
88 }
89 Ok(result)
90}