common_function/scalars/
hll_count.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Implementation of the scalar function `hll_count`.
16
17use std::fmt;
18use std::fmt::Display;
19use std::sync::Arc;
20
21use datafusion_common::DataFusionError;
22use datafusion_common::arrow::array::{Array, AsArray, UInt64Builder};
23use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
24use datatypes::arrow::datatypes::DataType;
25use hyperloglogplus::HyperLogLog;
26
27use crate::aggrs::approximate::hll::HllStateType;
28use crate::function::{Function, extract_args};
29use crate::function_registry::FunctionRegistry;
30
31const NAME: &str = "hll_count";
32
33/// HllCalcFunction implements the scalar function `hll_count`.
34///
35/// It accepts one argument:
36/// 1. The serialized HyperLogLogPlus state, as produced by the aggregator (binary).
37///
38/// For each row, it deserializes the sketch and returns the estimated cardinality.
39#[derive(Debug)]
40pub(crate) struct HllCalcFunction {
41    signature: Signature,
42}
43
44impl HllCalcFunction {
45    pub fn register(registry: &FunctionRegistry) {
46        registry.register_scalar(HllCalcFunction::default());
47    }
48}
49
50impl Default for HllCalcFunction {
51    fn default() -> Self {
52        Self {
53            signature: Signature::exact(vec![DataType::Binary], Volatility::Immutable),
54        }
55    }
56}
57
58impl Display for HllCalcFunction {
59    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60        write!(f, "{}", NAME.to_ascii_uppercase())
61    }
62}
63
64impl Function for HllCalcFunction {
65    fn name(&self) -> &str {
66        NAME
67    }
68
69    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
70        Ok(DataType::UInt64)
71    }
72
73    fn signature(&self) -> &Signature {
74        &self.signature
75    }
76
77    fn invoke_with_args(
78        &self,
79        args: ScalarFunctionArgs,
80    ) -> datafusion_common::Result<ColumnarValue> {
81        let [arg0] = extract_args(self.name(), &args)?;
82
83        let Some(hll_vec) = arg0.as_binary_opt::<i32>() else {
84            return Err(DataFusionError::Execution(format!(
85                "'{}' expects argument to be Binary datatype, got {}",
86                self.name(),
87                arg0.data_type()
88            )));
89        };
90        let len = hll_vec.len();
91        let mut builder = UInt64Builder::with_capacity(len);
92
93        for i in 0..len {
94            let hll_opt = hll_vec.is_valid(i).then(|| hll_vec.value(i));
95
96            if hll_opt.is_none() {
97                builder.append_null();
98                continue;
99            }
100
101            let hll_bytes = hll_opt.unwrap();
102
103            // Deserialize the HyperLogLogPlus from its bincode representation
104            let mut hll: HllStateType = match bincode::deserialize(hll_bytes) {
105                Ok(h) => h,
106                Err(e) => {
107                    common_telemetry::trace!("Failed to deserialize HyperLogLogPlus: {}", e);
108                    builder.append_null();
109                    continue;
110                }
111            };
112
113            builder.append_value(hll.count().round() as u64);
114        }
115
116        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use std::sync::Arc;
123
124    use arrow_schema::Field;
125    use datafusion_common::arrow::array::BinaryArray;
126    use datafusion_common::arrow::datatypes::UInt64Type;
127
128    use super::*;
129    use crate::utils::FixedRandomState;
130
131    #[test]
132    fn test_hll_count_function() {
133        let function = HllCalcFunction::default();
134        assert_eq!("hll_count", function.name());
135        assert_eq!(
136            DataType::UInt64,
137            function.return_type(&[DataType::UInt64]).unwrap()
138        );
139
140        // Create a test HLL
141        let mut hll = HllStateType::new(14, FixedRandomState::new()).unwrap();
142        for i in 1..=10 {
143            hll.insert(&i.to_string());
144        }
145
146        let serialized_bytes = bincode::serialize(&hll).unwrap();
147        let args = vec![ColumnarValue::Array(Arc::new(BinaryArray::from_iter(
148            vec![Some(serialized_bytes)],
149        )))];
150
151        let result = function
152            .invoke_with_args(ScalarFunctionArgs {
153                args,
154                arg_fields: vec![],
155                number_rows: 1,
156                return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
157                config_options: Arc::new(Default::default()),
158            })
159            .unwrap();
160        let ColumnarValue::Array(result) = result else {
161            unreachable!()
162        };
163        let result = result.as_primitive::<UInt64Type>();
164        assert_eq!(result.len(), 1);
165
166        // Test cardinality estimate
167        assert_eq!(result.value(0), 10);
168    }
169
170    #[test]
171    fn test_hll_count_function_errors() {
172        let function = HllCalcFunction::default();
173
174        // Test with invalid number of arguments
175        let result = function.invoke_with_args(ScalarFunctionArgs {
176            args: vec![],
177            arg_fields: vec![],
178            number_rows: 0,
179            return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
180            config_options: Arc::new(Default::default()),
181        });
182        assert!(result.is_err());
183        assert!(
184            result
185                .unwrap_err()
186                .to_string()
187                .contains("Execution error: hll_count function requires 1 argument, got 0")
188        );
189
190        // Test with invalid binary data
191        let result = function
192            .invoke_with_args(ScalarFunctionArgs {
193                args: vec![ColumnarValue::Array(Arc::new(BinaryArray::from_iter(
194                    vec![Some(vec![1, 2, 3])],
195                )))],
196                arg_fields: vec![],
197                number_rows: 0,
198                return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
199                config_options: Arc::new(Default::default()),
200            })
201            .unwrap();
202        let ColumnarValue::Array(result) = result else {
203            unreachable!()
204        };
205        let result = result.as_primitive::<UInt64Type>();
206        assert_eq!(result.len(), 1);
207        assert!(result.is_null(0));
208    }
209}