common_function/scalars/
hll_count.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Implementation of the scalar function `hll_count`.
16
17use std::fmt;
18use std::fmt::Display;
19use std::sync::Arc;
20
21use common_query::error::{DowncastVectorSnafu, InvalidFuncArgsSnafu, Result};
22use common_query::prelude::{Signature, Volatility};
23use datatypes::data_type::ConcreteDataType;
24use datatypes::prelude::Vector;
25use datatypes::scalars::{ScalarVector, ScalarVectorBuilder};
26use datatypes::vectors::{BinaryVector, MutableVector, UInt64VectorBuilder, VectorRef};
27use hyperloglogplus::HyperLogLog;
28use snafu::OptionExt;
29
30use crate::aggr::HllStateType;
31use crate::function::{Function, FunctionContext};
32use crate::function_registry::FunctionRegistry;
33
34const NAME: &str = "hll_count";
35
36/// HllCalcFunction implements the scalar function `hll_count`.
37///
38/// It accepts one argument:
39/// 1. The serialized HyperLogLogPlus state, as produced by the aggregator (binary).
40///
41/// For each row, it deserializes the sketch and returns the estimated cardinality.
42#[derive(Debug, Default)]
43pub struct HllCalcFunction;
44
45impl HllCalcFunction {
46    pub fn register(registry: &FunctionRegistry) {
47        registry.register(Arc::new(HllCalcFunction));
48    }
49}
50
51impl Display for HllCalcFunction {
52    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
53        write!(f, "{}", NAME.to_ascii_uppercase())
54    }
55}
56
57impl Function for HllCalcFunction {
58    fn name(&self) -> &str {
59        NAME
60    }
61
62    fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
63        Ok(ConcreteDataType::uint64_datatype())
64    }
65
66    fn signature(&self) -> Signature {
67        // Only argument: HyperLogLogPlus state (binary)
68        Signature::exact(
69            vec![ConcreteDataType::binary_datatype()],
70            Volatility::Immutable,
71        )
72    }
73
74    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
75        if columns.len() != 1 {
76            return InvalidFuncArgsSnafu {
77                err_msg: format!("hll_count expects 1 argument, got {}", columns.len()),
78            }
79            .fail();
80        }
81
82        let hll_vec = columns[0]
83            .as_any()
84            .downcast_ref::<BinaryVector>()
85            .with_context(|| DowncastVectorSnafu {
86                err_msg: format!("expect BinaryVector, got {}", columns[0].vector_type_name()),
87            })?;
88        let len = hll_vec.len();
89        let mut builder = UInt64VectorBuilder::with_capacity(len);
90
91        for i in 0..len {
92            let hll_opt = hll_vec.get_data(i);
93
94            if hll_opt.is_none() {
95                builder.push_null();
96                continue;
97            }
98
99            let hll_bytes = hll_opt.unwrap();
100
101            // Deserialize the HyperLogLogPlus from its bincode representation
102            let mut hll: HllStateType = match bincode::deserialize(hll_bytes) {
103                Ok(h) => h,
104                Err(e) => {
105                    common_telemetry::trace!("Failed to deserialize HyperLogLogPlus: {}", e);
106                    builder.push_null();
107                    continue;
108                }
109            };
110
111            builder.push(Some(hll.count().round() as u64));
112        }
113
114        Ok(builder.to_vector())
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use datatypes::vectors::BinaryVector;
121
122    use super::*;
123    use crate::utils::FixedRandomState;
124
125    #[test]
126    fn test_hll_count_function() {
127        let function = HllCalcFunction;
128        assert_eq!("hll_count", function.name());
129        assert_eq!(
130            ConcreteDataType::uint64_datatype(),
131            function
132                .return_type(&[ConcreteDataType::uint64_datatype()])
133                .unwrap()
134        );
135
136        // Create a test HLL
137        let mut hll = HllStateType::new(14, FixedRandomState::new()).unwrap();
138        for i in 1..=10 {
139            hll.insert(&i.to_string());
140        }
141
142        let serialized_bytes = bincode::serialize(&hll).unwrap();
143        let args: Vec<VectorRef> = vec![Arc::new(BinaryVector::from(vec![Some(serialized_bytes)]))];
144
145        let result = function.eval(&FunctionContext::default(), &args).unwrap();
146        assert_eq!(result.len(), 1);
147
148        // Test cardinality estimate
149        if let datatypes::value::Value::UInt64(v) = result.get(0) {
150            assert_eq!(v, 10);
151        } else {
152            panic!("Expected uint64 value");
153        }
154    }
155
156    #[test]
157    fn test_hll_count_function_errors() {
158        let function = HllCalcFunction;
159
160        // Test with invalid number of arguments
161        let args: Vec<VectorRef> = vec![];
162        let result = function.eval(&FunctionContext::default(), &args);
163        assert!(result.is_err());
164        assert!(result
165            .unwrap_err()
166            .to_string()
167            .contains("hll_count expects 1 argument"));
168
169        // Test with invalid binary data
170        let args: Vec<VectorRef> = vec![Arc::new(BinaryVector::from(vec![Some(vec![1, 2, 3])]))]; // Invalid binary data
171        let result = function.eval(&FunctionContext::default(), &args).unwrap();
172        assert_eq!(result.len(), 1);
173        assert!(matches!(result.get(0), datatypes::value::Value::Null));
174    }
175}