1use std::collections::HashSet;
23use std::fmt::Debug;
24use std::sync::Arc;
25
26use ahash::RandomState;
27use datafusion_common::cast::as_list_array;
28use datafusion_common::error::Result;
29use datafusion_common::hash_utils::create_hashes;
30use datafusion_common::utils::SingleRowListArrayBuilder;
31use datafusion_common::{internal_err, not_impl_err, ScalarValue};
32use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
33use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
34use datafusion_expr::{
35 Accumulator, AggregateUDF, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF,
36 SetMonotonicity, Signature, TypeSignature, Volatility,
37};
38use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
39use datatypes::arrow;
40use datatypes::arrow::array::{
41 Array, ArrayRef, AsArray, BooleanArray, Int64Array, ListArray, UInt64Array,
42};
43use datatypes::arrow::buffer::{OffsetBuffer, ScalarBuffer};
44use datatypes::arrow::datatypes::{DataType, Field};
45
46use crate::function_registry::FunctionRegistry;
47
48type HashValueType = u64;
49
50const RANDOM_SEED_0: u64 = 0x4047821dc6144e4b;
52const RANDOM_SEED_1: u64 = 0x2abddf23ad417112;
53const RANDOM_SEED_2: u64 = 0x6a52eeecd26eff21;
54const RANDOM_SEED_3: u64 = 0x91cf673b965a7875;
55
56impl CountHash {
57 pub fn register(registry: &FunctionRegistry) {
58 registry.register_aggr(CountHash::udf_impl());
59 }
60
61 pub fn udf_impl() -> AggregateUDF {
62 AggregateUDF::new_from_impl(CountHash {
63 signature: Signature::one_of(
64 vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
65 Volatility::Immutable,
66 ),
67 })
68 }
69}
70
71#[derive(Debug, Clone)]
72pub struct CountHash {
73 signature: Signature,
74}
75
76impl AggregateUDFImpl for CountHash {
77 fn as_any(&self) -> &dyn std::any::Any {
78 self
79 }
80
81 fn name(&self) -> &str {
82 "count_hash"
83 }
84
85 fn signature(&self) -> &Signature {
86 &self.signature
87 }
88
89 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
90 Ok(DataType::Int64)
91 }
92
93 fn is_nullable(&self) -> bool {
94 false
95 }
96
97 fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
98 Ok(vec![Field::new_list(
99 format_state_name(args.name, "count_hash"),
100 Field::new_list_field(DataType::UInt64, true),
101 true,
104 )])
105 }
106
107 fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
108 if acc_args.exprs.len() > 1 {
109 return not_impl_err!("count_hash with multiple arguments");
110 }
111
112 Ok(Box::new(CountHashAccumulator {
113 values: HashSet::default(),
114 random_state: RandomState::with_seeds(
115 RANDOM_SEED_0,
116 RANDOM_SEED_1,
117 RANDOM_SEED_2,
118 RANDOM_SEED_3,
119 ),
120 batch_hashes: vec![],
121 }))
122 }
123
124 fn aliases(&self) -> &[String] {
125 &[]
126 }
127
128 fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
129 true
130 }
131
132 fn create_groups_accumulator(
133 &self,
134 args: AccumulatorArgs,
135 ) -> Result<Box<dyn GroupsAccumulator>> {
136 if args.exprs.len() > 1 {
137 return not_impl_err!("count_hash with multiple arguments");
138 }
139
140 Ok(Box::new(CountHashGroupAccumulator::new()))
141 }
142
143 fn reverse_expr(&self) -> ReversedUDAF {
144 ReversedUDAF::Identical
145 }
146
147 fn order_sensitivity(&self) -> AggregateOrderSensitivity {
148 AggregateOrderSensitivity::Insensitive
149 }
150
151 fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
152 Ok(ScalarValue::Int64(Some(0)))
153 }
154
155 fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
156 SetMonotonicity::Increasing
157 }
158}
159
160#[derive(Debug)]
162pub struct CountHashGroupAccumulator {
163 distinct_sets: Vec<HashSet<HashValueType, RandomState>>,
165 random_state: RandomState,
166 batch_hashes: Vec<HashValueType>,
167}
168
169impl Default for CountHashGroupAccumulator {
170 fn default() -> Self {
171 Self::new()
172 }
173}
174
175impl CountHashGroupAccumulator {
176 pub fn new() -> Self {
177 Self {
178 distinct_sets: vec![],
179 random_state: RandomState::with_seeds(
180 RANDOM_SEED_0,
181 RANDOM_SEED_1,
182 RANDOM_SEED_2,
183 RANDOM_SEED_3,
184 ),
185 batch_hashes: vec![],
186 }
187 }
188
189 fn ensure_sets(&mut self, total_num_groups: usize) {
190 if self.distinct_sets.len() < total_num_groups {
191 self.distinct_sets
192 .resize_with(total_num_groups, HashSet::default);
193 }
194 }
195}
196
197impl GroupsAccumulator for CountHashGroupAccumulator {
198 fn update_batch(
199 &mut self,
200 values: &[ArrayRef],
201 group_indices: &[usize],
202 opt_filter: Option<&BooleanArray>,
203 total_num_groups: usize,
204 ) -> Result<()> {
205 assert_eq!(values.len(), 1, "count_hash expects a single argument");
206 self.ensure_sets(total_num_groups);
207
208 let array = &values[0];
209 self.batch_hashes.clear();
210 self.batch_hashes.resize(array.len(), 0);
211 let hashes = create_hashes(
212 &[ArrayRef::clone(array)],
213 &self.random_state,
214 &mut self.batch_hashes,
215 )?;
216
217 let nulls = array.logical_nulls();
220
221 match (nulls.as_ref(), opt_filter) {
222 (None, None) => {
223 for (row_idx, &group_idx) in group_indices.iter().enumerate() {
225 self.distinct_sets[group_idx].insert(hashes[row_idx]);
226 }
227 }
228 (Some(nulls), None) => {
229 for (row_idx, (&group_idx, is_valid)) in
231 group_indices.iter().zip(nulls.iter()).enumerate()
232 {
233 if is_valid {
234 self.distinct_sets[group_idx].insert(hashes[row_idx]);
235 }
236 }
237 }
238 (None, Some(filter)) => {
239 for (row_idx, (&group_idx, filter_value)) in
241 group_indices.iter().zip(filter.iter()).enumerate()
242 {
243 if let Some(true) = filter_value {
244 self.distinct_sets[group_idx].insert(hashes[row_idx]);
245 }
246 }
247 }
248 (Some(nulls), Some(filter)) => {
249 let iter = filter
251 .iter()
252 .zip(group_indices.iter())
253 .zip(nulls.iter())
254 .enumerate();
255
256 for (row_idx, ((filter_value, &group_idx), is_valid)) in iter {
257 if is_valid && filter_value == Some(true) {
258 self.distinct_sets[group_idx].insert(hashes[row_idx]);
259 }
260 }
261 }
262 }
263
264 Ok(())
265 }
266
267 fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
268 let distinct_sets: Vec<HashSet<u64, RandomState>> =
269 emit_to.take_needed(&mut self.distinct_sets);
270
271 let counts = distinct_sets
272 .iter()
273 .map(|set| set.len() as i64)
274 .collect::<Vec<_>>();
275 Ok(Arc::new(Int64Array::from(counts)))
276 }
277
278 fn merge_batch(
279 &mut self,
280 values: &[ArrayRef],
281 group_indices: &[usize],
282 _opt_filter: Option<&BooleanArray>,
283 total_num_groups: usize,
284 ) -> Result<()> {
285 assert_eq!(
286 values.len(),
287 1,
288 "count_hash merge expects a single state array"
289 );
290 self.ensure_sets(total_num_groups);
291
292 let list_array = as_list_array(&values[0])?;
293
294 for (i, &group_idx) in group_indices.iter().enumerate() {
296 if i < list_array.len() {
297 let inner_array = list_array.value(i);
298 let inner_array = inner_array.as_any().downcast_ref::<UInt64Array>().unwrap();
299 for j in 0..inner_array.len() {
301 if !inner_array.is_null(j) {
302 self.distinct_sets[group_idx].insert(inner_array.value(j));
303 }
304 }
305 }
306 }
307
308 Ok(())
309 }
310
311 fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
312 let distinct_sets: Vec<HashSet<u64, RandomState>> =
313 emit_to.take_needed(&mut self.distinct_sets);
314
315 let mut offsets = Vec::with_capacity(distinct_sets.len() + 1);
316 offsets.push(0);
317 let mut curr_len = 0i32;
318
319 let mut value_iter = distinct_sets
320 .into_iter()
321 .flat_map(|set| {
322 curr_len += set.len() as i32;
324 offsets.push(curr_len);
325 set.into_iter()
327 })
328 .peekable();
329 let data_array: ArrayRef = if value_iter.peek().is_none() {
330 arrow::array::new_empty_array(&DataType::UInt64) as _
331 } else {
332 Arc::new(UInt64Array::from_iter_values(value_iter))
333 };
334 let offset_buffer = OffsetBuffer::new(ScalarBuffer::from(offsets));
335
336 let list_array = ListArray::new(
337 Arc::new(Field::new_list_field(DataType::UInt64, true)),
338 offset_buffer,
339 data_array,
340 None,
341 );
342
343 Ok(vec![Arc::new(list_array) as _])
344 }
345
346 fn convert_to_state(
347 &self,
348 values: &[ArrayRef],
349 opt_filter: Option<&BooleanArray>,
350 ) -> Result<Vec<ArrayRef>> {
351 assert_eq!(values.len(), 1, "count_hash expects a single argument");
353 let values = ArrayRef::clone(&values[0]);
354
355 let offsets = OffsetBuffer::new(ScalarBuffer::from_iter(0..values.len() as i32 + 1));
356 let nulls = filtered_null_mask(opt_filter, &values);
357 let list_array = ListArray::new(
358 Arc::new(Field::new_list_field(DataType::UInt64, true)),
359 offsets,
360 values,
361 nulls,
362 );
363
364 Ok(vec![Arc::new(list_array)])
365 }
366
367 fn supports_convert_to_state(&self) -> bool {
368 true
369 }
370
371 fn size(&self) -> usize {
372 let mut size = size_of::<Self>();
374
375 size += size_of::<Vec<HashSet<HashValueType, RandomState>>>()
377 + self.distinct_sets.capacity() * size_of::<HashSet<HashValueType, RandomState>>();
378
379 for set in &self.distinct_sets {
382 size += set.capacity() * size_of::<HashValueType>();
384 }
385
386 size
387 }
388}
389
390#[derive(Debug)]
391struct CountHashAccumulator {
392 values: HashSet<HashValueType, RandomState>,
393 random_state: RandomState,
394 batch_hashes: Vec<HashValueType>,
395}
396
397impl CountHashAccumulator {
398 fn fixed_size(&self) -> usize {
401 size_of_val(self) + (size_of::<HashValueType>() * self.values.capacity())
402 }
403}
404
405impl Accumulator for CountHashAccumulator {
406 fn state(&mut self) -> Result<Vec<ScalarValue>> {
408 let values = self.values.iter().cloned().collect::<Vec<_>>();
409 let arr = Arc::new(UInt64Array::from(values)) as _;
410 let list_scalar = SingleRowListArrayBuilder::new(arr).build_list_scalar();
411 Ok(vec![list_scalar])
412 }
413
414 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
415 if values.is_empty() {
416 return Ok(());
417 }
418
419 let arr = &values[0];
420 if arr.data_type() == &DataType::Null {
421 return Ok(());
422 }
423
424 self.batch_hashes.clear();
425 self.batch_hashes.resize(arr.len(), 0);
426 let hashes = create_hashes(
427 &[ArrayRef::clone(arr)],
428 &self.random_state,
429 &mut self.batch_hashes,
430 )?;
431 for hash in hashes.as_slice() {
432 self.values.insert(*hash);
433 }
434 Ok(())
435 }
436
437 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
443 if states.is_empty() {
444 return Ok(());
445 }
446 assert_eq!(states.len(), 1, "array_agg states must be singleton!");
447 let array = &states[0];
448 let list_array = array.as_list::<i32>();
449 for inner_array in list_array.iter() {
450 let Some(inner_array) = inner_array else {
451 return internal_err!(
452 "Intermediate results of count_hash should always be non null"
453 );
454 };
455 let hash_array = inner_array.as_any().downcast_ref::<UInt64Array>().unwrap();
456 for i in 0..hash_array.len() {
457 self.values.insert(hash_array.value(i));
458 }
459 }
460 Ok(())
461 }
462
463 fn evaluate(&mut self) -> Result<ScalarValue> {
464 Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
465 }
466
467 fn size(&self) -> usize {
468 self.fixed_size()
469 }
470}
471
472#[cfg(test)]
473mod tests {
474 use datatypes::arrow::array::{Array, BooleanArray, Int32Array, Int64Array};
475
476 use super::*;
477
478 fn create_test_accumulator() -> CountHashAccumulator {
479 CountHashAccumulator {
480 values: HashSet::default(),
481 random_state: RandomState::with_seeds(
482 RANDOM_SEED_0,
483 RANDOM_SEED_1,
484 RANDOM_SEED_2,
485 RANDOM_SEED_3,
486 ),
487 batch_hashes: vec![],
488 }
489 }
490
491 #[test]
492 fn test_count_hash_accumulator() -> Result<()> {
493 let mut acc = create_test_accumulator();
494
495 let array = Arc::new(Int32Array::from(vec![
497 Some(1),
498 Some(2),
499 Some(3),
500 Some(1),
501 Some(2),
502 None,
503 ])) as ArrayRef;
504 acc.update_batch(&[array])?;
505 let result = acc.evaluate()?;
506 assert_eq!(result, ScalarValue::Int64(Some(4)));
507
508 let mut acc = create_test_accumulator();
510 let array = Arc::new(Int32Array::from(vec![] as Vec<Option<i32>>)) as ArrayRef;
511 acc.update_batch(&[array])?;
512 let result = acc.evaluate()?;
513 assert_eq!(result, ScalarValue::Int64(Some(0)));
514
515 let mut acc = create_test_accumulator();
517 let array = Arc::new(Int32Array::from(vec![None, None, None])) as ArrayRef;
518 acc.update_batch(&[array])?;
519 let result = acc.evaluate()?;
520 assert_eq!(result, ScalarValue::Int64(Some(1)));
521
522 Ok(())
523 }
524
525 #[test]
526 fn test_count_hash_accumulator_merge() -> Result<()> {
527 let mut acc1 = create_test_accumulator();
529 let array1 = Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])) as ArrayRef;
530 acc1.update_batch(&[array1])?;
531 let state1 = acc1.state()?;
532
533 let mut acc2 = create_test_accumulator();
535 let array2 = Arc::new(Int32Array::from(vec![Some(3), Some(4), Some(5)])) as ArrayRef;
536 acc2.update_batch(&[array2])?;
537 let state2 = acc2.state()?;
538
539 let mut acc_merged = create_test_accumulator();
541 let state_array1 = state1[0].to_array()?;
542 let state_array2 = state2[0].to_array()?;
543
544 acc_merged.merge_batch(&[state_array1])?;
545 acc_merged.merge_batch(&[state_array2])?;
546
547 let result = acc_merged.evaluate()?;
548 assert_eq!(result, ScalarValue::Int64(Some(5)));
550
551 Ok(())
552 }
553
554 fn create_test_group_accumulator() -> CountHashGroupAccumulator {
555 CountHashGroupAccumulator::new()
556 }
557
558 #[test]
559 fn test_count_hash_group_accumulator() -> Result<()> {
560 let mut acc = create_test_group_accumulator();
561 let values = Arc::new(Int32Array::from(vec![1, 2, 1, 3, 2, 4, 5])) as ArrayRef;
562 let group_indices = vec![0, 1, 0, 0, 1, 2, 0];
563 let total_num_groups = 3;
564
565 acc.update_batch(&[values], &group_indices, None, total_num_groups)?;
566
567 let result_array = acc.evaluate(EmitTo::All)?;
568 let result = result_array.as_any().downcast_ref::<Int64Array>().unwrap();
569
570 assert_eq!(result.value(0), 3);
574 assert_eq!(result.value(1), 1);
575 assert_eq!(result.value(2), 1);
576
577 Ok(())
578 }
579
580 #[test]
581 fn test_count_hash_group_accumulator_with_filter() -> Result<()> {
582 let mut acc = create_test_group_accumulator();
583 let values = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])) as ArrayRef;
584 let group_indices = vec![0, 0, 1, 1, 2, 2];
585 let filter = BooleanArray::from(vec![true, false, true, true, false, true]);
586 let total_num_groups = 3;
587
588 acc.update_batch(&[values], &group_indices, Some(&filter), total_num_groups)?;
589
590 let result_array = acc.evaluate(EmitTo::All)?;
591 let result = result_array.as_any().downcast_ref::<Int64Array>().unwrap();
592
593 assert_eq!(result.value(0), 1);
597 assert_eq!(result.value(1), 2);
598 assert_eq!(result.value(2), 1);
599
600 Ok(())
601 }
602
603 #[test]
604 fn test_count_hash_group_accumulator_merge() -> Result<()> {
605 let mut acc1 = create_test_group_accumulator();
607 let values1 = Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as ArrayRef;
608 let group_indices1 = vec![0, 0, 1, 1];
609 acc1.update_batch(&[values1], &group_indices1, None, 2)?;
610 let state1 = acc1.state(EmitTo::All)?;
612
613 let mut acc2 = create_test_group_accumulator();
615 let values2 = Arc::new(Int32Array::from(vec![5, 6, 1, 3])) as ArrayRef;
616 let group_indices2 = vec![2, 2, 0, 1];
618 acc2.update_batch(&[values2], &group_indices2, None, 3)?;
619 let merge_group_indices = vec![0, 2];
625 acc2.merge_batch(&state1, &merge_group_indices, None, 3)?;
626
627 let result_array = acc2.evaluate(EmitTo::All)?;
628 let result = result_array.as_any().downcast_ref::<Int64Array>().unwrap();
629
630 assert_eq!(result.value(0), 2);
635 assert_eq!(result.value(1), 1);
636 assert_eq!(result.value(2), 4);
637
638 Ok(())
639 }
640
641 #[test]
642 fn test_size() {
643 let acc = create_test_group_accumulator();
644 assert!(acc.size() > 0);
646 }
647}