common_function/
function_registry.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! functions registry
16use std::collections::HashMap;
17use std::sync::{Arc, RwLock};
18
19use datafusion_expr::AggregateUDF;
20use once_cell::sync::Lazy;
21
22use crate::admin::AdminFunction;
23use crate::aggrs::aggr_wrapper::StateMergeHelper;
24use crate::aggrs::approximate::ApproximateFunction;
25use crate::aggrs::count_hash::CountHash;
26use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
27use crate::function::{Function, FunctionRef};
28use crate::function_factory::ScalarFunctionFactory;
29use crate::scalars::date::DateFunction;
30use crate::scalars::expression::ExpressionFunction;
31use crate::scalars::hll_count::HllCalcFunction;
32use crate::scalars::ip::IpFunctions;
33use crate::scalars::json::JsonFunction;
34use crate::scalars::matches::MatchesFunction;
35use crate::scalars::matches_term::MatchesTermFunction;
36use crate::scalars::math::MathFunction;
37use crate::scalars::timestamp::TimestampFunction;
38use crate::scalars::uddsketch_calc::UddSketchCalcFunction;
39use crate::scalars::vector::VectorFunction as VectorScalarFunction;
40use crate::system::SystemFunction;
41
42#[derive(Default)]
43pub struct FunctionRegistry {
44    functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
45    aggregate_functions: RwLock<HashMap<String, AggregateUDF>>,
46}
47
48impl FunctionRegistry {
49    /// Register a function in the registry by converting it into a `ScalarFunctionFactory`.
50    ///
51    /// # Arguments
52    ///
53    /// * `func` - An object that can be converted into a `ScalarFunctionFactory`.
54    ///
55    /// The function is inserted into the internal function map, keyed by its name.
56    /// If a function with the same name already exists, it will be replaced.
57    pub fn register(&self, func: impl Into<ScalarFunctionFactory>) {
58        let func = func.into();
59        let _ = self
60            .functions
61            .write()
62            .unwrap()
63            .insert(func.name().to_string(), func);
64    }
65
66    /// Register a scalar function in the registry.
67    pub fn register_scalar(&self, func: impl Function + 'static) {
68        self.register(Arc::new(func) as FunctionRef);
69    }
70
71    /// Register an aggregate function in the registry.
72    pub fn register_aggr(&self, func: AggregateUDF) {
73        let _ = self
74            .aggregate_functions
75            .write()
76            .unwrap()
77            .insert(func.name().to_string(), func);
78    }
79
80    pub fn get_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
81        self.functions.read().unwrap().get(name).cloned()
82    }
83
84    /// Returns a list of all scalar functions registered in the registry.
85    pub fn scalar_functions(&self) -> Vec<ScalarFunctionFactory> {
86        self.functions.read().unwrap().values().cloned().collect()
87    }
88
89    /// Returns a list of all aggregate functions registered in the registry.
90    pub fn aggregate_functions(&self) -> Vec<AggregateUDF> {
91        self.aggregate_functions
92            .read()
93            .unwrap()
94            .values()
95            .cloned()
96            .collect()
97    }
98
99    /// Returns true if an aggregate function with the given name exists in the registry.
100    pub fn is_aggr_func_exist(&self, name: &str) -> bool {
101        self.aggregate_functions.read().unwrap().contains_key(name)
102    }
103}
104
105pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
106    let function_registry = FunctionRegistry::default();
107
108    // Utility functions
109    MathFunction::register(&function_registry);
110    TimestampFunction::register(&function_registry);
111    DateFunction::register(&function_registry);
112    ExpressionFunction::register(&function_registry);
113    UddSketchCalcFunction::register(&function_registry);
114    HllCalcFunction::register(&function_registry);
115
116    // Full text search function
117    MatchesFunction::register(&function_registry);
118    MatchesTermFunction::register(&function_registry);
119
120    // System and administration functions
121    SystemFunction::register(&function_registry);
122    AdminFunction::register(&function_registry);
123
124    // Json related functions
125    JsonFunction::register(&function_registry);
126
127    // Vector related functions
128    VectorScalarFunction::register(&function_registry);
129    VectorAggrFunction::register(&function_registry);
130
131    // Geo functions
132    #[cfg(feature = "geo")]
133    crate::scalars::geo::GeoFunctions::register(&function_registry);
134    #[cfg(feature = "geo")]
135    crate::aggrs::geo::GeoFunction::register(&function_registry);
136
137    // Ip functions
138    IpFunctions::register(&function_registry);
139
140    // Approximate functions
141    ApproximateFunction::register(&function_registry);
142
143    // CountHash function
144    CountHash::register(&function_registry);
145
146    // state function of supported aggregate functions
147    StateMergeHelper::register(&function_registry);
148
149    Arc::new(function_registry)
150});
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use crate::scalars::test::TestAndFunction;
156
157    #[test]
158    fn test_function_registry() {
159        let registry = FunctionRegistry::default();
160
161        assert!(registry.get_function("test_and").is_none());
162        assert!(registry.scalar_functions().is_empty());
163        registry.register_scalar(TestAndFunction);
164        let _ = registry.get_function("test_and").unwrap();
165        assert_eq!(1, registry.scalar_functions().len());
166    }
167}