common_function/aggr/
uddsketch_state.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Implementation of the `uddsketch_state` UDAF that generate the state of
16//! UDDSketch for a given set of values.
17//!
18//! The generated state can be used to compute approximate quantiles using
19//! `uddsketch_calc` UDF.
20
21use std::sync::Arc;
22
23use common_query::prelude::*;
24use common_telemetry::trace;
25use datafusion::common::cast::{as_binary_array, as_primitive_array};
26use datafusion::common::not_impl_err;
27use datafusion::error::{DataFusionError, Result as DfResult};
28use datafusion::logical_expr::function::AccumulatorArgs;
29use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF};
30use datafusion::physical_plan::expressions::Literal;
31use datafusion::prelude::create_udaf;
32use datatypes::arrow::array::ArrayRef;
33use datatypes::arrow::datatypes::{DataType, Float64Type};
34use serde::{Deserialize, Serialize};
35use uddsketch::{SketchHashKey, UDDSketch};
36
37pub const UDDSKETCH_STATE_NAME: &str = "uddsketch_state";
38
39pub const UDDSKETCH_MERGE_NAME: &str = "uddsketch_merge";
40
41#[derive(Debug, Serialize, Deserialize)]
42pub struct UddSketchState {
43    uddsketch: UDDSketch,
44    error_rate: f64,
45}
46
47impl UddSketchState {
48    pub fn new(bucket_size: u64, error_rate: f64) -> Self {
49        Self {
50            uddsketch: UDDSketch::new(bucket_size, error_rate),
51            error_rate,
52        }
53    }
54
55    pub fn state_udf_impl() -> AggregateUDF {
56        create_udaf(
57            UDDSKETCH_STATE_NAME,
58            vec![DataType::Int64, DataType::Float64, DataType::Float64],
59            Arc::new(DataType::Binary),
60            Volatility::Immutable,
61            Arc::new(|args| {
62                let (bucket_size, error_rate) = downcast_accumulator_args(args)?;
63                Ok(Box::new(UddSketchState::new(bucket_size, error_rate)))
64            }),
65            Arc::new(vec![DataType::Binary]),
66        )
67    }
68
69    /// Create a UDF for the `uddsketch_merge` function.
70    ///
71    /// `uddsketch_merge` accepts bucket size, error rate, and a binary column of states generated by `uddsketch_state`
72    /// and merges them into a single state.
73    ///
74    /// The bucket size and error rate must be the same as the original state.
75    pub fn merge_udf_impl() -> AggregateUDF {
76        create_udaf(
77            UDDSKETCH_MERGE_NAME,
78            vec![DataType::Int64, DataType::Float64, DataType::Binary],
79            Arc::new(DataType::Binary),
80            Volatility::Immutable,
81            Arc::new(|args| {
82                let (bucket_size, error_rate) = downcast_accumulator_args(args)?;
83                Ok(Box::new(UddSketchState::new(bucket_size, error_rate)))
84            }),
85            Arc::new(vec![DataType::Binary]),
86        )
87    }
88
89    fn update(&mut self, value: f64) {
90        self.uddsketch.add_value(value);
91    }
92
93    fn merge(&mut self, raw: &[u8]) -> DfResult<()> {
94        if let Ok(uddsketch) = bincode::deserialize::<Self>(raw) {
95            if uddsketch.uddsketch.count() != 0 {
96                if self.uddsketch.max_allowed_buckets() != uddsketch.uddsketch.max_allowed_buckets()
97                    || (self.error_rate - uddsketch.error_rate).abs() >= 1e-9
98                {
99                    return Err(DataFusionError::Plan(format!(
100                        "Merging UDDSketch with different parameters: arguments={:?} vs actual input={:?}",
101                        (
102                            self.uddsketch.max_allowed_buckets(),
103                            self.error_rate
104                        ),
105                        (uddsketch.uddsketch.max_allowed_buckets(), uddsketch.error_rate)
106                    )));
107                }
108                self.uddsketch.merge_sketch(&uddsketch.uddsketch);
109            }
110        } else {
111            trace!("Warning: Failed to deserialize UDDSketch from {:?}", raw);
112            return Err(DataFusionError::Plan(
113                "Failed to deserialize UDDSketch from binary".to_string(),
114            ));
115        }
116
117        Ok(())
118    }
119}
120
121fn downcast_accumulator_args(args: AccumulatorArgs) -> DfResult<(u64, f64)> {
122    let bucket_size = match args.exprs[0]
123        .as_any()
124        .downcast_ref::<Literal>()
125        .map(|lit| lit.value())
126    {
127        Some(ScalarValue::Int64(Some(value))) => *value as u64,
128        _ => {
129            return not_impl_err!(
130                "{} not supported for bucket size: {}",
131                UDDSKETCH_STATE_NAME,
132                &args.exprs[0]
133            )
134        }
135    };
136
137    let error_rate = match args.exprs[1]
138        .as_any()
139        .downcast_ref::<Literal>()
140        .map(|lit| lit.value())
141    {
142        Some(ScalarValue::Float64(Some(value))) => *value,
143        _ => {
144            return not_impl_err!(
145                "{} not supported for error rate: {}",
146                UDDSKETCH_STATE_NAME,
147                &args.exprs[1]
148            )
149        }
150    };
151
152    Ok((bucket_size, error_rate))
153}
154
155impl DfAccumulator for UddSketchState {
156    fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> {
157        let array = &values[2]; // the third column is data value
158        match array.data_type() {
159            DataType::Float64 => {
160                let f64_array = as_primitive_array::<Float64Type>(array)?;
161                for v in f64_array.iter().flatten() {
162                    self.update(v);
163                }
164            }
165            // meaning instantiate as `uddsketch_merge`
166            DataType::Binary => self.merge_batch(&[array.clone()])?,
167            _ => {
168                return not_impl_err!(
169                    "UDDSketch functions do not support data type: {}",
170                    array.data_type()
171                )
172            }
173        }
174
175        Ok(())
176    }
177
178    fn evaluate(&mut self) -> DfResult<ScalarValue> {
179        Ok(ScalarValue::Binary(Some(
180            bincode::serialize(&self).map_err(|e| {
181                DataFusionError::Internal(format!("Failed to serialize UDDSketch: {}", e))
182            })?,
183        )))
184    }
185
186    fn size(&self) -> usize {
187        // Base size of UDDSketch struct fields
188        let mut total_size = std::mem::size_of::<f64>() * 3 + // alpha, gamma, values_sum
189                            std::mem::size_of::<u32>() +      // compactions
190                            std::mem::size_of::<u64>() * 2; // max_buckets, num_values
191
192        // Size of buckets (SketchHashMap)
193        // Each bucket entry contains:
194        // - SketchHashKey (enum with i64/Zero/Invalid variants)
195        // - SketchHashEntry (count: u64, next: SketchHashKey)
196        let bucket_entry_size = std::mem::size_of::<SketchHashKey>() + // key
197                               std::mem::size_of::<u64>() +            // count
198                               std::mem::size_of::<SketchHashKey>(); // next
199
200        total_size += self.uddsketch.current_buckets_count() * bucket_entry_size;
201
202        total_size
203    }
204
205    fn state(&mut self) -> DfResult<Vec<ScalarValue>> {
206        Ok(vec![ScalarValue::Binary(Some(
207            bincode::serialize(&self).map_err(|e| {
208                DataFusionError::Internal(format!("Failed to serialize UDDSketch: {}", e))
209            })?,
210        ))])
211    }
212
213    fn merge_batch(&mut self, states: &[ArrayRef]) -> DfResult<()> {
214        let array = &states[0];
215        let binary_array = as_binary_array(array)?;
216        for v in binary_array.iter().flatten() {
217            self.merge(v)?;
218        }
219
220        Ok(())
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use datafusion::arrow::array::{BinaryArray, Float64Array};
227
228    use super::*;
229
230    #[test]
231    fn test_uddsketch_state_basic() {
232        let mut state = UddSketchState::new(10, 0.01);
233        state.update(1.0);
234        state.update(2.0);
235        state.update(3.0);
236
237        let result = state.evaluate().unwrap();
238        if let ScalarValue::Binary(Some(bytes)) = result {
239            let deserialized: UddSketchState = bincode::deserialize(&bytes).unwrap();
240            assert_eq!(deserialized.uddsketch.count(), 3);
241        } else {
242            panic!("Expected binary scalar value");
243        }
244    }
245
246    #[test]
247    fn test_uddsketch_state_roundtrip() {
248        let mut state = UddSketchState::new(10, 0.01);
249        state.update(1.0);
250        state.update(2.0);
251
252        // Serialize
253        let serialized = state.evaluate().unwrap();
254
255        // Create new state and merge the serialized data
256        let mut new_state = UddSketchState::new(10, 0.01);
257        if let ScalarValue::Binary(Some(bytes)) = &serialized {
258            new_state.merge(bytes).unwrap();
259
260            // Verify the merged state matches original by comparing deserialized values
261            let original_sketch: UddSketchState = bincode::deserialize(bytes).unwrap();
262            let original_sketch = original_sketch.uddsketch;
263            let new_result = new_state.evaluate().unwrap();
264            if let ScalarValue::Binary(Some(new_bytes)) = new_result {
265                let new_sketch: UddSketchState = bincode::deserialize(&new_bytes).unwrap();
266                let new_sketch = new_sketch.uddsketch;
267                assert_eq!(original_sketch.count(), new_sketch.count());
268                assert_eq!(original_sketch.sum(), new_sketch.sum());
269                assert_eq!(original_sketch.mean(), new_sketch.mean());
270                assert_eq!(original_sketch.max_error(), new_sketch.max_error());
271                // Compare a few quantiles to ensure statistical equivalence
272                for q in [0.1, 0.5, 0.9].iter() {
273                    assert!(
274                        (original_sketch.estimate_quantile(*q) - new_sketch.estimate_quantile(*q))
275                            .abs()
276                            < 1e-10,
277                        "Quantile {} mismatch: original={}, new={}",
278                        q,
279                        original_sketch.estimate_quantile(*q),
280                        new_sketch.estimate_quantile(*q)
281                    );
282                }
283            } else {
284                panic!("Expected binary scalar value");
285            }
286        } else {
287            panic!("Expected binary scalar value");
288        }
289    }
290
291    #[test]
292    fn test_uddsketch_state_batch_update() {
293        let mut state = UddSketchState::new(10, 0.01);
294        let values = vec![1.0f64, 2.0, 3.0];
295        let array = Arc::new(Float64Array::from(values)) as ArrayRef;
296
297        state
298            .update_batch(&[array.clone(), array.clone(), array])
299            .unwrap();
300
301        let result = state.evaluate().unwrap();
302        if let ScalarValue::Binary(Some(bytes)) = result {
303            let deserialized: UddSketchState = bincode::deserialize(&bytes).unwrap();
304            let deserialized = deserialized.uddsketch;
305            assert_eq!(deserialized.count(), 3);
306        } else {
307            panic!("Expected binary scalar value");
308        }
309    }
310
311    #[test]
312    fn test_uddsketch_state_merge_batch() {
313        let mut state1 = UddSketchState::new(10, 0.01);
314        state1.update(1.0);
315        let state1_binary = state1.evaluate().unwrap();
316
317        let mut state2 = UddSketchState::new(10, 0.01);
318        state2.update(2.0);
319        let state2_binary = state2.evaluate().unwrap();
320
321        let mut merged_state = UddSketchState::new(10, 0.01);
322        if let (ScalarValue::Binary(Some(bytes1)), ScalarValue::Binary(Some(bytes2))) =
323            (&state1_binary, &state2_binary)
324        {
325            let binary_array = Arc::new(BinaryArray::from(vec![
326                bytes1.as_slice(),
327                bytes2.as_slice(),
328            ])) as ArrayRef;
329            merged_state.merge_batch(&[binary_array]).unwrap();
330
331            let result = merged_state.evaluate().unwrap();
332            if let ScalarValue::Binary(Some(bytes)) = result {
333                let deserialized: UddSketchState = bincode::deserialize(&bytes).unwrap();
334                let deserialized = deserialized.uddsketch;
335                assert_eq!(deserialized.count(), 2);
336            } else {
337                panic!("Expected binary scalar value");
338            }
339        } else {
340            panic!("Expected binary scalar values");
341        }
342    }
343
344    #[test]
345    fn test_uddsketch_state_size() {
346        let mut state = UddSketchState::new(10, 0.01);
347        let initial_size = state.size();
348
349        // Add some values to create buckets
350        state.update(1.0);
351        state.update(2.0);
352        state.update(3.0);
353
354        let size_with_values = state.size();
355        assert!(
356            size_with_values > initial_size,
357            "Size should increase after adding values: initial={}, with_values={}",
358            initial_size,
359            size_with_values
360        );
361
362        // Verify size increases with more buckets
363        state.update(10.0); // This should create a new bucket
364        assert!(
365            state.size() > size_with_values,
366            "Size should increase after adding new bucket: prev={}, new={}",
367            size_with_values,
368            state.size()
369        );
370    }
371}