common_function/aggr/
hll.rs1use std::sync::Arc;
26
27use common_query::prelude::*;
28use common_telemetry::trace;
29use datafusion::arrow::array::ArrayRef;
30use datafusion::common::cast::{as_binary_array, as_string_array};
31use datafusion::common::not_impl_err;
32use datafusion::error::{DataFusionError, Result as DfResult};
33use datafusion::logical_expr::function::AccumulatorArgs;
34use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF};
35use datafusion::prelude::create_udaf;
36use datatypes::arrow::datatypes::DataType;
37use hyperloglogplus::{HyperLogLog, HyperLogLogPlus};
38
39use crate::utils::FixedRandomState;
40
41pub const HLL_NAME: &str = "hll";
42pub const HLL_MERGE_NAME: &str = "hll_merge";
43
44const DEFAULT_PRECISION: u8 = 14;
45
46pub(crate) type HllStateType = HyperLogLogPlus<String, FixedRandomState>;
47
48pub struct HllState {
49 hll: HllStateType,
50}
51
52impl std::fmt::Debug for HllState {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 write!(f, "HllState<Opaque>")
55 }
56}
57
58impl Default for HllState {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl HllState {
65 pub fn new() -> Self {
66 Self {
67 hll: HllStateType::new(DEFAULT_PRECISION, FixedRandomState::new()).unwrap(),
69 }
70 }
71
72 pub fn state_udf_impl() -> AggregateUDF {
77 create_udaf(
78 HLL_NAME,
79 vec![DataType::Utf8],
80 Arc::new(DataType::Binary),
81 Volatility::Immutable,
82 Arc::new(Self::create_accumulator),
83 Arc::new(vec![DataType::Binary]),
84 )
85 }
86
87 pub fn merge_udf_impl() -> AggregateUDF {
92 create_udaf(
93 HLL_MERGE_NAME,
94 vec![DataType::Binary],
95 Arc::new(DataType::Binary),
96 Volatility::Immutable,
97 Arc::new(Self::create_merge_accumulator),
98 Arc::new(vec![DataType::Binary]),
99 )
100 }
101
102 fn update(&mut self, value: &str) {
103 self.hll.insert(value);
104 }
105
106 fn merge(&mut self, raw: &[u8]) {
107 if let Ok(serialized) = bincode::deserialize::<HllStateType>(raw) {
108 if let Ok(()) = self.hll.merge(&serialized) {
109 return;
110 }
111 }
112 trace!("Warning: Failed to merge HyperLogLog from {:?}", raw);
113 }
114
115 fn create_accumulator(acc_args: AccumulatorArgs) -> DfResult<Box<dyn DfAccumulator>> {
116 let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
117
118 match data_type {
119 DataType::Utf8 => Ok(Box::new(HllState::new())),
120 other => not_impl_err!("{HLL_NAME} does not support data type: {other}"),
121 }
122 }
123
124 fn create_merge_accumulator(acc_args: AccumulatorArgs) -> DfResult<Box<dyn DfAccumulator>> {
125 let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
126
127 match data_type {
128 DataType::Binary => Ok(Box::new(HllState::new())),
129 other => not_impl_err!("{HLL_MERGE_NAME} does not support data type: {other}"),
130 }
131 }
132}
133
134impl DfAccumulator for HllState {
135 fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> {
136 let array = &values[0];
137
138 match array.data_type() {
139 DataType::Utf8 => {
140 let string_array = as_string_array(array)?;
141 for value in string_array.iter().flatten() {
142 self.update(value);
143 }
144 }
145 DataType::Binary => {
146 let binary_array = as_binary_array(array)?;
147 for v in binary_array.iter().flatten() {
148 self.merge(v);
149 }
150 }
151 _ => {
152 return not_impl_err!(
153 "HLL functions do not support data type: {}",
154 array.data_type()
155 )
156 }
157 }
158
159 Ok(())
160 }
161
162 fn evaluate(&mut self) -> DfResult<ScalarValue> {
163 Ok(ScalarValue::Binary(Some(
164 bincode::serialize(&self.hll).map_err(|e| {
165 DataFusionError::Internal(format!("Failed to serialize HyperLogLog: {}", e))
166 })?,
167 )))
168 }
169
170 fn size(&self) -> usize {
171 std::mem::size_of_val(&self.hll)
172 }
173
174 fn state(&mut self) -> DfResult<Vec<ScalarValue>> {
175 Ok(vec![ScalarValue::Binary(Some(
176 bincode::serialize(&self.hll).map_err(|e| {
177 DataFusionError::Internal(format!("Failed to serialize HyperLogLog: {}", e))
178 })?,
179 ))])
180 }
181
182 fn merge_batch(&mut self, states: &[ArrayRef]) -> DfResult<()> {
183 let array = &states[0];
184 let binary_array = as_binary_array(array)?;
185 for v in binary_array.iter().flatten() {
186 self.merge(v);
187 }
188
189 Ok(())
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use datafusion::arrow::array::{BinaryArray, StringArray};
196
197 use super::*;
198
199 #[test]
200 fn test_hll_basic() {
201 let mut state = HllState::new();
202 state.update("1");
203 state.update("2");
204 state.update("3");
205
206 let result = state.evaluate().unwrap();
207 if let ScalarValue::Binary(Some(bytes)) = result {
208 let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap();
209 assert_eq!(hll.count().trunc() as u32, 3);
210 } else {
211 panic!("Expected binary scalar value");
212 }
213 }
214
215 #[test]
216 fn test_hll_roundtrip() {
217 let mut state = HllState::new();
218 state.update("1");
219 state.update("2");
220
221 let serialized = state.evaluate().unwrap();
223
224 let mut new_state = HllState::new();
226 if let ScalarValue::Binary(Some(bytes)) = &serialized {
227 new_state.merge(bytes);
228
229 let result = new_state.evaluate().unwrap();
231 if let ScalarValue::Binary(Some(new_bytes)) = result {
232 let mut original: HllStateType = bincode::deserialize(bytes).unwrap();
233 let mut merged: HllStateType = bincode::deserialize(&new_bytes).unwrap();
234 assert_eq!(original.count(), merged.count());
235 } else {
236 panic!("Expected binary scalar value");
237 }
238 } else {
239 panic!("Expected binary scalar value");
240 }
241 }
242
243 #[test]
244 fn test_hll_batch_update() {
245 let mut state = HllState::new();
246
247 let str_values = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i"];
249 let str_array = Arc::new(StringArray::from(str_values)) as ArrayRef;
250 state.update_batch(&[str_array]).unwrap();
251
252 let result = state.evaluate().unwrap();
253 if let ScalarValue::Binary(Some(bytes)) = result {
254 let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap();
255 assert_eq!(hll.count().trunc() as u32, 9);
256 } else {
257 panic!("Expected binary scalar value");
258 }
259 }
260
261 #[test]
262 fn test_hll_merge_batch() {
263 let mut state1 = HllState::new();
264 state1.update("1");
265 let state1_binary = state1.evaluate().unwrap();
266
267 let mut state2 = HllState::new();
268 state2.update("2");
269 let state2_binary = state2.evaluate().unwrap();
270
271 let mut merged_state = HllState::new();
272 if let (ScalarValue::Binary(Some(bytes1)), ScalarValue::Binary(Some(bytes2))) =
273 (&state1_binary, &state2_binary)
274 {
275 let binary_array = Arc::new(BinaryArray::from(vec![
276 bytes1.as_slice(),
277 bytes2.as_slice(),
278 ])) as ArrayRef;
279 merged_state.merge_batch(&[binary_array]).unwrap();
280
281 let result = merged_state.evaluate().unwrap();
282 if let ScalarValue::Binary(Some(bytes)) = result {
283 let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap();
284 assert_eq!(hll.count().trunc() as u32, 2);
285 } else {
286 panic!("Expected binary scalar value");
287 }
288 } else {
289 panic!("Expected binary scalar values");
290 }
291 }
292
293 #[test]
294 fn test_hll_merge_function() {
295 let mut state1 = HllState::new();
297 state1.update("1");
298 state1.update("2");
299 let state1_binary = state1.evaluate().unwrap();
300
301 let mut state2 = HllState::new();
302 state2.update("2");
303 state2.update("3");
304 let state2_binary = state2.evaluate().unwrap();
305
306 let mut merge_state = HllState::new();
308 if let (ScalarValue::Binary(Some(bytes1)), ScalarValue::Binary(Some(bytes2))) =
309 (&state1_binary, &state2_binary)
310 {
311 let binary_array = Arc::new(BinaryArray::from(vec![
312 bytes1.as_slice(),
313 bytes2.as_slice(),
314 ])) as ArrayRef;
315 merge_state.update_batch(&[binary_array]).unwrap();
316
317 let result = merge_state.evaluate().unwrap();
318 if let ScalarValue::Binary(Some(bytes)) = result {
319 let mut hll: HllStateType = bincode::deserialize(&bytes).unwrap();
320 assert_eq!(hll.count().trunc() as u32, 3);
322 } else {
323 panic!("Expected binary scalar value");
324 }
325 } else {
326 panic!("Expected binary scalar values");
327 }
328 }
329}