common_function/aggrs/
count_hash.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! `CountHash` / `count_hash` is a hash-based approximate distinct count function.
16//!
17//! It is a variant of `CountDistinct` that uses a hash function to approximate the
18//! distinct count.
19//! It is designed to be more efficient than `CountDistinct` for large datasets,
20//! but it is not as accurate, as the hash value may be collision.
21
22use std::collections::HashSet;
23use std::fmt::Debug;
24use std::sync::Arc;
25
26use ahash::RandomState;
27use datafusion_common::cast::as_list_array;
28use datafusion_common::error::Result;
29use datafusion_common::hash_utils::create_hashes;
30use datafusion_common::utils::SingleRowListArrayBuilder;
31use datafusion_common::{internal_err, not_impl_err, ScalarValue};
32use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
33use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
34use datafusion_expr::{
35    Accumulator, AggregateUDF, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF,
36    SetMonotonicity, Signature, TypeSignature, Volatility,
37};
38use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
39use datatypes::arrow;
40use datatypes::arrow::array::{
41    Array, ArrayRef, AsArray, BooleanArray, Int64Array, ListArray, UInt64Array,
42};
43use datatypes::arrow::buffer::{OffsetBuffer, ScalarBuffer};
44use datatypes::arrow::datatypes::{DataType, Field};
45
46use crate::function_registry::FunctionRegistry;
47
48type HashValueType = u64;
49
50// read from /dev/urandom 4047821dc6144e4b2abddf23ad4171126a52eeecd26eff2191cf673b965a7875
51const RANDOM_SEED_0: u64 = 0x4047821dc6144e4b;
52const RANDOM_SEED_1: u64 = 0x2abddf23ad417112;
53const RANDOM_SEED_2: u64 = 0x6a52eeecd26eff21;
54const RANDOM_SEED_3: u64 = 0x91cf673b965a7875;
55
56impl CountHash {
57    pub fn register(registry: &FunctionRegistry) {
58        registry.register_aggr(CountHash::udf_impl());
59    }
60
61    pub fn udf_impl() -> AggregateUDF {
62        AggregateUDF::new_from_impl(CountHash {
63            signature: Signature::one_of(
64                vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
65                Volatility::Immutable,
66            ),
67        })
68    }
69}
70
71#[derive(Debug, Clone)]
72pub struct CountHash {
73    signature: Signature,
74}
75
76impl AggregateUDFImpl for CountHash {
77    fn as_any(&self) -> &dyn std::any::Any {
78        self
79    }
80
81    fn name(&self) -> &str {
82        "count_hash"
83    }
84
85    fn signature(&self) -> &Signature {
86        &self.signature
87    }
88
89    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
90        Ok(DataType::Int64)
91    }
92
93    fn is_nullable(&self) -> bool {
94        false
95    }
96
97    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
98        Ok(vec![Field::new_list(
99            format_state_name(args.name, "count_hash"),
100            Field::new_list_field(DataType::UInt64, true),
101            // For count_hash accumulator, null list item stands for an
102            // empty value set (i.e., all NULL value so far for that group).
103            true,
104        )])
105    }
106
107    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
108        if acc_args.exprs.len() > 1 {
109            return not_impl_err!("count_hash with multiple arguments");
110        }
111
112        Ok(Box::new(CountHashAccumulator {
113            values: HashSet::default(),
114            random_state: RandomState::with_seeds(
115                RANDOM_SEED_0,
116                RANDOM_SEED_1,
117                RANDOM_SEED_2,
118                RANDOM_SEED_3,
119            ),
120            batch_hashes: vec![],
121        }))
122    }
123
124    fn aliases(&self) -> &[String] {
125        &[]
126    }
127
128    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
129        true
130    }
131
132    fn create_groups_accumulator(
133        &self,
134        args: AccumulatorArgs,
135    ) -> Result<Box<dyn GroupsAccumulator>> {
136        if args.exprs.len() > 1 {
137            return not_impl_err!("count_hash with multiple arguments");
138        }
139
140        Ok(Box::new(CountHashGroupAccumulator::new()))
141    }
142
143    fn reverse_expr(&self) -> ReversedUDAF {
144        ReversedUDAF::Identical
145    }
146
147    fn order_sensitivity(&self) -> AggregateOrderSensitivity {
148        AggregateOrderSensitivity::Insensitive
149    }
150
151    fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
152        Ok(ScalarValue::Int64(Some(0)))
153    }
154
155    fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {
156        SetMonotonicity::Increasing
157    }
158}
159
160/// GroupsAccumulator for `count_hash` aggregate function
161#[derive(Debug)]
162pub struct CountHashGroupAccumulator {
163    /// One HashSet per group to track distinct values
164    distinct_sets: Vec<HashSet<HashValueType, RandomState>>,
165    random_state: RandomState,
166    batch_hashes: Vec<HashValueType>,
167}
168
169impl Default for CountHashGroupAccumulator {
170    fn default() -> Self {
171        Self::new()
172    }
173}
174
175impl CountHashGroupAccumulator {
176    pub fn new() -> Self {
177        Self {
178            distinct_sets: vec![],
179            random_state: RandomState::with_seeds(
180                RANDOM_SEED_0,
181                RANDOM_SEED_1,
182                RANDOM_SEED_2,
183                RANDOM_SEED_3,
184            ),
185            batch_hashes: vec![],
186        }
187    }
188
189    fn ensure_sets(&mut self, total_num_groups: usize) {
190        if self.distinct_sets.len() < total_num_groups {
191            self.distinct_sets
192                .resize_with(total_num_groups, HashSet::default);
193        }
194    }
195}
196
197impl GroupsAccumulator for CountHashGroupAccumulator {
198    fn update_batch(
199        &mut self,
200        values: &[ArrayRef],
201        group_indices: &[usize],
202        opt_filter: Option<&BooleanArray>,
203        total_num_groups: usize,
204    ) -> Result<()> {
205        assert_eq!(values.len(), 1, "count_hash expects a single argument");
206        self.ensure_sets(total_num_groups);
207
208        let array = &values[0];
209        self.batch_hashes.clear();
210        self.batch_hashes.resize(array.len(), 0);
211        let hashes = create_hashes(
212            &[ArrayRef::clone(array)],
213            &self.random_state,
214            &mut self.batch_hashes,
215        )?;
216
217        // Use a pattern similar to accumulate_indices to process rows
218        // that are not null and pass the filter
219        let nulls = array.logical_nulls();
220
221        match (nulls.as_ref(), opt_filter) {
222            (None, None) => {
223                // No nulls, no filter - process all rows
224                for (row_idx, &group_idx) in group_indices.iter().enumerate() {
225                    self.distinct_sets[group_idx].insert(hashes[row_idx]);
226                }
227            }
228            (Some(nulls), None) => {
229                // Has nulls, no filter
230                for (row_idx, (&group_idx, is_valid)) in
231                    group_indices.iter().zip(nulls.iter()).enumerate()
232                {
233                    if is_valid {
234                        self.distinct_sets[group_idx].insert(hashes[row_idx]);
235                    }
236                }
237            }
238            (None, Some(filter)) => {
239                // No nulls, has filter
240                for (row_idx, (&group_idx, filter_value)) in
241                    group_indices.iter().zip(filter.iter()).enumerate()
242                {
243                    if let Some(true) = filter_value {
244                        self.distinct_sets[group_idx].insert(hashes[row_idx]);
245                    }
246                }
247            }
248            (Some(nulls), Some(filter)) => {
249                // Has nulls and filter
250                let iter = filter
251                    .iter()
252                    .zip(group_indices.iter())
253                    .zip(nulls.iter())
254                    .enumerate();
255
256                for (row_idx, ((filter_value, &group_idx), is_valid)) in iter {
257                    if is_valid && filter_value == Some(true) {
258                        self.distinct_sets[group_idx].insert(hashes[row_idx]);
259                    }
260                }
261            }
262        }
263
264        Ok(())
265    }
266
267    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
268        let distinct_sets: Vec<HashSet<u64, RandomState>> =
269            emit_to.take_needed(&mut self.distinct_sets);
270
271        let counts = distinct_sets
272            .iter()
273            .map(|set| set.len() as i64)
274            .collect::<Vec<_>>();
275        Ok(Arc::new(Int64Array::from(counts)))
276    }
277
278    fn merge_batch(
279        &mut self,
280        values: &[ArrayRef],
281        group_indices: &[usize],
282        _opt_filter: Option<&BooleanArray>,
283        total_num_groups: usize,
284    ) -> Result<()> {
285        assert_eq!(
286            values.len(),
287            1,
288            "count_hash merge expects a single state array"
289        );
290        self.ensure_sets(total_num_groups);
291
292        let list_array = as_list_array(&values[0])?;
293
294        // For each group in the incoming batch
295        for (i, &group_idx) in group_indices.iter().enumerate() {
296            if i < list_array.len() {
297                let inner_array = list_array.value(i);
298                let inner_array = inner_array.as_any().downcast_ref::<UInt64Array>().unwrap();
299                // Add each value to our set for this group
300                for j in 0..inner_array.len() {
301                    if !inner_array.is_null(j) {
302                        self.distinct_sets[group_idx].insert(inner_array.value(j));
303                    }
304                }
305            }
306        }
307
308        Ok(())
309    }
310
311    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
312        let distinct_sets: Vec<HashSet<u64, RandomState>> =
313            emit_to.take_needed(&mut self.distinct_sets);
314
315        let mut offsets = Vec::with_capacity(distinct_sets.len() + 1);
316        offsets.push(0);
317        let mut curr_len = 0i32;
318
319        let mut value_iter = distinct_sets
320            .into_iter()
321            .flat_map(|set| {
322                // build offset
323                curr_len += set.len() as i32;
324                offsets.push(curr_len);
325                // convert into iter
326                set.into_iter()
327            })
328            .peekable();
329        let data_array: ArrayRef = if value_iter.peek().is_none() {
330            arrow::array::new_empty_array(&DataType::UInt64) as _
331        } else {
332            Arc::new(UInt64Array::from_iter_values(value_iter))
333        };
334        let offset_buffer = OffsetBuffer::new(ScalarBuffer::from(offsets));
335
336        let list_array = ListArray::new(
337            Arc::new(Field::new_list_field(DataType::UInt64, true)),
338            offset_buffer,
339            data_array,
340            None,
341        );
342
343        Ok(vec![Arc::new(list_array) as _])
344    }
345
346    fn convert_to_state(
347        &self,
348        values: &[ArrayRef],
349        opt_filter: Option<&BooleanArray>,
350    ) -> Result<Vec<ArrayRef>> {
351        // For a single hash value per row, create a list array with that value
352        assert_eq!(values.len(), 1, "count_hash expects a single argument");
353        let values = ArrayRef::clone(&values[0]);
354
355        let offsets = OffsetBuffer::new(ScalarBuffer::from_iter(0..values.len() as i32 + 1));
356        let nulls = filtered_null_mask(opt_filter, &values);
357        let list_array = ListArray::new(
358            Arc::new(Field::new_list_field(DataType::UInt64, true)),
359            offsets,
360            values,
361            nulls,
362        );
363
364        Ok(vec![Arc::new(list_array)])
365    }
366
367    fn supports_convert_to_state(&self) -> bool {
368        true
369    }
370
371    fn size(&self) -> usize {
372        // Base size of the struct
373        let mut size = size_of::<Self>();
374
375        // Size of the vector holding the HashSets
376        size += size_of::<Vec<HashSet<HashValueType, RandomState>>>()
377            + self.distinct_sets.capacity() * size_of::<HashSet<HashValueType, RandomState>>();
378
379        // Estimate HashSet contents size more efficiently
380        // Instead of iterating through all values which is expensive, use an approximation
381        for set in &self.distinct_sets {
382            // Base size of the HashSet
383            size += set.capacity() * size_of::<HashValueType>();
384        }
385
386        size
387    }
388}
389
390#[derive(Debug)]
391struct CountHashAccumulator {
392    values: HashSet<HashValueType, RandomState>,
393    random_state: RandomState,
394    batch_hashes: Vec<HashValueType>,
395}
396
397impl CountHashAccumulator {
398    // calculating the size for fixed length values, taking first batch size *
399    // number of batches.
400    fn fixed_size(&self) -> usize {
401        size_of_val(self) + (size_of::<HashValueType>() * self.values.capacity())
402    }
403}
404
405impl Accumulator for CountHashAccumulator {
406    /// Returns the distinct values seen so far as (one element) ListArray.
407    fn state(&mut self) -> Result<Vec<ScalarValue>> {
408        let values = self.values.iter().cloned().collect::<Vec<_>>();
409        let arr = Arc::new(UInt64Array::from(values)) as _;
410        let list_scalar = SingleRowListArrayBuilder::new(arr).build_list_scalar();
411        Ok(vec![list_scalar])
412    }
413
414    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
415        if values.is_empty() {
416            return Ok(());
417        }
418
419        let arr = &values[0];
420        if arr.data_type() == &DataType::Null {
421            return Ok(());
422        }
423
424        self.batch_hashes.clear();
425        self.batch_hashes.resize(arr.len(), 0);
426        let hashes = create_hashes(
427            &[ArrayRef::clone(arr)],
428            &self.random_state,
429            &mut self.batch_hashes,
430        )?;
431        for hash in hashes.as_slice() {
432            self.values.insert(*hash);
433        }
434        Ok(())
435    }
436
437    /// Merges multiple sets of distinct values into the current set.
438    ///
439    /// The input to this function is a `ListArray` with **multiple** rows,
440    /// where each row contains the values from a partial aggregate's phase (e.g.
441    /// the result of calling `Self::state` on multiple accumulators).
442    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
443        if states.is_empty() {
444            return Ok(());
445        }
446        assert_eq!(states.len(), 1, "array_agg states must be singleton!");
447        let array = &states[0];
448        let list_array = array.as_list::<i32>();
449        for inner_array in list_array.iter() {
450            let Some(inner_array) = inner_array else {
451                return internal_err!(
452                    "Intermediate results of count_hash should always be non null"
453                );
454            };
455            let hash_array = inner_array.as_any().downcast_ref::<UInt64Array>().unwrap();
456            for i in 0..hash_array.len() {
457                self.values.insert(hash_array.value(i));
458            }
459        }
460        Ok(())
461    }
462
463    fn evaluate(&mut self) -> Result<ScalarValue> {
464        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
465    }
466
467    fn size(&self) -> usize {
468        self.fixed_size()
469    }
470}
471
472#[cfg(test)]
473mod tests {
474    use datatypes::arrow::array::{Array, BooleanArray, Int32Array, Int64Array};
475
476    use super::*;
477
478    fn create_test_accumulator() -> CountHashAccumulator {
479        CountHashAccumulator {
480            values: HashSet::default(),
481            random_state: RandomState::with_seeds(
482                RANDOM_SEED_0,
483                RANDOM_SEED_1,
484                RANDOM_SEED_2,
485                RANDOM_SEED_3,
486            ),
487            batch_hashes: vec![],
488        }
489    }
490
491    #[test]
492    fn test_count_hash_accumulator() -> Result<()> {
493        let mut acc = create_test_accumulator();
494
495        // Test with some data
496        let array = Arc::new(Int32Array::from(vec![
497            Some(1),
498            Some(2),
499            Some(3),
500            Some(1),
501            Some(2),
502            None,
503        ])) as ArrayRef;
504        acc.update_batch(&[array])?;
505        let result = acc.evaluate()?;
506        assert_eq!(result, ScalarValue::Int64(Some(4)));
507
508        // Test with empty data
509        let mut acc = create_test_accumulator();
510        let array = Arc::new(Int32Array::from(vec![] as Vec<Option<i32>>)) as ArrayRef;
511        acc.update_batch(&[array])?;
512        let result = acc.evaluate()?;
513        assert_eq!(result, ScalarValue::Int64(Some(0)));
514
515        // Test with only nulls
516        let mut acc = create_test_accumulator();
517        let array = Arc::new(Int32Array::from(vec![None, None, None])) as ArrayRef;
518        acc.update_batch(&[array])?;
519        let result = acc.evaluate()?;
520        assert_eq!(result, ScalarValue::Int64(Some(1)));
521
522        Ok(())
523    }
524
525    #[test]
526    fn test_count_hash_accumulator_merge() -> Result<()> {
527        // Accumulator 1
528        let mut acc1 = create_test_accumulator();
529        let array1 = Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])) as ArrayRef;
530        acc1.update_batch(&[array1])?;
531        let state1 = acc1.state()?;
532
533        // Accumulator 2
534        let mut acc2 = create_test_accumulator();
535        let array2 = Arc::new(Int32Array::from(vec![Some(3), Some(4), Some(5)])) as ArrayRef;
536        acc2.update_batch(&[array2])?;
537        let state2 = acc2.state()?;
538
539        // Merge state1 and state2 into a new accumulator
540        let mut acc_merged = create_test_accumulator();
541        let state_array1 = state1[0].to_array()?;
542        let state_array2 = state2[0].to_array()?;
543
544        acc_merged.merge_batch(&[state_array1])?;
545        acc_merged.merge_batch(&[state_array2])?;
546
547        let result = acc_merged.evaluate()?;
548        // Distinct values are {1, 2, 3, 4, 5}, so count is 5
549        assert_eq!(result, ScalarValue::Int64(Some(5)));
550
551        Ok(())
552    }
553
554    fn create_test_group_accumulator() -> CountHashGroupAccumulator {
555        CountHashGroupAccumulator::new()
556    }
557
558    #[test]
559    fn test_count_hash_group_accumulator() -> Result<()> {
560        let mut acc = create_test_group_accumulator();
561        let values = Arc::new(Int32Array::from(vec![1, 2, 1, 3, 2, 4, 5])) as ArrayRef;
562        let group_indices = vec![0, 1, 0, 0, 1, 2, 0];
563        let total_num_groups = 3;
564
565        acc.update_batch(&[values], &group_indices, None, total_num_groups)?;
566
567        let result_array = acc.evaluate(EmitTo::All)?;
568        let result = result_array.as_any().downcast_ref::<Int64Array>().unwrap();
569
570        // Group 0: {1, 3, 5} -> 3
571        // Group 1: {2} -> 1
572        // Group 2: {4} -> 1
573        assert_eq!(result.value(0), 3);
574        assert_eq!(result.value(1), 1);
575        assert_eq!(result.value(2), 1);
576
577        Ok(())
578    }
579
580    #[test]
581    fn test_count_hash_group_accumulator_with_filter() -> Result<()> {
582        let mut acc = create_test_group_accumulator();
583        let values = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])) as ArrayRef;
584        let group_indices = vec![0, 0, 1, 1, 2, 2];
585        let filter = BooleanArray::from(vec![true, false, true, true, false, true]);
586        let total_num_groups = 3;
587
588        acc.update_batch(&[values], &group_indices, Some(&filter), total_num_groups)?;
589
590        let result_array = acc.evaluate(EmitTo::All)?;
591        let result = result_array.as_any().downcast_ref::<Int64Array>().unwrap();
592
593        // Group 0: {1} (2 is filtered out) -> 1
594        // Group 1: {3, 4} -> 2
595        // Group 2: {6} (5 is filtered out) -> 1
596        assert_eq!(result.value(0), 1);
597        assert_eq!(result.value(1), 2);
598        assert_eq!(result.value(2), 1);
599
600        Ok(())
601    }
602
603    #[test]
604    fn test_count_hash_group_accumulator_merge() -> Result<()> {
605        // Accumulator 1
606        let mut acc1 = create_test_group_accumulator();
607        let values1 = Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as ArrayRef;
608        let group_indices1 = vec![0, 0, 1, 1];
609        acc1.update_batch(&[values1], &group_indices1, None, 2)?;
610        // acc1 state: group 0 -> {1, 2}, group 1 -> {3, 4}
611        let state1 = acc1.state(EmitTo::All)?;
612
613        // Accumulator 2
614        let mut acc2 = create_test_group_accumulator();
615        let values2 = Arc::new(Int32Array::from(vec![5, 6, 1, 3])) as ArrayRef;
616        // Merge into different group indices
617        let group_indices2 = vec![2, 2, 0, 1];
618        acc2.update_batch(&[values2], &group_indices2, None, 3)?;
619        // acc2 state: group 0 -> {1}, group 1 -> {3}, group 2 -> {5, 6}
620
621        // Merge state from acc1 into acc2
622        // We will merge acc1's group 0 into acc2's group 0
623        // and acc1's group 1 into acc2's group 2
624        let merge_group_indices = vec![0, 2];
625        acc2.merge_batch(&state1, &merge_group_indices, None, 3)?;
626
627        let result_array = acc2.evaluate(EmitTo::All)?;
628        let result = result_array.as_any().downcast_ref::<Int64Array>().unwrap();
629
630        // Final state of acc2:
631        // Group 0: {1} U {1, 2} -> {1, 2}, count = 2
632        // Group 1: {3}, count = 1
633        // Group 2: {5, 6} U {3, 4} -> {3, 4, 5, 6}, count = 4
634        assert_eq!(result.value(0), 2);
635        assert_eq!(result.value(1), 1);
636        assert_eq!(result.value(2), 4);
637
638        Ok(())
639    }
640
641    #[test]
642    fn test_size() {
643        let acc = create_test_group_accumulator();
644        // Just test it doesn't crash and returns a value.
645        assert!(acc.size() > 0);
646    }
647}