common_function/scalars/
hll_count.rs1use std::fmt;
18use std::fmt::Display;
19use std::sync::Arc;
20
21use datafusion_common::DataFusionError;
22use datafusion_common::arrow::array::{Array, AsArray, UInt64Builder};
23use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
24use datatypes::arrow::datatypes::DataType;
25use hyperloglogplus::HyperLogLog;
26
27use crate::aggrs::approximate::hll::HllStateType;
28use crate::function::{Function, extract_args};
29use crate::function_registry::FunctionRegistry;
30
31const NAME: &str = "hll_count";
32
33#[derive(Debug)]
40pub(crate) struct HllCalcFunction {
41 signature: Signature,
42}
43
44impl HllCalcFunction {
45 pub fn register(registry: &FunctionRegistry) {
46 registry.register_scalar(HllCalcFunction::default());
47 }
48}
49
50impl Default for HllCalcFunction {
51 fn default() -> Self {
52 Self {
53 signature: Signature::exact(vec![DataType::Binary], Volatility::Immutable),
54 }
55 }
56}
57
58impl Display for HllCalcFunction {
59 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60 write!(f, "{}", NAME.to_ascii_uppercase())
61 }
62}
63
64impl Function for HllCalcFunction {
65 fn name(&self) -> &str {
66 NAME
67 }
68
69 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
70 Ok(DataType::UInt64)
71 }
72
73 fn signature(&self) -> &Signature {
74 &self.signature
75 }
76
77 fn invoke_with_args(
78 &self,
79 args: ScalarFunctionArgs,
80 ) -> datafusion_common::Result<ColumnarValue> {
81 let [arg0] = extract_args(self.name(), &args)?;
82
83 let Some(hll_vec) = arg0.as_binary_opt::<i32>() else {
84 return Err(DataFusionError::Execution(format!(
85 "'{}' expects argument to be Binary datatype, got {}",
86 self.name(),
87 arg0.data_type()
88 )));
89 };
90 let len = hll_vec.len();
91 let mut builder = UInt64Builder::with_capacity(len);
92
93 for i in 0..len {
94 let hll_opt = hll_vec.is_valid(i).then(|| hll_vec.value(i));
95
96 if hll_opt.is_none() {
97 builder.append_null();
98 continue;
99 }
100
101 let hll_bytes = hll_opt.unwrap();
102
103 let mut hll: HllStateType = match bincode::deserialize(hll_bytes) {
105 Ok(h) => h,
106 Err(e) => {
107 common_telemetry::trace!("Failed to deserialize HyperLogLogPlus: {}", e);
108 builder.append_null();
109 continue;
110 }
111 };
112
113 builder.append_value(hll.count().round() as u64);
114 }
115
116 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use std::sync::Arc;
123
124 use arrow_schema::Field;
125 use datafusion_common::arrow::array::BinaryArray;
126 use datafusion_common::arrow::datatypes::UInt64Type;
127
128 use super::*;
129 use crate::utils::FixedRandomState;
130
131 #[test]
132 fn test_hll_count_function() {
133 let function = HllCalcFunction::default();
134 assert_eq!("hll_count", function.name());
135 assert_eq!(
136 DataType::UInt64,
137 function.return_type(&[DataType::UInt64]).unwrap()
138 );
139
140 let mut hll = HllStateType::new(14, FixedRandomState::new()).unwrap();
142 for i in 1..=10 {
143 hll.insert(&i.to_string());
144 }
145
146 let serialized_bytes = bincode::serialize(&hll).unwrap();
147 let args = vec![ColumnarValue::Array(Arc::new(BinaryArray::from_iter(
148 vec![Some(serialized_bytes)],
149 )))];
150
151 let result = function
152 .invoke_with_args(ScalarFunctionArgs {
153 args,
154 arg_fields: vec![],
155 number_rows: 1,
156 return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
157 config_options: Arc::new(Default::default()),
158 })
159 .unwrap();
160 let ColumnarValue::Array(result) = result else {
161 unreachable!()
162 };
163 let result = result.as_primitive::<UInt64Type>();
164 assert_eq!(result.len(), 1);
165
166 assert_eq!(result.value(0), 10);
168 }
169
170 #[test]
171 fn test_hll_count_function_errors() {
172 let function = HllCalcFunction::default();
173
174 let result = function.invoke_with_args(ScalarFunctionArgs {
176 args: vec![],
177 arg_fields: vec![],
178 number_rows: 0,
179 return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
180 config_options: Arc::new(Default::default()),
181 });
182 assert!(result.is_err());
183 assert!(
184 result
185 .unwrap_err()
186 .to_string()
187 .contains("Execution error: hll_count function requires 1 argument, got 0")
188 );
189
190 let result = function
192 .invoke_with_args(ScalarFunctionArgs {
193 args: vec![ColumnarValue::Array(Arc::new(BinaryArray::from_iter(
194 vec![Some(vec![1, 2, 3])],
195 )))],
196 arg_fields: vec![],
197 number_rows: 0,
198 return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
199 config_options: Arc::new(Default::default()),
200 })
201 .unwrap();
202 let ColumnarValue::Array(result) = result else {
203 unreachable!()
204 };
205 let result = result.as_primitive::<UInt64Type>();
206 assert_eq!(result.len(), 1);
207 assert!(result.is_null(0));
208 }
209}