common_function/aggrs/approximate/
hll.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Two UDAFs are implemented for HyperLogLog:
16//!
17//! - `hll`: Accepts a string column and aggregates the values into a
18//!   HyperLogLog state.
19//! - `hll_merge`: Accepts a binary column of states generated by `hll`
20//!   and merges them into a single state.
21//!
22//! The states can be then used to estimate the cardinality of the
23//! values in the column by `hll_count` UDF.
24
25use std::sync::Arc;
26
27use common_query::prelude::*;
28use common_telemetry::trace;
29use datafusion::arrow::array::ArrayRef;
30use datafusion::common::cast::{as_binary_array, as_string_array};
31use datafusion::common::not_impl_err;
32use datafusion::error::{DataFusionError, Result as DfResult};
33use datafusion::logical_expr::function::AccumulatorArgs;
34use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF};
35use datafusion::prelude::create_udaf;
36use datafusion_expr::Volatility;
37use datatypes::arrow::datatypes::DataType;
38use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
39
40use crate::utils::FixedRandomState;
41
42pub const HLL_NAME: &str = "hll";
43pub const HLL_MERGE_NAME: &str = "hll_merge";
44
45const DEFAULT_PRECISION: u8 = 14;
46
47pub(crate) type HllStateType = HyperLogLogPlus<String, FixedRandomState>;
48
49pub struct HllState {
50    hll: HllStateType,
51}
52
53impl std::fmt::Debug for HllState {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        write!(f, "HllState<Opaque>")
56    }
57}
58
59impl Default for HllState {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl HllState {
66    pub fn new() -> Self {
67        Self {
68            // Safety: the DEFAULT_PRECISION is fixed and valid
69            hll: HllStateType::new(DEFAULT_PRECISION, FixedRandomState::new()).unwrap(),
70        }
71    }
72
73    /// Create a UDF for the `hll` function.
74    ///
75    /// `hll` accepts a string column and aggregates the
76    /// values into a HyperLogLog state.
77    pub fn state_udf_impl() -> AggregateUDF {
78        create_udaf(
79            HLL_NAME,
80            vec![DataType::Utf8],
81            Arc::new(DataType::Binary),
82            Volatility::Immutable,
83            Arc::new(Self::create_accumulator),
84            Arc::new(vec![DataType::Binary]),
85        )
86    }
87
88    /// Create a UDF for the `hll_merge` function.
89    ///
90    /// `hll_merge` accepts a binary column of states generated by `hll`
91    /// and merges them into a single state.
92    pub fn merge_udf_impl() -> AggregateUDF {
93        create_udaf(
94            HLL_MERGE_NAME,
95            vec![DataType::Binary],
96            Arc::new(DataType::Binary),
97            Volatility::Immutable,
98            Arc::new(Self::create_merge_accumulator),
99            Arc::new(vec![DataType::Binary]),
100        )
101    }
102
103    fn update(&mut self, value: &str) {
104        self.hll.insert(value);
105    }
106
107    fn merge(&mut self, raw: &[u8]) {
108        if let Ok(serialized) = bincode::deserialize::<HllStateType>(raw)
109            && let Ok(()) = self.hll.merge(&serialized)
110        {
111            return;
112        }
113        trace!("Warning: Failed to merge HyperLogLog from {:?}", raw);
114    }
115
116    fn create_accumulator(acc_args: AccumulatorArgs) -> DfResult<Box<dyn DfAccumulator>> {
117        let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
118
119        match data_type {
120            DataType::Utf8 => Ok(Box::new(HllState::new())),
121            other => not_impl_err!("{HLL_NAME} does not support data type: {other}"),
122        }
123    }
124
125    fn create_merge_accumulator(acc_args: AccumulatorArgs) -> DfResult<Box<dyn DfAccumulator>> {
126        let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
127
128        match data_type {
129            DataType::Binary => Ok(Box::new(HllState::new())),
130            other => not_impl_err!("{HLL_MERGE_NAME} does not support data type: {other}"),
131        }
132    }
133}
134
135impl DfAccumulator for HllState {
136    fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> {
137        let array = &values[0];
138
139        match array.data_type() {
140            DataType::Utf8 => {
141                let string_array = as_string_array(array)?;
142                for value in string_array.iter().flatten() {
143                    self.update(value);
144                }
145            }
146            DataType::Binary => {
147                let binary_array = as_binary_array(array)?;
148                for v in binary_array.iter().flatten() {
149                    self.merge(v);
150                }
151            }
152            _ => {
153                return not_impl_err!(
154                    "HLL functions do not support data type: {}",
155                    array.data_type()
156                );
157            }
158        }
159
160        Ok(())
161    }
162
163    fn evaluate(&mut self) -> DfResult<ScalarValue> {
164        Ok(ScalarValue::Binary(Some(
165            bincode::serialize(&self.hll).map_err(|e| {
166                DataFusionError::Internal(format!("Failed to serialize HyperLogLog: {}", e))
167            })?,
168        )))
169    }
170
171    fn size(&self) -> usize {
172        std::mem::size_of_val(&self.hll)
173    }
174
175    fn state(&mut self) -> DfResult<Vec<ScalarValue>> {
176        Ok(vec![ScalarValue::Binary(Some(
177            bincode::serialize(&self.hll).map_err(|e| {
178                DataFusionError::Internal(format!("Failed to serialize HyperLogLog: {}", e))
179            })?,
180        ))])
181    }
182
183    fn merge_batch(&mut self, states: &[ArrayRef]) -> DfResult<()> {
184        let array = &states[0];
185        let binary_array = as_binary_array(array)?;
186        for v in binary_array.iter().flatten() {
187            self.merge(v);
188        }
189
190        Ok(())
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use datafusion::arrow::array::{BinaryArray, StringArray};
197
198    use super::*;
199
200    #[test]
201    fn test_hll_basic() {
202        let mut state = HllState::new();
203        state.update("1");
204        state.update("2");
205        state.update("3");
206
207        let result = state.evaluate().unwrap();
208        if let ScalarValue::Binary(Some(bytes)) = result {
209            let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap();
210            assert_eq!(hll.count().trunc() as u32, 3);
211        } else {
212            panic!("Expected binary scalar value");
213        }
214    }
215
216    #[test]
217    fn test_hll_roundtrip() {
218        let mut state = HllState::new();
219        state.update("1");
220        state.update("2");
221
222        // Serialize
223        let serialized = state.evaluate().unwrap();
224
225        // Create new state and merge the serialized data
226        let mut new_state = HllState::new();
227        if let ScalarValue::Binary(Some(bytes)) = &serialized {
228            new_state.merge(bytes);
229
230            // Verify the merged state matches original
231            let result = new_state.evaluate().unwrap();
232            if let ScalarValue::Binary(Some(new_bytes)) = result {
233                let mut original: HllStateType = bincode::deserialize(bytes).unwrap();
234                let mut merged: HllStateType = bincode::deserialize(&new_bytes).unwrap();
235                assert_eq!(original.count(), merged.count());
236            } else {
237                panic!("Expected binary scalar value");
238            }
239        } else {
240            panic!("Expected binary scalar value");
241        }
242    }
243
244    #[test]
245    fn test_hll_batch_update() {
246        let mut state = HllState::new();
247
248        // Test string values
249        let str_values = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i"];
250        let str_array = Arc::new(StringArray::from(str_values)) as ArrayRef;
251        state.update_batch(&[str_array]).unwrap();
252
253        let result = state.evaluate().unwrap();
254        if let ScalarValue::Binary(Some(bytes)) = result {
255            let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap();
256            assert_eq!(hll.count().trunc() as u32, 9);
257        } else {
258            panic!("Expected binary scalar value");
259        }
260    }
261
262    #[test]
263    fn test_hll_merge_batch() {
264        let mut state1 = HllState::new();
265        state1.update("1");
266        let state1_binary = state1.evaluate().unwrap();
267
268        let mut state2 = HllState::new();
269        state2.update("2");
270        let state2_binary = state2.evaluate().unwrap();
271
272        let mut merged_state = HllState::new();
273        if let (ScalarValue::Binary(Some(bytes1)), ScalarValue::Binary(Some(bytes2))) =
274            (&state1_binary, &state2_binary)
275        {
276            let binary_array = Arc::new(BinaryArray::from(vec![
277                bytes1.as_slice(),
278                bytes2.as_slice(),
279            ])) as ArrayRef;
280            merged_state.merge_batch(&[binary_array]).unwrap();
281
282            let result = merged_state.evaluate().unwrap();
283            if let ScalarValue::Binary(Some(bytes)) = result {
284                let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap();
285                assert_eq!(hll.count().trunc() as u32, 2);
286            } else {
287                panic!("Expected binary scalar value");
288            }
289        } else {
290            panic!("Expected binary scalar values");
291        }
292    }
293
294    #[test]
295    fn test_hll_merge_function() {
296        // Create two HLL states with different values
297        let mut state1 = HllState::new();
298        state1.update("1");
299        state1.update("2");
300        let state1_binary = state1.evaluate().unwrap();
301
302        let mut state2 = HllState::new();
303        state2.update("2");
304        state2.update("3");
305        let state2_binary = state2.evaluate().unwrap();
306
307        // Create a merge state and merge both states
308        let mut merge_state = HllState::new();
309        if let (ScalarValue::Binary(Some(bytes1)), ScalarValue::Binary(Some(bytes2))) =
310            (&state1_binary, &state2_binary)
311        {
312            let binary_array = Arc::new(BinaryArray::from(vec![
313                bytes1.as_slice(),
314                bytes2.as_slice(),
315            ])) as ArrayRef;
316            merge_state.update_batch(&[binary_array]).unwrap();
317
318            let result = merge_state.evaluate().unwrap();
319            if let ScalarValue::Binary(Some(bytes)) = result {
320                let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap();
321                // Should have 3 unique values: "1", "2", "3"
322                assert_eq!(hll.count().trunc() as u32, 3);
323            } else {
324                panic!("Expected binary scalar value");
325            }
326        } else {
327            panic!("Expected binary scalar values");
328        }
329    }
330}