1use std::sync::Arc;
22
23use common_query::prelude::*;
24use common_telemetry::trace;
25use datafusion::common::cast::{as_binary_array, as_primitive_array};
26use datafusion::common::not_impl_err;
27use datafusion::error::{DataFusionError, Result as DfResult};
28use datafusion::logical_expr::function::AccumulatorArgs;
29use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF};
30use datafusion::physical_plan::expressions::Literal;
31use datafusion::prelude::create_udaf;
32use datatypes::arrow::array::ArrayRef;
33use datatypes::arrow::datatypes::{DataType, Float64Type};
34use serde::{Deserialize, Serialize};
35use uddsketch::{SketchHashKey, UDDSketch};
36
37pub const UDDSKETCH_STATE_NAME: &str = "uddsketch_state";
38
39pub const UDDSKETCH_MERGE_NAME: &str = "uddsketch_merge";
40
41#[derive(Debug, Serialize, Deserialize)]
42pub struct UddSketchState {
43 uddsketch: UDDSketch,
44 error_rate: f64,
45}
46
47impl UddSketchState {
48 pub fn new(bucket_size: u64, error_rate: f64) -> Self {
49 Self {
50 uddsketch: UDDSketch::new(bucket_size, error_rate),
51 error_rate,
52 }
53 }
54
55 pub fn state_udf_impl() -> AggregateUDF {
56 create_udaf(
57 UDDSKETCH_STATE_NAME,
58 vec![DataType::Int64, DataType::Float64, DataType::Float64],
59 Arc::new(DataType::Binary),
60 Volatility::Immutable,
61 Arc::new(|args| {
62 let (bucket_size, error_rate) = downcast_accumulator_args(args)?;
63 Ok(Box::new(UddSketchState::new(bucket_size, error_rate)))
64 }),
65 Arc::new(vec![DataType::Binary]),
66 )
67 }
68
69 pub fn merge_udf_impl() -> AggregateUDF {
76 create_udaf(
77 UDDSKETCH_MERGE_NAME,
78 vec![DataType::Int64, DataType::Float64, DataType::Binary],
79 Arc::new(DataType::Binary),
80 Volatility::Immutable,
81 Arc::new(|args| {
82 let (bucket_size, error_rate) = downcast_accumulator_args(args)?;
83 Ok(Box::new(UddSketchState::new(bucket_size, error_rate)))
84 }),
85 Arc::new(vec![DataType::Binary]),
86 )
87 }
88
89 fn update(&mut self, value: f64) {
90 self.uddsketch.add_value(value);
91 }
92
93 fn merge(&mut self, raw: &[u8]) -> DfResult<()> {
94 if let Ok(uddsketch) = bincode::deserialize::<Self>(raw) {
95 if uddsketch.uddsketch.count() != 0 {
96 if self.uddsketch.max_allowed_buckets() != uddsketch.uddsketch.max_allowed_buckets()
97 || (self.error_rate - uddsketch.error_rate).abs() >= 1e-9
98 {
99 return Err(DataFusionError::Plan(format!(
100 "Merging UDDSketch with different parameters: arguments={:?} vs actual input={:?}",
101 (
102 self.uddsketch.max_allowed_buckets(),
103 self.error_rate
104 ),
105 (uddsketch.uddsketch.max_allowed_buckets(), uddsketch.error_rate)
106 )));
107 }
108 self.uddsketch.merge_sketch(&uddsketch.uddsketch);
109 }
110 } else {
111 trace!("Warning: Failed to deserialize UDDSketch from {:?}", raw);
112 return Err(DataFusionError::Plan(
113 "Failed to deserialize UDDSketch from binary".to_string(),
114 ));
115 }
116
117 Ok(())
118 }
119}
120
121fn downcast_accumulator_args(args: AccumulatorArgs) -> DfResult<(u64, f64)> {
122 let bucket_size = match args.exprs[0]
123 .as_any()
124 .downcast_ref::<Literal>()
125 .map(|lit| lit.value())
126 {
127 Some(ScalarValue::Int64(Some(value))) => *value as u64,
128 _ => {
129 return not_impl_err!(
130 "{} not supported for bucket size: {}",
131 UDDSKETCH_STATE_NAME,
132 &args.exprs[0]
133 )
134 }
135 };
136
137 let error_rate = match args.exprs[1]
138 .as_any()
139 .downcast_ref::<Literal>()
140 .map(|lit| lit.value())
141 {
142 Some(ScalarValue::Float64(Some(value))) => *value,
143 _ => {
144 return not_impl_err!(
145 "{} not supported for error rate: {}",
146 UDDSKETCH_STATE_NAME,
147 &args.exprs[1]
148 )
149 }
150 };
151
152 Ok((bucket_size, error_rate))
153}
154
155impl DfAccumulator for UddSketchState {
156 fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> {
157 let array = &values[2]; match array.data_type() {
159 DataType::Float64 => {
160 let f64_array = as_primitive_array::<Float64Type>(array)?;
161 for v in f64_array.iter().flatten() {
162 self.update(v);
163 }
164 }
165 DataType::Binary => self.merge_batch(&[array.clone()])?,
167 _ => {
168 return not_impl_err!(
169 "UDDSketch functions do not support data type: {}",
170 array.data_type()
171 )
172 }
173 }
174
175 Ok(())
176 }
177
178 fn evaluate(&mut self) -> DfResult<ScalarValue> {
179 Ok(ScalarValue::Binary(Some(
180 bincode::serialize(&self).map_err(|e| {
181 DataFusionError::Internal(format!("Failed to serialize UDDSketch: {}", e))
182 })?,
183 )))
184 }
185
186 fn size(&self) -> usize {
187 let mut total_size = std::mem::size_of::<f64>() * 3 + std::mem::size_of::<u32>() + std::mem::size_of::<u64>() * 2; let bucket_entry_size = std::mem::size_of::<SketchHashKey>() + std::mem::size_of::<u64>() + std::mem::size_of::<SketchHashKey>(); total_size += self.uddsketch.current_buckets_count() * bucket_entry_size;
201
202 total_size
203 }
204
205 fn state(&mut self) -> DfResult<Vec<ScalarValue>> {
206 Ok(vec![ScalarValue::Binary(Some(
207 bincode::serialize(&self).map_err(|e| {
208 DataFusionError::Internal(format!("Failed to serialize UDDSketch: {}", e))
209 })?,
210 ))])
211 }
212
213 fn merge_batch(&mut self, states: &[ArrayRef]) -> DfResult<()> {
214 let array = &states[0];
215 let binary_array = as_binary_array(array)?;
216 for v in binary_array.iter().flatten() {
217 self.merge(v)?;
218 }
219
220 Ok(())
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use datafusion::arrow::array::{BinaryArray, Float64Array};
227
228 use super::*;
229
230 #[test]
231 fn test_uddsketch_state_basic() {
232 let mut state = UddSketchState::new(10, 0.01);
233 state.update(1.0);
234 state.update(2.0);
235 state.update(3.0);
236
237 let result = state.evaluate().unwrap();
238 if let ScalarValue::Binary(Some(bytes)) = result {
239 let deserialized: UddSketchState = bincode::deserialize(&bytes).unwrap();
240 assert_eq!(deserialized.uddsketch.count(), 3);
241 } else {
242 panic!("Expected binary scalar value");
243 }
244 }
245
246 #[test]
247 fn test_uddsketch_state_roundtrip() {
248 let mut state = UddSketchState::new(10, 0.01);
249 state.update(1.0);
250 state.update(2.0);
251
252 let serialized = state.evaluate().unwrap();
254
255 let mut new_state = UddSketchState::new(10, 0.01);
257 if let ScalarValue::Binary(Some(bytes)) = &serialized {
258 new_state.merge(bytes).unwrap();
259
260 let original_sketch: UddSketchState = bincode::deserialize(bytes).unwrap();
262 let original_sketch = original_sketch.uddsketch;
263 let new_result = new_state.evaluate().unwrap();
264 if let ScalarValue::Binary(Some(new_bytes)) = new_result {
265 let new_sketch: UddSketchState = bincode::deserialize(&new_bytes).unwrap();
266 let new_sketch = new_sketch.uddsketch;
267 assert_eq!(original_sketch.count(), new_sketch.count());
268 assert_eq!(original_sketch.sum(), new_sketch.sum());
269 assert_eq!(original_sketch.mean(), new_sketch.mean());
270 assert_eq!(original_sketch.max_error(), new_sketch.max_error());
271 for q in [0.1, 0.5, 0.9].iter() {
273 assert!(
274 (original_sketch.estimate_quantile(*q) - new_sketch.estimate_quantile(*q))
275 .abs()
276 < 1e-10,
277 "Quantile {} mismatch: original={}, new={}",
278 q,
279 original_sketch.estimate_quantile(*q),
280 new_sketch.estimate_quantile(*q)
281 );
282 }
283 } else {
284 panic!("Expected binary scalar value");
285 }
286 } else {
287 panic!("Expected binary scalar value");
288 }
289 }
290
291 #[test]
292 fn test_uddsketch_state_batch_update() {
293 let mut state = UddSketchState::new(10, 0.01);
294 let values = vec![1.0f64, 2.0, 3.0];
295 let array = Arc::new(Float64Array::from(values)) as ArrayRef;
296
297 state
298 .update_batch(&[array.clone(), array.clone(), array])
299 .unwrap();
300
301 let result = state.evaluate().unwrap();
302 if let ScalarValue::Binary(Some(bytes)) = result {
303 let deserialized: UddSketchState = bincode::deserialize(&bytes).unwrap();
304 let deserialized = deserialized.uddsketch;
305 assert_eq!(deserialized.count(), 3);
306 } else {
307 panic!("Expected binary scalar value");
308 }
309 }
310
311 #[test]
312 fn test_uddsketch_state_merge_batch() {
313 let mut state1 = UddSketchState::new(10, 0.01);
314 state1.update(1.0);
315 let state1_binary = state1.evaluate().unwrap();
316
317 let mut state2 = UddSketchState::new(10, 0.01);
318 state2.update(2.0);
319 let state2_binary = state2.evaluate().unwrap();
320
321 let mut merged_state = UddSketchState::new(10, 0.01);
322 if let (ScalarValue::Binary(Some(bytes1)), ScalarValue::Binary(Some(bytes2))) =
323 (&state1_binary, &state2_binary)
324 {
325 let binary_array = Arc::new(BinaryArray::from(vec![
326 bytes1.as_slice(),
327 bytes2.as_slice(),
328 ])) as ArrayRef;
329 merged_state.merge_batch(&[binary_array]).unwrap();
330
331 let result = merged_state.evaluate().unwrap();
332 if let ScalarValue::Binary(Some(bytes)) = result {
333 let deserialized: UddSketchState = bincode::deserialize(&bytes).unwrap();
334 let deserialized = deserialized.uddsketch;
335 assert_eq!(deserialized.count(), 2);
336 } else {
337 panic!("Expected binary scalar value");
338 }
339 } else {
340 panic!("Expected binary scalar values");
341 }
342 }
343
344 #[test]
345 fn test_uddsketch_state_size() {
346 let mut state = UddSketchState::new(10, 0.01);
347 let initial_size = state.size();
348
349 state.update(1.0);
351 state.update(2.0);
352 state.update(3.0);
353
354 let size_with_values = state.size();
355 assert!(
356 size_with_values > initial_size,
357 "Size should increase after adding values: initial={}, with_values={}",
358 initial_size,
359 size_with_values
360 );
361
362 state.update(10.0); assert!(
365 state.size() > size_with_values,
366 "Size should increase after adding new bucket: prev={}, new={}",
367 size_with_values,
368 state.size()
369 );
370 }
371}