common_function/scalars/
hll_count.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Implementation of the scalar function `hll_count`.
16
17use std::fmt;
18use std::fmt::Display;
19
20use common_query::error::{DowncastVectorSnafu, InvalidFuncArgsSnafu, Result};
21use common_query::prelude::{Signature, Volatility};
22use datatypes::data_type::ConcreteDataType;
23use datatypes::prelude::Vector;
24use datatypes::scalars::{ScalarVector, ScalarVectorBuilder};
25use datatypes::vectors::{BinaryVector, MutableVector, UInt64VectorBuilder, VectorRef};
26use hyperloglogplus::HyperLogLog;
27use snafu::OptionExt;
28
29use crate::aggrs::approximate::hll::HllStateType;
30use crate::function::{Function, FunctionContext};
31use crate::function_registry::FunctionRegistry;
32
33const NAME: &str = "hll_count";
34
35/// HllCalcFunction implements the scalar function `hll_count`.
36///
37/// It accepts one argument:
38/// 1. The serialized HyperLogLogPlus state, as produced by the aggregator (binary).
39///
40/// For each row, it deserializes the sketch and returns the estimated cardinality.
41#[derive(Debug, Default)]
42pub struct HllCalcFunction;
43
44impl HllCalcFunction {
45    pub fn register(registry: &FunctionRegistry) {
46        registry.register_scalar(HllCalcFunction);
47    }
48}
49
50impl Display for HllCalcFunction {
51    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52        write!(f, "{}", NAME.to_ascii_uppercase())
53    }
54}
55
56impl Function for HllCalcFunction {
57    fn name(&self) -> &str {
58        NAME
59    }
60
61    fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
62        Ok(ConcreteDataType::uint64_datatype())
63    }
64
65    fn signature(&self) -> Signature {
66        // Only argument: HyperLogLogPlus state (binary)
67        Signature::exact(
68            vec![ConcreteDataType::binary_datatype()],
69            Volatility::Immutable,
70        )
71    }
72
73    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
74        if columns.len() != 1 {
75            return InvalidFuncArgsSnafu {
76                err_msg: format!("hll_count expects 1 argument, got {}", columns.len()),
77            }
78            .fail();
79        }
80
81        let hll_vec = columns[0]
82            .as_any()
83            .downcast_ref::<BinaryVector>()
84            .with_context(|| DowncastVectorSnafu {
85                err_msg: format!("expect BinaryVector, got {}", columns[0].vector_type_name()),
86            })?;
87        let len = hll_vec.len();
88        let mut builder = UInt64VectorBuilder::with_capacity(len);
89
90        for i in 0..len {
91            let hll_opt = hll_vec.get_data(i);
92
93            if hll_opt.is_none() {
94                builder.push_null();
95                continue;
96            }
97
98            let hll_bytes = hll_opt.unwrap();
99
100            // Deserialize the HyperLogLogPlus from its bincode representation
101            let mut hll: HllStateType = match bincode::deserialize(hll_bytes) {
102                Ok(h) => h,
103                Err(e) => {
104                    common_telemetry::trace!("Failed to deserialize HyperLogLogPlus: {}", e);
105                    builder.push_null();
106                    continue;
107                }
108            };
109
110            builder.push(Some(hll.count().round() as u64));
111        }
112
113        Ok(builder.to_vector())
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use std::sync::Arc;
120
121    use datatypes::vectors::BinaryVector;
122
123    use super::*;
124    use crate::utils::FixedRandomState;
125
126    #[test]
127    fn test_hll_count_function() {
128        let function = HllCalcFunction;
129        assert_eq!("hll_count", function.name());
130        assert_eq!(
131            ConcreteDataType::uint64_datatype(),
132            function
133                .return_type(&[ConcreteDataType::uint64_datatype()])
134                .unwrap()
135        );
136
137        // Create a test HLL
138        let mut hll = HllStateType::new(14, FixedRandomState::new()).unwrap();
139        for i in 1..=10 {
140            hll.insert(&i.to_string());
141        }
142
143        let serialized_bytes = bincode::serialize(&hll).unwrap();
144        let args: Vec<VectorRef> = vec![Arc::new(BinaryVector::from(vec![Some(serialized_bytes)]))];
145
146        let result = function.eval(&FunctionContext::default(), &args).unwrap();
147        assert_eq!(result.len(), 1);
148
149        // Test cardinality estimate
150        if let datatypes::value::Value::UInt64(v) = result.get(0) {
151            assert_eq!(v, 10);
152        } else {
153            panic!("Expected uint64 value");
154        }
155    }
156
157    #[test]
158    fn test_hll_count_function_errors() {
159        let function = HllCalcFunction;
160
161        // Test with invalid number of arguments
162        let args: Vec<VectorRef> = vec![];
163        let result = function.eval(&FunctionContext::default(), &args);
164        assert!(result.is_err());
165        assert!(result
166            .unwrap_err()
167            .to_string()
168            .contains("hll_count expects 1 argument"));
169
170        // Test with invalid binary data
171        let args: Vec<VectorRef> = vec![Arc::new(BinaryVector::from(vec![Some(vec![1, 2, 3])]))]; // Invalid binary data
172        let result = function.eval(&FunctionContext::default(), &args).unwrap();
173        assert_eq!(result.len(), 1);
174        assert!(matches!(result.get(0), datatypes::value::Value::Null));
175    }
176}