1use std::sync::Arc;
16
17use datafusion::arrow::array::{ArrayRef, AsArray};
18use datafusion::common::cast::{as_list_array, as_primitive_array, as_struct_array};
19use datafusion::error::{DataFusionError, Result as DfResult};
20use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF, Volatility};
21use datafusion::physical_plan::expressions::Literal;
22use datafusion::prelude::create_udaf;
23use datafusion_common::ScalarValue;
24use datafusion_expr::function::AccumulatorArgs;
25use datatypes::arrow::array::{ListArray, StructArray};
26use datatypes::arrow::datatypes::{DataType, Field, Float64Type};
27
28use crate::functions::quantile::quantile_impl;
29
30pub const QUANTILE_NAME: &str = "quantile";
31
32const VALUES_FIELD_NAME: &str = "values";
33const DEFAULT_LIST_FIELD_NAME: &str = "item";
34
35#[derive(Debug, Default)]
36pub struct QuantileAccumulator {
37 q: f64,
38 values: Vec<Option<f64>>,
39}
40
41pub fn quantile_udaf() -> Arc<AggregateUDF> {
44 Arc::new(create_udaf(
45 QUANTILE_NAME,
46 vec![DataType::Float64, DataType::Float64],
48 Arc::new(DataType::Float64),
50 Volatility::Volatile,
51 Arc::new(QuantileAccumulator::from_args),
53 Arc::new(vec![DataType::Struct(
55 vec![Field::new(
56 VALUES_FIELD_NAME,
57 DataType::List(Arc::new(Field::new(
58 DEFAULT_LIST_FIELD_NAME,
59 DataType::Float64,
60 true,
61 ))),
62 false,
63 )]
64 .into(),
65 )]),
66 ))
67}
68
69impl QuantileAccumulator {
70 fn new(q: f64) -> Self {
71 Self {
72 q,
73 ..Default::default()
74 }
75 }
76
77 pub fn from_args(args: AccumulatorArgs) -> DfResult<Box<dyn DfAccumulator>> {
78 if args.exprs.len() != 2 {
79 return Err(DataFusionError::Plan(
80 "Quantile function should have 2 inputs".to_string(),
81 ));
82 }
83
84 let q = match &args.exprs[0]
85 .as_any()
86 .downcast_ref::<Literal>()
87 .map(|lit| lit.value())
88 {
89 Some(ScalarValue::Float64(Some(q))) => *q,
90 _ => {
91 return Err(DataFusionError::Internal(
92 "Invalid quantile value".to_string(),
93 ))
94 }
95 };
96
97 Ok(Box::new(Self::new(q)))
98 }
99}
100
101impl DfAccumulator for QuantileAccumulator {
102 fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> {
103 let f64_array = values[1].as_primitive::<Float64Type>();
104
105 self.values.extend(f64_array);
106
107 Ok(())
108 }
109
110 fn evaluate(&mut self) -> DfResult<ScalarValue> {
111 let values: Vec<_> = self.values.iter().map(|v| v.unwrap_or(0.0)).collect();
112
113 let result = quantile_impl(&values, self.q);
114
115 ScalarValue::new_primitive::<Float64Type>(result, &DataType::Float64)
116 }
117
118 fn size(&self) -> usize {
119 std::mem::size_of::<Self>() + self.values.capacity() * std::mem::size_of::<Option<f64>>()
120 }
121
122 fn state(&mut self) -> DfResult<Vec<ScalarValue>> {
123 let values_array = Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
124 Some(self.values.clone()),
125 ]));
126
127 let state_struct = StructArray::new(
128 vec![Field::new(
129 VALUES_FIELD_NAME,
130 DataType::List(Arc::new(Field::new(
131 DEFAULT_LIST_FIELD_NAME,
132 DataType::Float64,
133 true,
134 ))),
135 false,
136 )]
137 .into(),
138 vec![values_array],
139 None,
140 );
141
142 Ok(vec![ScalarValue::Struct(Arc::new(state_struct))])
143 }
144
145 fn merge_batch(&mut self, states: &[ArrayRef]) -> DfResult<()> {
146 if states.is_empty() {
147 return Ok(());
148 }
149
150 for state in states {
151 let state = as_struct_array(state)?;
152
153 for list in as_list_array(state.column(0))?.iter().flatten() {
154 let f64_array = as_primitive_array::<Float64Type>(&list)?.clone();
155 self.values.extend(&f64_array);
156 }
157 }
158
159 Ok(())
160 }
161}
162#[cfg(test)]
163mod tests {
164 use std::sync::Arc;
165
166 use datafusion::arrow::array::{ArrayRef, Float64Array};
167 use datafusion_common::ScalarValue;
168
169 use super::*;
170
171 fn create_f64_array(values: Vec<Option<f64>>) -> ArrayRef {
172 Arc::new(Float64Array::from(values)) as ArrayRef
173 }
174
175 #[test]
176 fn test_quantile_accumulator_empty() {
177 let mut accumulator = QuantileAccumulator::new(0.5);
178
179 let result = accumulator.evaluate().unwrap();
180
181 match result {
182 ScalarValue::Float64(_) => (),
183 _ => panic!("Expected Float64 scalar value"),
184 }
185 }
186
187 #[test]
188 fn test_quantile_accumulator_single_value() {
189 let mut accumulator = QuantileAccumulator::new(0.5);
190 let q = create_f64_array(vec![Some(0.5)]);
191 let input = create_f64_array(vec![Some(10.0)]);
192
193 accumulator.update_batch(&[q, input]).unwrap();
194 let result = accumulator.evaluate().unwrap();
195
196 assert_eq!(result, ScalarValue::Float64(Some(10.0)));
197 }
198
199 #[test]
200 fn test_quantile_accumulator_multiple_values() {
201 let mut accumulator = QuantileAccumulator::new(0.5);
202 let q = create_f64_array(vec![Some(0.5)]);
203 let input = create_f64_array(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)]);
204
205 accumulator.update_batch(&[q, input]).unwrap();
206 let result = accumulator.evaluate().unwrap();
207
208 assert_eq!(result, ScalarValue::Float64(Some(3.0)));
209 }
210
211 #[test]
212 fn test_quantile_accumulator_with_nulls() {
213 let mut accumulator = QuantileAccumulator::new(0.5);
214 let q = create_f64_array(vec![Some(0.5)]);
215 let input = create_f64_array(vec![Some(1.0), None, Some(3.0), Some(4.0), Some(5.0)]);
216
217 accumulator.update_batch(&[q, input]).unwrap();
218
219 let result = accumulator.evaluate().unwrap();
220 assert_eq!(result, ScalarValue::Float64(Some(3.0)));
221 }
222
223 #[test]
224 fn test_quantile_accumulator_multiple_batches() {
225 let mut accumulator = QuantileAccumulator::new(0.5);
226 let q = create_f64_array(vec![Some(0.5)]);
227 let input1 = create_f64_array(vec![Some(1.0), Some(2.0)]);
228 let input2 = create_f64_array(vec![Some(3.0), Some(4.0), Some(5.0)]);
229
230 accumulator.update_batch(&[q.clone(), input1]).unwrap();
231 accumulator.update_batch(&[q, input2]).unwrap();
232
233 let result = accumulator.evaluate().unwrap();
234 assert_eq!(result, ScalarValue::Float64(Some(3.0)));
235 }
236
237 #[test]
238 fn test_quantile_accumulator_different_quantiles() {
239 let mut min_accumulator = QuantileAccumulator::new(0.0);
240 let q = create_f64_array(vec![Some(0.0)]);
241 let input = create_f64_array(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0), Some(5.0)]);
242 min_accumulator.update_batch(&[q, input.clone()]).unwrap();
243 assert_eq!(
244 min_accumulator.evaluate().unwrap(),
245 ScalarValue::Float64(Some(1.0))
246 );
247
248 let mut q1_accumulator = QuantileAccumulator::new(0.25);
249 let q = create_f64_array(vec![Some(0.25)]);
250 q1_accumulator.update_batch(&[q, input.clone()]).unwrap();
251 assert_eq!(
252 q1_accumulator.evaluate().unwrap(),
253 ScalarValue::Float64(Some(2.0))
254 );
255
256 let mut q3_accumulator = QuantileAccumulator::new(0.75);
257 let q = create_f64_array(vec![Some(0.75)]);
258 q3_accumulator.update_batch(&[q, input.clone()]).unwrap();
259 assert_eq!(
260 q3_accumulator.evaluate().unwrap(),
261 ScalarValue::Float64(Some(4.0))
262 );
263
264 let mut max_accumulator = QuantileAccumulator::new(1.0);
265 let q = create_f64_array(vec![Some(1.0)]);
266 max_accumulator.update_batch(&[q, input]).unwrap();
267 assert_eq!(
268 max_accumulator.evaluate().unwrap(),
269 ScalarValue::Float64(Some(5.0))
270 );
271 }
272
273 #[test]
274 fn test_quantile_accumulator_size() {
275 let mut accumulator = QuantileAccumulator::new(0.5);
276 let q = create_f64_array(vec![Some(0.5)]);
277 let input = create_f64_array(vec![Some(1.0), Some(2.0), Some(3.0)]);
278
279 let initial_size = accumulator.size();
280 accumulator.update_batch(&[q, input]).unwrap();
281 let after_update_size = accumulator.size();
282
283 assert!(after_update_size >= initial_size);
284 }
285
286 #[test]
287 fn test_quantile_accumulator_state_and_merge() -> DfResult<()> {
288 let mut acc1 = QuantileAccumulator::new(0.5);
289 let q = create_f64_array(vec![Some(0.5)]);
290 let input1 = create_f64_array(vec![Some(1.0), Some(2.0)]);
291 acc1.update_batch(&[q, input1])?;
292
293 let state1 = acc1.state()?;
294
295 let mut acc2 = QuantileAccumulator::new(0.5);
296 let q = create_f64_array(vec![Some(0.5)]);
297 let input2 = create_f64_array(vec![Some(3.0), Some(4.0), Some(5.0)]);
298 acc2.update_batch(&[q, input2])?;
299
300 let mut struct_builders = vec![];
301 for scalar in &state1 {
302 if let ScalarValue::Struct(struct_array) = scalar {
303 struct_builders.push(struct_array.clone() as ArrayRef);
304 }
305 }
306
307 acc2.merge_batch(&struct_builders)?;
308
309 let result = acc2.evaluate()?;
310
311 assert_eq!(result, ScalarValue::Float64(Some(3.0)));
312
313 Ok(())
314 }
315
316 #[test]
317 fn test_quantile_accumulator_with_extreme_values() {
318 let mut accumulator = QuantileAccumulator::new(0.5);
319 let q = create_f64_array(vec![Some(0.5)]);
320 let input = create_f64_array(vec![Some(f64::MAX), Some(f64::MIN), Some(0.0)]);
321
322 accumulator.update_batch(&[q, input]).unwrap();
323 let _result = accumulator.evaluate().unwrap();
324 }
325
326 #[test]
327 fn test_quantile_udaf_creation() {
328 let udaf = quantile_udaf();
329
330 assert_eq!(udaf.name(), QUANTILE_NAME);
331 assert_eq!(udaf.return_type(&[]).unwrap(), DataType::Float64);
332 }
333}