common_function/scalars/
udf.rsuse std::sync::Arc;
use common_query::error::FromScalarValueSnafu;
use common_query::prelude::{
ColumnarValue, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUdf,
};
use datatypes::error::Error as DataTypeError;
use datatypes::prelude::*;
use datatypes::vectors::Helper;
use session::context::QueryContextRef;
use snafu::ResultExt;
use crate::function::{FunctionContext, FunctionRef};
use crate::state::FunctionState;
pub fn create_udf(
func: FunctionRef,
query_ctx: QueryContextRef,
state: Arc<FunctionState>,
) -> ScalarUdf {
let func_cloned = func.clone();
let return_type: ReturnTypeFunction = Arc::new(move |input_types: &[ConcreteDataType]| {
Ok(Arc::new(func_cloned.return_type(input_types)?))
});
let func_cloned = func.clone();
let fun: ScalarFunctionImplementation = Arc::new(move |args: &[ColumnarValue]| {
let func_ctx = FunctionContext {
query_ctx: query_ctx.clone(),
state: state.clone(),
};
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Vector(v) => Some(v.len()),
});
let rows = len.unwrap_or(1);
let args: Result<Vec<_>, DataTypeError> = args
.iter()
.map(|arg| match arg {
ColumnarValue::Scalar(v) => Helper::try_from_scalar_value(v.clone(), rows),
ColumnarValue::Vector(v) => Ok(v.clone()),
})
.collect();
let result = func_cloned.eval(func_ctx, &args.context(FromScalarValueSnafu)?);
let udf_result = result.map(ColumnarValue::Vector)?;
Ok(udf_result)
});
ScalarUdf::new(func.name(), &func.signature(), &return_type, &fun)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use common_query::prelude::{ColumnarValue, ScalarValue};
use datatypes::data_type::ConcreteDataType;
use datatypes::prelude::{ScalarVector, Vector, VectorRef};
use datatypes::value::Value;
use datatypes::vectors::{BooleanVector, ConstantVector};
use session::context::QueryContextBuilder;
use super::*;
use crate::function::Function;
use crate::scalars::test::TestAndFunction;
#[test]
fn test_create_udf() {
let f = Arc::new(TestAndFunction);
let query_ctx = QueryContextBuilder::default().build().into();
let args: Vec<VectorRef> = vec![
Arc::new(ConstantVector::new(
Arc::new(BooleanVector::from(vec![true])),
3,
)),
Arc::new(BooleanVector::from(vec![true, false, true])),
];
let vector = f.eval(FunctionContext::default(), &args).unwrap();
assert_eq!(3, vector.len());
for i in 0..3 {
assert!(matches!(vector.get(i), Value::Boolean(b) if b == (i == 0 || i == 2)));
}
let udf = create_udf(f.clone(), query_ctx, Arc::new(FunctionState::default()));
assert_eq!("test_and", udf.name);
assert_eq!(f.signature(), udf.signature);
assert_eq!(
Arc::new(ConcreteDataType::boolean_datatype()),
((udf.return_type)(&[])).unwrap()
);
let args = vec![
ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
ColumnarValue::Vector(Arc::new(BooleanVector::from(vec![
true, false, false, true,
]))),
];
let vec = (udf.fun)(&args).unwrap();
match vec {
ColumnarValue::Vector(vec) => {
let vec = vec.as_any().downcast_ref::<BooleanVector>().unwrap();
assert_eq!(4, vec.len());
for i in 0..4 {
assert_eq!(i == 0 || i == 3, vec.get_data(i).unwrap(), "Failed at {i}",)
}
}
_ => unreachable!(),
}
}
}