common_function/scalars/
uddsketch_calc.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Implementation of the scalar function `uddsketch_calc`.
16
17use std::fmt;
18use std::fmt::Display;
19use std::sync::Arc;
20
21use datafusion_common::DataFusionError;
22use datafusion_common::arrow::array::{Array, AsArray, Float64Builder};
23use datafusion_common::arrow::datatypes::{DataType, Float64Type};
24use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
25use uddsketch::UDDSketch;
26
27use crate::function::{Function, extract_args};
28use crate::function_registry::FunctionRegistry;
29
30const NAME: &str = "uddsketch_calc";
31
32/// UddSketchCalcFunction implements the scalar function `uddsketch_calc`.
33///
34/// It accepts two arguments:
35/// 1. A percentile (as f64) for which to compute the estimated quantile (e.g. 0.95 for p95).
36/// 2. The serialized UDDSketch state, as produced by the aggregator (binary).
37///
38/// For each row, it deserializes the sketch and returns the computed quantile value.
39#[derive(Debug)]
40pub(crate) struct UddSketchCalcFunction {
41    signature: Signature,
42}
43
44impl UddSketchCalcFunction {
45    pub fn register(registry: &FunctionRegistry) {
46        registry.register_scalar(UddSketchCalcFunction::default());
47    }
48}
49
50impl Default for UddSketchCalcFunction {
51    fn default() -> Self {
52        Self {
53            // First argument: percentile (float64)
54            // Second argument: UDDSketch state (binary)
55            signature: Signature::exact(
56                vec![DataType::Float64, DataType::Binary],
57                Volatility::Immutable,
58            ),
59        }
60    }
61}
62
63impl Display for UddSketchCalcFunction {
64    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
65        write!(f, "{}", NAME.to_ascii_uppercase())
66    }
67}
68
69impl Function for UddSketchCalcFunction {
70    fn name(&self) -> &str {
71        NAME
72    }
73
74    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
75        Ok(DataType::Float64)
76    }
77
78    fn signature(&self) -> &Signature {
79        &self.signature
80    }
81
82    fn invoke_with_args(
83        &self,
84        args: ScalarFunctionArgs,
85    ) -> datafusion_common::Result<ColumnarValue> {
86        let [arg0, arg1] = extract_args(self.name(), &args)?;
87
88        let Some(percentages) = arg0.as_primitive_opt::<Float64Type>() else {
89            return Err(DataFusionError::Execution(format!(
90                "'{}' expects 1st argument to be Float64 datatype, got {}",
91                self.name(),
92                arg0.data_type()
93            )));
94        };
95        let Some(sketch_vec) = arg1.as_binary_opt::<i32>() else {
96            return Err(DataFusionError::Execution(format!(
97                "'{}' expects 2nd argument to be Binary datatype, got {}",
98                self.name(),
99                arg1.data_type()
100            )));
101        };
102        let len = sketch_vec.len();
103        let mut builder = Float64Builder::with_capacity(len);
104
105        for i in 0..len {
106            let perc_opt = percentages.is_valid(i).then(|| percentages.value(i));
107            let sketch_opt = sketch_vec.is_valid(i).then(|| sketch_vec.value(i));
108
109            if sketch_opt.is_none() || perc_opt.is_none() {
110                builder.append_null();
111                continue;
112            }
113
114            let sketch_bytes = sketch_opt.unwrap();
115            let perc = perc_opt.unwrap();
116
117            // Deserialize the UDDSketch from its bincode representation
118            let sketch: UDDSketch = match bincode::deserialize(sketch_bytes) {
119                Ok(s) => s,
120                Err(e) => {
121                    common_telemetry::trace!("Failed to deserialize UDDSketch: {}", e);
122                    builder.append_null();
123                    continue;
124                }
125            };
126
127            // Check if the sketch is empty, if so, return null
128            // This is important to avoid panics when calling estimate_quantile on an empty sketch
129            // In practice, this will happen if input is all null
130            if sketch.bucket_iter().count() == 0 {
131                builder.append_null();
132                continue;
133            }
134            // Compute the estimated quantile from the sketch
135            let result = sketch.estimate_quantile(perc);
136            builder.append_value(result);
137        }
138
139        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use std::sync::Arc;
146
147    use arrow_schema::Field;
148    use datafusion_common::arrow::array::{BinaryArray, Float64Array};
149
150    use super::*;
151
152    #[test]
153    fn test_uddsketch_calc_function() {
154        let function = UddSketchCalcFunction::default();
155        assert_eq!("uddsketch_calc", function.name());
156        assert_eq!(
157            DataType::Float64,
158            function.return_type(&[DataType::Float64]).unwrap()
159        );
160
161        // Create a test sketch
162        let mut sketch = UDDSketch::new(128, 0.01);
163        sketch.add_value(10.0);
164        sketch.add_value(20.0);
165        sketch.add_value(30.0);
166        sketch.add_value(40.0);
167        sketch.add_value(50.0);
168        sketch.add_value(60.0);
169        sketch.add_value(70.0);
170        sketch.add_value(80.0);
171        sketch.add_value(90.0);
172        sketch.add_value(100.0);
173
174        // Get expected values directly from the sketch
175        let expected_p50 = sketch.estimate_quantile(0.5);
176        let expected_p90 = sketch.estimate_quantile(0.9);
177        let expected_p95 = sketch.estimate_quantile(0.95);
178
179        let serialized = bincode::serialize(&sketch).unwrap();
180        let percentiles = vec![0.5, 0.9, 0.95];
181
182        let args = vec![
183            ColumnarValue::Array(Arc::new(Float64Array::from(percentiles.clone()))),
184            ColumnarValue::Array(Arc::new(BinaryArray::from_iter_values(vec![serialized; 3]))),
185        ];
186
187        let result = function
188            .invoke_with_args(ScalarFunctionArgs {
189                args,
190                arg_fields: vec![],
191                number_rows: 3,
192                return_field: Arc::new(Field::new("x", DataType::Float64, false)),
193                config_options: Arc::new(Default::default()),
194            })
195            .unwrap();
196        let ColumnarValue::Array(result) = result else {
197            unreachable!()
198        };
199        let result = result.as_primitive::<Float64Type>();
200        assert_eq!(result.len(), 3);
201
202        // Test median (p50)
203        assert!((result.value(0) - expected_p50).abs() < 1e-10);
204        // Test p90
205        assert!((result.value(1) - expected_p90).abs() < 1e-10);
206        // Test p95
207        assert!((result.value(2) - expected_p95).abs() < 1e-10);
208    }
209
210    #[test]
211    fn test_uddsketch_calc_function_errors() {
212        let function = UddSketchCalcFunction::default();
213
214        // Test with invalid number of arguments
215        let result = function.invoke_with_args(ScalarFunctionArgs {
216            args: vec![ColumnarValue::Array(Arc::new(Float64Array::from(vec![
217                0.95,
218            ])))],
219            arg_fields: vec![],
220            number_rows: 0,
221            return_field: Arc::new(Field::new("x", DataType::Float64, false)),
222            config_options: Arc::new(Default::default()),
223        });
224        assert!(result.is_err());
225        assert!(
226            result
227                .unwrap_err()
228                .to_string()
229                .contains("Execution error: uddsketch_calc function requires 2 arguments, got 1")
230        );
231
232        // Test with invalid binary data
233        let args = vec![
234            ColumnarValue::Array(Arc::new(Float64Array::from(vec![0.95]))),
235            ColumnarValue::Array(Arc::new(BinaryArray::from_iter(vec![Some(vec![1, 2, 3])]))),
236        ];
237        let result = function
238            .invoke_with_args(ScalarFunctionArgs {
239                args,
240                arg_fields: vec![],
241                number_rows: 0,
242                return_field: Arc::new(Field::new("x", DataType::Float64, false)),
243                config_options: Arc::new(Default::default()),
244            })
245            .unwrap();
246        let ColumnarValue::Array(result) = result else {
247            unreachable!()
248        };
249        let result = result.as_primitive::<Float64Type>();
250        assert_eq!(result.len(), 1);
251        assert!(result.is_null(0));
252    }
253}