common_function/aggr/
hll.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Two UDAFs are implemented for HyperLogLog:
16//!
17//! - `hll`: Accepts a string column and aggregates the values into a
18//!   HyperLogLog state.
19//! - `hll_merge`: Accepts a binary column of states generated by `hll`
20//!   and merges them into a single state.
21//!
22//! The states can be then used to estimate the cardinality of the
23//! values in the column by `hll_count` UDF.
24
25use std::sync::Arc;
26
27use common_query::prelude::*;
28use common_telemetry::trace;
29use datafusion::arrow::array::ArrayRef;
30use datafusion::common::cast::{as_binary_array, as_string_array};
31use datafusion::common::not_impl_err;
32use datafusion::error::{DataFusionError, Result as DfResult};
33use datafusion::logical_expr::function::AccumulatorArgs;
34use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF};
35use datafusion::prelude::create_udaf;
36use datatypes::arrow::datatypes::DataType;
37use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
38
39use crate::utils::FixedRandomState;
40
41pub const HLL_NAME: &str = "hll";
42pub const HLL_MERGE_NAME: &str = "hll_merge";
43
44const DEFAULT_PRECISION: u8 = 14;
45
46pub(crate) type HllStateType = HyperLogLogPlus<String, FixedRandomState>;
47
48pub struct HllState {
49    hll: HllStateType,
50}
51
52impl std::fmt::Debug for HllState {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        write!(f, "HllState<Opaque>")
55    }
56}
57
58impl Default for HllState {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl HllState {
65    pub fn new() -> Self {
66        Self {
67            // Safety: the DEFAULT_PRECISION is fixed and valid
68            hll: HllStateType::new(DEFAULT_PRECISION, FixedRandomState::new()).unwrap(),
69        }
70    }
71
72    /// Create a UDF for the `hll` function.
73    ///
74    /// `hll` accepts a string column and aggregates the
75    /// values into a HyperLogLog state.
76    pub fn state_udf_impl() -> AggregateUDF {
77        create_udaf(
78            HLL_NAME,
79            vec![DataType::Utf8],
80            Arc::new(DataType::Binary),
81            Volatility::Immutable,
82            Arc::new(Self::create_accumulator),
83            Arc::new(vec![DataType::Binary]),
84        )
85    }
86
87    /// Create a UDF for the `hll_merge` function.
88    ///
89    /// `hll_merge` accepts a binary column of states generated by `hll`
90    /// and merges them into a single state.
91    pub fn merge_udf_impl() -> AggregateUDF {
92        create_udaf(
93            HLL_MERGE_NAME,
94            vec![DataType::Binary],
95            Arc::new(DataType::Binary),
96            Volatility::Immutable,
97            Arc::new(Self::create_merge_accumulator),
98            Arc::new(vec![DataType::Binary]),
99        )
100    }
101
102    fn update(&mut self, value: &str) {
103        self.hll.insert(value);
104    }
105
106    fn merge(&mut self, raw: &[u8]) {
107        if let Ok(serialized) = bincode::deserialize::<HllStateType>(raw) {
108            if let Ok(()) = self.hll.merge(&serialized) {
109                return;
110            }
111        }
112        trace!("Warning: Failed to merge HyperLogLog from {:?}", raw);
113    }
114
115    fn create_accumulator(acc_args: AccumulatorArgs) -> DfResult<Box<dyn DfAccumulator>> {
116        let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
117
118        match data_type {
119            DataType::Utf8 => Ok(Box::new(HllState::new())),
120            other => not_impl_err!("{HLL_NAME} does not support data type: {other}"),
121        }
122    }
123
124    fn create_merge_accumulator(acc_args: AccumulatorArgs) -> DfResult<Box<dyn DfAccumulator>> {
125        let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
126
127        match data_type {
128            DataType::Binary => Ok(Box::new(HllState::new())),
129            other => not_impl_err!("{HLL_MERGE_NAME} does not support data type: {other}"),
130        }
131    }
132}
133
134impl DfAccumulator for HllState {
135    fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> {
136        let array = &values[0];
137
138        match array.data_type() {
139            DataType::Utf8 => {
140                let string_array = as_string_array(array)?;
141                for value in string_array.iter().flatten() {
142                    self.update(value);
143                }
144            }
145            DataType::Binary => {
146                let binary_array = as_binary_array(array)?;
147                for v in binary_array.iter().flatten() {
148                    self.merge(v);
149                }
150            }
151            _ => {
152                return not_impl_err!(
153                    "HLL functions do not support data type: {}",
154                    array.data_type()
155                )
156            }
157        }
158
159        Ok(())
160    }
161
162    fn evaluate(&mut self) -> DfResult<ScalarValue> {
163        Ok(ScalarValue::Binary(Some(
164            bincode::serialize(&self.hll).map_err(|e| {
165                DataFusionError::Internal(format!("Failed to serialize HyperLogLog: {}", e))
166            })?,
167        )))
168    }
169
170    fn size(&self) -> usize {
171        std::mem::size_of_val(&self.hll)
172    }
173
174    fn state(&mut self) -> DfResult<Vec<ScalarValue>> {
175        Ok(vec![ScalarValue::Binary(Some(
176            bincode::serialize(&self.hll).map_err(|e| {
177                DataFusionError::Internal(format!("Failed to serialize HyperLogLog: {}", e))
178            })?,
179        ))])
180    }
181
182    fn merge_batch(&mut self, states: &[ArrayRef]) -> DfResult<()> {
183        let array = &states[0];
184        let binary_array = as_binary_array(array)?;
185        for v in binary_array.iter().flatten() {
186            self.merge(v);
187        }
188
189        Ok(())
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use datafusion::arrow::array::{BinaryArray, StringArray};
196
197    use super::*;
198
199    #[test]
200    fn test_hll_basic() {
201        let mut state = HllState::new();
202        state.update("1");
203        state.update("2");
204        state.update("3");
205
206        let result = state.evaluate().unwrap();
207        if let ScalarValue::Binary(Some(bytes)) = result {
208            let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap();
209            assert_eq!(hll.count().trunc() as u32, 3);
210        } else {
211            panic!("Expected binary scalar value");
212        }
213    }
214
215    #[test]
216    fn test_hll_roundtrip() {
217        let mut state = HllState::new();
218        state.update("1");
219        state.update("2");
220
221        // Serialize
222        let serialized = state.evaluate().unwrap();
223
224        // Create new state and merge the serialized data
225        let mut new_state = HllState::new();
226        if let ScalarValue::Binary(Some(bytes)) = &serialized {
227            new_state.merge(bytes);
228
229            // Verify the merged state matches original
230            let result = new_state.evaluate().unwrap();
231            if let ScalarValue::Binary(Some(new_bytes)) = result {
232                let mut original: HllStateType = bincode::deserialize(bytes).unwrap();
233                let mut merged: HllStateType = bincode::deserialize(&new_bytes).unwrap();
234                assert_eq!(original.count(), merged.count());
235            } else {
236                panic!("Expected binary scalar value");
237            }
238        } else {
239            panic!("Expected binary scalar value");
240        }
241    }
242
243    #[test]
244    fn test_hll_batch_update() {
245        let mut state = HllState::new();
246
247        // Test string values
248        let str_values = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i"];
249        let str_array = Arc::new(StringArray::from(str_values)) as ArrayRef;
250        state.update_batch(&[str_array]).unwrap();
251
252        let result = state.evaluate().unwrap();
253        if let ScalarValue::Binary(Some(bytes)) = result {
254            let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap();
255            assert_eq!(hll.count().trunc() as u32, 9);
256        } else {
257            panic!("Expected binary scalar value");
258        }
259    }
260
261    #[test]
262    fn test_hll_merge_batch() {
263        let mut state1 = HllState::new();
264        state1.update("1");
265        let state1_binary = state1.evaluate().unwrap();
266
267        let mut state2 = HllState::new();
268        state2.update("2");
269        let state2_binary = state2.evaluate().unwrap();
270
271        let mut merged_state = HllState::new();
272        if let (ScalarValue::Binary(Some(bytes1)), ScalarValue::Binary(Some(bytes2))) =
273            (&state1_binary, &state2_binary)
274        {
275            let binary_array = Arc::new(BinaryArray::from(vec![
276                bytes1.as_slice(),
277                bytes2.as_slice(),
278            ])) as ArrayRef;
279            merged_state.merge_batch(&[binary_array]).unwrap();
280
281            let result = merged_state.evaluate().unwrap();
282            if let ScalarValue::Binary(Some(bytes)) = result {
283                let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap();
284                assert_eq!(hll.count().trunc() as u32, 2);
285            } else {
286                panic!("Expected binary scalar value");
287            }
288        } else {
289            panic!("Expected binary scalar values");
290        }
291    }
292
293    #[test]
294    fn test_hll_merge_function() {
295        // Create two HLL states with different values
296        let mut state1 = HllState::new();
297        state1.update("1");
298        state1.update("2");
299        let state1_binary = state1.evaluate().unwrap();
300
301        let mut state2 = HllState::new();
302        state2.update("2");
303        state2.update("3");
304        let state2_binary = state2.evaluate().unwrap();
305
306        // Create a merge state and merge both states
307        let mut merge_state = HllState::new();
308        if let (ScalarValue::Binary(Some(bytes1)), ScalarValue::Binary(Some(bytes2))) =
309            (&state1_binary, &state2_binary)
310        {
311            let binary_array = Arc::new(BinaryArray::from(vec![
312                bytes1.as_slice(),
313                bytes2.as_slice(),
314            ])) as ArrayRef;
315            merge_state.update_batch(&[binary_array]).unwrap();
316
317            let result = merge_state.evaluate().unwrap();
318            if let ScalarValue::Binary(Some(bytes)) = result {
319                let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap();
320                // Should have 3 unique values: "1", "2", "3"
321                assert_eq!(hll.count().trunc() as u32, 3);
322            } else {
323                panic!("Expected binary scalar value");
324            }
325        } else {
326            panic!("Expected binary scalar values");
327        }
328    }
329}