common_function/aggrs/approximate/
hll.rs1use std::sync::Arc;
26
27use common_query::prelude::*;
28use common_telemetry::trace;
29use datafusion::arrow::array::ArrayRef;
30use datafusion::common::cast::{as_binary_array, as_string_array};
31use datafusion::common::not_impl_err;
32use datafusion::error::{DataFusionError, Result as DfResult};
33use datafusion::logical_expr::function::AccumulatorArgs;
34use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF};
35use datafusion::prelude::create_udaf;
36use datafusion_expr::Volatility;
37use datatypes::arrow::datatypes::DataType;
38use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
39
40use crate::utils::FixedRandomState;
41
42pub const HLL_NAME: &str = "hll";
43pub const HLL_MERGE_NAME: &str = "hll_merge";
44
45const DEFAULT_PRECISION: u8 = 14;
46
47pub(crate) type HllStateType = HyperLogLogPlus<String, FixedRandomState>;
48
49pub struct HllState {
50 hll: HllStateType,
51}
52
53impl std::fmt::Debug for HllState {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 write!(f, "HllState<Opaque>")
56 }
57}
58
59impl Default for HllState {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl HllState {
66 pub fn new() -> Self {
67 Self {
68 hll: HllStateType::new(DEFAULT_PRECISION, FixedRandomState::new()).unwrap(),
70 }
71 }
72
73 pub fn state_udf_impl() -> AggregateUDF {
78 create_udaf(
79 HLL_NAME,
80 vec![DataType::Utf8],
81 Arc::new(DataType::Binary),
82 Volatility::Immutable,
83 Arc::new(Self::create_accumulator),
84 Arc::new(vec![DataType::Binary]),
85 )
86 }
87
88 pub fn merge_udf_impl() -> AggregateUDF {
93 create_udaf(
94 HLL_MERGE_NAME,
95 vec![DataType::Binary],
96 Arc::new(DataType::Binary),
97 Volatility::Immutable,
98 Arc::new(Self::create_merge_accumulator),
99 Arc::new(vec![DataType::Binary]),
100 )
101 }
102
103 fn update(&mut self, value: &str) {
104 self.hll.insert(value);
105 }
106
107 fn merge(&mut self, raw: &[u8]) {
108 if let Ok(serialized) = bincode::deserialize::<HllStateType>(raw)
109 && let Ok(()) = self.hll.merge(&serialized)
110 {
111 return;
112 }
113 trace!("Warning: Failed to merge HyperLogLog from {:?}", raw);
114 }
115
116 fn create_accumulator(acc_args: AccumulatorArgs) -> DfResult<Box<dyn DfAccumulator>> {
117 let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
118
119 match data_type {
120 DataType::Utf8 => Ok(Box::new(HllState::new())),
121 other => not_impl_err!("{HLL_NAME} does not support data type: {other}"),
122 }
123 }
124
125 fn create_merge_accumulator(acc_args: AccumulatorArgs) -> DfResult<Box<dyn DfAccumulator>> {
126 let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
127
128 match data_type {
129 DataType::Binary => Ok(Box::new(HllState::new())),
130 other => not_impl_err!("{HLL_MERGE_NAME} does not support data type: {other}"),
131 }
132 }
133}
134
135impl DfAccumulator for HllState {
136 fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> {
137 let array = &values[0];
138
139 match array.data_type() {
140 DataType::Utf8 => {
141 let string_array = as_string_array(array)?;
142 for value in string_array.iter().flatten() {
143 self.update(value);
144 }
145 }
146 DataType::Binary => {
147 let binary_array = as_binary_array(array)?;
148 for v in binary_array.iter().flatten() {
149 self.merge(v);
150 }
151 }
152 _ => {
153 return not_impl_err!(
154 "HLL functions do not support data type: {}",
155 array.data_type()
156 );
157 }
158 }
159
160 Ok(())
161 }
162
163 fn evaluate(&mut self) -> DfResult<ScalarValue> {
164 Ok(ScalarValue::Binary(Some(
165 bincode::serialize(&self.hll).map_err(|e| {
166 DataFusionError::Internal(format!("Failed to serialize HyperLogLog: {}", e))
167 })?,
168 )))
169 }
170
171 fn size(&self) -> usize {
172 std::mem::size_of_val(&self.hll)
173 }
174
175 fn state(&mut self) -> DfResult<Vec<ScalarValue>> {
176 Ok(vec![ScalarValue::Binary(Some(
177 bincode::serialize(&self.hll).map_err(|e| {
178 DataFusionError::Internal(format!("Failed to serialize HyperLogLog: {}", e))
179 })?,
180 ))])
181 }
182
183 fn merge_batch(&mut self, states: &[ArrayRef]) -> DfResult<()> {
184 let array = &states[0];
185 let binary_array = as_binary_array(array)?;
186 for v in binary_array.iter().flatten() {
187 self.merge(v);
188 }
189
190 Ok(())
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use datafusion::arrow::array::{BinaryArray, StringArray};
197
198 use super::*;
199
200 #[test]
201 fn test_hll_basic() {
202 let mut state = HllState::new();
203 state.update("1");
204 state.update("2");
205 state.update("3");
206
207 let result = state.evaluate().unwrap();
208 if let ScalarValue::Binary(Some(bytes)) = result {
209 let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap();
210 assert_eq!(hll.count().trunc() as u32, 3);
211 } else {
212 panic!("Expected binary scalar value");
213 }
214 }
215
216 #[test]
217 fn test_hll_roundtrip() {
218 let mut state = HllState::new();
219 state.update("1");
220 state.update("2");
221
222 let serialized = state.evaluate().unwrap();
224
225 let mut new_state = HllState::new();
227 if let ScalarValue::Binary(Some(bytes)) = &serialized {
228 new_state.merge(bytes);
229
230 let result = new_state.evaluate().unwrap();
232 if let ScalarValue::Binary(Some(new_bytes)) = result {
233 let mut original: HllStateType = bincode::deserialize(bytes).unwrap();
234 let mut merged: HllStateType = bincode::deserialize(&new_bytes).unwrap();
235 assert_eq!(original.count(), merged.count());
236 } else {
237 panic!("Expected binary scalar value");
238 }
239 } else {
240 panic!("Expected binary scalar value");
241 }
242 }
243
244 #[test]
245 fn test_hll_batch_update() {
246 let mut state = HllState::new();
247
248 let str_values = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i"];
250 let str_array = Arc::new(StringArray::from(str_values)) as ArrayRef;
251 state.update_batch(&[str_array]).unwrap();
252
253 let result = state.evaluate().unwrap();
254 if let ScalarValue::Binary(Some(bytes)) = result {
255 let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap();
256 assert_eq!(hll.count().trunc() as u32, 9);
257 } else {
258 panic!("Expected binary scalar value");
259 }
260 }
261
262 #[test]
263 fn test_hll_merge_batch() {
264 let mut state1 = HllState::new();
265 state1.update("1");
266 let state1_binary = state1.evaluate().unwrap();
267
268 let mut state2 = HllState::new();
269 state2.update("2");
270 let state2_binary = state2.evaluate().unwrap();
271
272 let mut merged_state = HllState::new();
273 if let (ScalarValue::Binary(Some(bytes1)), ScalarValue::Binary(Some(bytes2))) =
274 (&state1_binary, &state2_binary)
275 {
276 let binary_array = Arc::new(BinaryArray::from(vec![
277 bytes1.as_slice(),
278 bytes2.as_slice(),
279 ])) as ArrayRef;
280 merged_state.merge_batch(&[binary_array]).unwrap();
281
282 let result = merged_state.evaluate().unwrap();
283 if let ScalarValue::Binary(Some(bytes)) = result {
284 let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap();
285 assert_eq!(hll.count().trunc() as u32, 2);
286 } else {
287 panic!("Expected binary scalar value");
288 }
289 } else {
290 panic!("Expected binary scalar values");
291 }
292 }
293
294 #[test]
295 fn test_hll_merge_function() {
296 let mut state1 = HllState::new();
298 state1.update("1");
299 state1.update("2");
300 let state1_binary = state1.evaluate().unwrap();
301
302 let mut state2 = HllState::new();
303 state2.update("2");
304 state2.update("3");
305 let state2_binary = state2.evaluate().unwrap();
306
307 let mut merge_state = HllState::new();
309 if let (ScalarValue::Binary(Some(bytes1)), ScalarValue::Binary(Some(bytes2))) =
310 (&state1_binary, &state2_binary)
311 {
312 let binary_array = Arc::new(BinaryArray::from(vec![
313 bytes1.as_slice(),
314 bytes2.as_slice(),
315 ])) as ArrayRef;
316 merge_state.update_batch(&[binary_array]).unwrap();
317
318 let result = merge_state.evaluate().unwrap();
319 if let ScalarValue::Binary(Some(bytes)) = result {
320 let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap();
321 assert_eq!(hll.count().trunc() as u32, 3);
323 } else {
324 panic!("Expected binary scalar value");
325 }
326 } else {
327 panic!("Expected binary scalar values");
328 }
329 }
330}