common_function/scalars/vector/
scalar_add.rs1use std::fmt::Display;
16
17use datafusion::arrow::datatypes::DataType;
18use datafusion::logical_expr::ColumnarValue;
19use datafusion_common::ScalarValue;
20use datafusion_expr::{ScalarFunctionArgs, Signature};
21use nalgebra::DVectorView;
22
23use crate::function::Function;
24use crate::helper;
25use crate::scalars::vector::VectorCalculator;
26use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
27
28const NAME: &str = "vec_scalar_add";
29
30#[derive(Debug, Clone)]
53pub(crate) struct ScalarAddFunction {
54 signature: Signature,
55}
56
57impl Default for ScalarAddFunction {
58 fn default() -> Self {
59 Self {
60 signature: helper::one_of_sigs2(
61 vec![DataType::Float64],
62 vec![
63 DataType::Utf8,
64 DataType::Utf8View,
65 DataType::Binary,
66 DataType::BinaryView,
67 ],
68 ),
69 }
70 }
71}
72
73impl Function for ScalarAddFunction {
74 fn name(&self) -> &str {
75 NAME
76 }
77
78 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
79 Ok(DataType::BinaryView)
80 }
81
82 fn signature(&self) -> &Signature {
83 &self.signature
84 }
85
86 fn invoke_with_args(
87 &self,
88 args: ScalarFunctionArgs,
89 ) -> datafusion_common::Result<ColumnarValue> {
90 let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
91 let ScalarValue::Float64(Some(v0)) = v0 else {
92 return Ok(ScalarValue::BinaryView(None));
93 };
94
95 let v1 = as_veclit(v1)?
96 .map(|v1| DVectorView::from_slice(&v1, v1.len()).add_scalar(*v0 as f32));
97 let result = v1.map(|v1| veclit_to_binlit(v1.as_slice()));
98 Ok(ScalarValue::BinaryView(result))
99 };
100
101 let calculator = VectorCalculator {
102 name: self.name(),
103 func: body,
104 };
105 calculator.invoke_with_args(args)
106 }
107}
108
109impl Display for ScalarAddFunction {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 write!(f, "{}", NAME.to_ascii_uppercase())
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use std::sync::Arc;
118
119 use arrow_schema::Field;
120 use datafusion::arrow::array::{Array, AsArray, Float64Array, StringViewArray};
121 use datafusion_common::config::ConfigOptions;
122
123 use super::*;
124
125 #[test]
126 fn test_scalar_add() {
127 let func = ScalarAddFunction::default();
128
129 let input0 = Arc::new(Float64Array::from(vec![
130 Some(1.0),
131 Some(-1.0),
132 None,
133 Some(3.0),
134 ]));
135 let input1 = Arc::new(StringViewArray::from(vec![
136 Some("[1.0,2.0,3.0]".to_string()),
137 Some("[4.0,5.0,6.0]".to_string()),
138 Some("[7.0,8.0,9.0]".to_string()),
139 None,
140 ]));
141
142 let args = ScalarFunctionArgs {
143 args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
144 arg_fields: vec![],
145 number_rows: 4,
146 return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
147 config_options: Arc::new(ConfigOptions::new()),
148 };
149 let result = func
150 .invoke_with_args(args)
151 .and_then(|x| x.to_array(4))
152 .unwrap();
153
154 let result = result.as_binary_view();
155 assert_eq!(result.len(), 4);
156 assert_eq!(
157 result.value(0),
158 veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice()
159 );
160 assert_eq!(
161 result.value(1),
162 veclit_to_binlit(&[3.0, 4.0, 5.0]).as_slice()
163 );
164 assert!(result.is_null(2));
165 assert!(result.is_null(3));
166 }
167}