common_function/scalars/
uddsketch_calc.rs1use std::fmt;
18use std::fmt::Display;
19use std::sync::Arc;
20
21use datafusion_common::DataFusionError;
22use datafusion_common::arrow::array::{Array, AsArray, Float64Builder};
23use datafusion_common::arrow::datatypes::{DataType, Float64Type};
24use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
25use uddsketch::UDDSketch;
26
27use crate::function::{Function, extract_args};
28use crate::function_registry::FunctionRegistry;
29
30const NAME: &str = "uddsketch_calc";
31
32#[derive(Debug)]
40pub(crate) struct UddSketchCalcFunction {
41 signature: Signature,
42}
43
44impl UddSketchCalcFunction {
45 pub fn register(registry: &FunctionRegistry) {
46 registry.register_scalar(UddSketchCalcFunction::default());
47 }
48}
49
50impl Default for UddSketchCalcFunction {
51 fn default() -> Self {
52 Self {
53 signature: Signature::exact(
56 vec![DataType::Float64, DataType::Binary],
57 Volatility::Immutable,
58 ),
59 }
60 }
61}
62
63impl Display for UddSketchCalcFunction {
64 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65 write!(f, "{}", NAME.to_ascii_uppercase())
66 }
67}
68
69impl Function for UddSketchCalcFunction {
70 fn name(&self) -> &str {
71 NAME
72 }
73
74 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
75 Ok(DataType::Float64)
76 }
77
78 fn signature(&self) -> &Signature {
79 &self.signature
80 }
81
82 fn invoke_with_args(
83 &self,
84 args: ScalarFunctionArgs,
85 ) -> datafusion_common::Result<ColumnarValue> {
86 let [arg0, arg1] = extract_args(self.name(), &args)?;
87
88 let Some(percentages) = arg0.as_primitive_opt::<Float64Type>() else {
89 return Err(DataFusionError::Execution(format!(
90 "'{}' expects 1st argument to be Float64 datatype, got {}",
91 self.name(),
92 arg0.data_type()
93 )));
94 };
95 let Some(sketch_vec) = arg1.as_binary_opt::<i32>() else {
96 return Err(DataFusionError::Execution(format!(
97 "'{}' expects 2nd argument to be Binary datatype, got {}",
98 self.name(),
99 arg1.data_type()
100 )));
101 };
102 let len = sketch_vec.len();
103 let mut builder = Float64Builder::with_capacity(len);
104
105 for i in 0..len {
106 let perc_opt = percentages.is_valid(i).then(|| percentages.value(i));
107 let sketch_opt = sketch_vec.is_valid(i).then(|| sketch_vec.value(i));
108
109 if sketch_opt.is_none() || perc_opt.is_none() {
110 builder.append_null();
111 continue;
112 }
113
114 let sketch_bytes = sketch_opt.unwrap();
115 let perc = perc_opt.unwrap();
116
117 let sketch: UDDSketch = match bincode::deserialize(sketch_bytes) {
119 Ok(s) => s,
120 Err(e) => {
121 common_telemetry::trace!("Failed to deserialize UDDSketch: {}", e);
122 builder.append_null();
123 continue;
124 }
125 };
126
127 if sketch.bucket_iter().count() == 0 {
131 builder.append_null();
132 continue;
133 }
134 let result = sketch.estimate_quantile(perc);
136 builder.append_value(result);
137 }
138
139 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use std::sync::Arc;
146
147 use arrow_schema::Field;
148 use datafusion_common::arrow::array::{BinaryArray, Float64Array};
149
150 use super::*;
151
152 #[test]
153 fn test_uddsketch_calc_function() {
154 let function = UddSketchCalcFunction::default();
155 assert_eq!("uddsketch_calc", function.name());
156 assert_eq!(
157 DataType::Float64,
158 function.return_type(&[DataType::Float64]).unwrap()
159 );
160
161 let mut sketch = UDDSketch::new(128, 0.01);
163 sketch.add_value(10.0);
164 sketch.add_value(20.0);
165 sketch.add_value(30.0);
166 sketch.add_value(40.0);
167 sketch.add_value(50.0);
168 sketch.add_value(60.0);
169 sketch.add_value(70.0);
170 sketch.add_value(80.0);
171 sketch.add_value(90.0);
172 sketch.add_value(100.0);
173
174 let expected_p50 = sketch.estimate_quantile(0.5);
176 let expected_p90 = sketch.estimate_quantile(0.9);
177 let expected_p95 = sketch.estimate_quantile(0.95);
178
179 let serialized = bincode::serialize(&sketch).unwrap();
180 let percentiles = vec![0.5, 0.9, 0.95];
181
182 let args = vec![
183 ColumnarValue::Array(Arc::new(Float64Array::from(percentiles.clone()))),
184 ColumnarValue::Array(Arc::new(BinaryArray::from_iter_values(vec![serialized; 3]))),
185 ];
186
187 let result = function
188 .invoke_with_args(ScalarFunctionArgs {
189 args,
190 arg_fields: vec![],
191 number_rows: 3,
192 return_field: Arc::new(Field::new("x", DataType::Float64, false)),
193 config_options: Arc::new(Default::default()),
194 })
195 .unwrap();
196 let ColumnarValue::Array(result) = result else {
197 unreachable!()
198 };
199 let result = result.as_primitive::<Float64Type>();
200 assert_eq!(result.len(), 3);
201
202 assert!((result.value(0) - expected_p50).abs() < 1e-10);
204 assert!((result.value(1) - expected_p90).abs() < 1e-10);
206 assert!((result.value(2) - expected_p95).abs() < 1e-10);
208 }
209
210 #[test]
211 fn test_uddsketch_calc_function_errors() {
212 let function = UddSketchCalcFunction::default();
213
214 let result = function.invoke_with_args(ScalarFunctionArgs {
216 args: vec![ColumnarValue::Array(Arc::new(Float64Array::from(vec![
217 0.95,
218 ])))],
219 arg_fields: vec![],
220 number_rows: 0,
221 return_field: Arc::new(Field::new("x", DataType::Float64, false)),
222 config_options: Arc::new(Default::default()),
223 });
224 assert!(result.is_err());
225 assert!(
226 result
227 .unwrap_err()
228 .to_string()
229 .contains("Execution error: uddsketch_calc function requires 2 arguments, got 1")
230 );
231
232 let args = vec![
234 ColumnarValue::Array(Arc::new(Float64Array::from(vec![0.95]))),
235 ColumnarValue::Array(Arc::new(BinaryArray::from_iter(vec![Some(vec![1, 2, 3])]))),
236 ];
237 let result = function
238 .invoke_with_args(ScalarFunctionArgs {
239 args,
240 arg_fields: vec![],
241 number_rows: 0,
242 return_field: Arc::new(Field::new("x", DataType::Float64, false)),
243 config_options: Arc::new(Default::default()),
244 })
245 .unwrap();
246 let ColumnarValue::Array(result) = result else {
247 unreachable!()
248 };
249 let result = result.as_primitive::<Float64Type>();
250 assert_eq!(result.len(), 1);
251 assert!(result.is_null(0));
252 }
253}