common_function/scalars/
uddsketch_calc.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Implementation of the scalar function `uddsketch_calc`.
16
17use std::fmt;
18use std::fmt::Display;
19
20use common_query::error::{DowncastVectorSnafu, InvalidFuncArgsSnafu, Result};
21use datafusion_expr::{Signature, Volatility};
22use datatypes::arrow::datatypes::DataType;
23use datatypes::prelude::Vector;
24use datatypes::scalars::{ScalarVector, ScalarVectorBuilder};
25use datatypes::vectors::{BinaryVector, Float64VectorBuilder, MutableVector, VectorRef};
26use snafu::OptionExt;
27use uddsketch::UDDSketch;
28
29use crate::function::{Function, FunctionContext};
30use crate::function_registry::FunctionRegistry;
31
32const NAME: &str = "uddsketch_calc";
33
34/// UddSketchCalcFunction implements the scalar function `uddsketch_calc`.
35///
36/// It accepts two arguments:
37/// 1. A percentile (as f64) for which to compute the estimated quantile (e.g. 0.95 for p95).
38/// 2. The serialized UDDSketch state, as produced by the aggregator (binary).
39///
40/// For each row, it deserializes the sketch and returns the computed quantile value.
41#[derive(Debug, Default)]
42pub struct UddSketchCalcFunction;
43
44impl UddSketchCalcFunction {
45    pub fn register(registry: &FunctionRegistry) {
46        registry.register_scalar(UddSketchCalcFunction);
47    }
48}
49
50impl Display for UddSketchCalcFunction {
51    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52        write!(f, "{}", NAME.to_ascii_uppercase())
53    }
54}
55
56impl Function for UddSketchCalcFunction {
57    fn name(&self) -> &str {
58        NAME
59    }
60
61    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
62        Ok(DataType::Float64)
63    }
64
65    fn signature(&self) -> Signature {
66        // First argument: percentile (float64)
67        // Second argument: UDDSketch state (binary)
68        Signature::exact(
69            vec![DataType::Float64, DataType::Binary],
70            Volatility::Immutable,
71        )
72    }
73
74    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
75        if columns.len() != 2 {
76            return InvalidFuncArgsSnafu {
77                err_msg: format!("uddsketch_calc expects 2 arguments, got {}", columns.len()),
78            }
79            .fail();
80        }
81
82        let perc_vec = &columns[0];
83        let sketch_vec = columns[1]
84            .as_any()
85            .downcast_ref::<BinaryVector>()
86            .with_context(|| DowncastVectorSnafu {
87                err_msg: format!("expect BinaryVector, got {}", columns[1].vector_type_name()),
88            })?;
89        let len = sketch_vec.len();
90        let mut builder = Float64VectorBuilder::with_capacity(len);
91
92        for i in 0..len {
93            let perc_opt = perc_vec.get(i).as_f64_lossy();
94            let sketch_opt = sketch_vec.get_data(i);
95
96            if sketch_opt.is_none() || perc_opt.is_none() {
97                builder.push_null();
98                continue;
99            }
100
101            let sketch_bytes = sketch_opt.unwrap();
102            let perc = perc_opt.unwrap();
103
104            // Deserialize the UDDSketch from its bincode representation
105            let sketch: UDDSketch = match bincode::deserialize(sketch_bytes) {
106                Ok(s) => s,
107                Err(e) => {
108                    common_telemetry::trace!("Failed to deserialize UDDSketch: {}", e);
109                    builder.push_null();
110                    continue;
111                }
112            };
113
114            // Check if the sketch is empty, if so, return null
115            // This is important to avoid panics when calling estimate_quantile on an empty sketch
116            // In practice, this will happen if input is all null
117            if sketch.bucket_iter().count() == 0 {
118                builder.push_null();
119                continue;
120            }
121            // Compute the estimated quantile from the sketch
122            let result = sketch.estimate_quantile(perc);
123            builder.push(Some(result));
124        }
125
126        Ok(builder.to_vector())
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use std::sync::Arc;
133
134    use datatypes::vectors::{BinaryVector, Float64Vector};
135
136    use super::*;
137
138    #[test]
139    fn test_uddsketch_calc_function() {
140        let function = UddSketchCalcFunction;
141        assert_eq!("uddsketch_calc", function.name());
142        assert_eq!(
143            DataType::Float64,
144            function.return_type(&[DataType::Float64]).unwrap()
145        );
146
147        // Create a test sketch
148        let mut sketch = UDDSketch::new(128, 0.01);
149        sketch.add_value(10.0);
150        sketch.add_value(20.0);
151        sketch.add_value(30.0);
152        sketch.add_value(40.0);
153        sketch.add_value(50.0);
154        sketch.add_value(60.0);
155        sketch.add_value(70.0);
156        sketch.add_value(80.0);
157        sketch.add_value(90.0);
158        sketch.add_value(100.0);
159
160        // Get expected values directly from the sketch
161        let expected_p50 = sketch.estimate_quantile(0.5);
162        let expected_p90 = sketch.estimate_quantile(0.9);
163        let expected_p95 = sketch.estimate_quantile(0.95);
164
165        let serialized = bincode::serialize(&sketch).unwrap();
166        let percentiles = vec![0.5, 0.9, 0.95];
167
168        let args: Vec<VectorRef> = vec![
169            Arc::new(Float64Vector::from_vec(percentiles.clone())),
170            Arc::new(BinaryVector::from(vec![Some(serialized.clone()); 3])),
171        ];
172
173        let result = function.eval(&FunctionContext::default(), &args).unwrap();
174        assert_eq!(result.len(), 3);
175
176        // Test median (p50)
177        assert!(
178            matches!(result.get(0), datatypes::value::Value::Float64(v) if (v - expected_p50).abs() < 1e-10)
179        );
180        // Test p90
181        assert!(
182            matches!(result.get(1), datatypes::value::Value::Float64(v) if (v - expected_p90).abs() < 1e-10)
183        );
184        // Test p95
185        assert!(
186            matches!(result.get(2), datatypes::value::Value::Float64(v) if (v - expected_p95).abs() < 1e-10)
187        );
188    }
189
190    #[test]
191    fn test_uddsketch_calc_function_errors() {
192        let function = UddSketchCalcFunction;
193
194        // Test with invalid number of arguments
195        let args: Vec<VectorRef> = vec![Arc::new(Float64Vector::from_vec(vec![0.95]))];
196        let result = function.eval(&FunctionContext::default(), &args);
197        assert!(result.is_err());
198        assert!(
199            result
200                .unwrap_err()
201                .to_string()
202                .contains("uddsketch_calc expects 2 arguments")
203        );
204
205        // Test with invalid binary data
206        let args: Vec<VectorRef> = vec![
207            Arc::new(Float64Vector::from_vec(vec![0.95])),
208            Arc::new(BinaryVector::from(vec![Some(vec![1, 2, 3])])), // Invalid binary data
209        ];
210        let result = function.eval(&FunctionContext::default(), &args).unwrap();
211        assert_eq!(result.len(), 1);
212        assert!(matches!(result.get(0), datatypes::value::Value::Null));
213    }
214}