common_function/
function_registry.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! functions registry
16use std::collections::HashMap;
17use std::sync::{Arc, LazyLock, RwLock};
18
19use datafusion::catalog::TableFunction;
20use datafusion_expr::expr_rewriter::FunctionRewrite;
21use datafusion_expr::{AggregateUDF, WindowUDF};
22
23use crate::admin::AdminFunction;
24use crate::aggrs::aggr_wrapper::StateMergeHelper;
25use crate::aggrs::approximate::ApproximateFunction;
26use crate::aggrs::count_hash::CountHash;
27use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
28use crate::function::{Function, FunctionRef};
29use crate::function_factory::ScalarFunctionFactory;
30use crate::scalars::anomaly::AnomalyFunction;
31use crate::scalars::date::DateFunction;
32use crate::scalars::expression::ExpressionFunction;
33use crate::scalars::hll_count::HllCalcFunction;
34use crate::scalars::ip::IpFunctions;
35use crate::scalars::json::JsonFunction;
36use crate::scalars::matches::MatchesFunction;
37use crate::scalars::matches_term::MatchesTermFunction;
38use crate::scalars::math::MathFunction;
39use crate::scalars::primary_key::DecodePrimaryKeyFunction;
40use crate::scalars::string::register_string_functions;
41use crate::scalars::timestamp::TimestampFunction;
42use crate::scalars::uddsketch_calc::UddSketchCalcFunction;
43use crate::scalars::vector::VectorFunction as VectorScalarFunction;
44use crate::system::SystemFunction;
45
46#[derive(Default)]
47pub struct FunctionRegistry {
48    functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
49    aggregate_functions: RwLock<HashMap<String, AggregateUDF>>,
50    table_functions: RwLock<HashMap<String, Arc<TableFunction>>>,
51    function_rewrites: RwLock<Vec<Arc<dyn FunctionRewrite + Send + Sync>>>,
52    window_functions: RwLock<HashMap<String, WindowUDF>>,
53}
54
55impl FunctionRegistry {
56    /// Register a function in the registry by converting it into a `ScalarFunctionFactory`.
57    ///
58    /// # Arguments
59    ///
60    /// * `func` - An object that can be converted into a `ScalarFunctionFactory`.
61    ///
62    /// The function is inserted into the internal function map, keyed by its name.
63    /// If a function with the same name already exists, it will be replaced.
64    pub fn register(&self, func: impl Into<ScalarFunctionFactory>) {
65        let func = func.into();
66        let _ = self
67            .functions
68            .write()
69            .unwrap()
70            .insert(func.name().to_string(), func);
71    }
72
73    /// Register a scalar function in the registry.
74    pub fn register_scalar(&self, func: impl Function + 'static) {
75        let func = Arc::new(func) as FunctionRef;
76
77        for alias in func.aliases() {
78            let func: ScalarFunctionFactory = func.clone().into();
79            let alias = ScalarFunctionFactory {
80                name: alias.clone(),
81                ..func
82            };
83            self.register(alias);
84        }
85
86        self.register(func)
87    }
88
89    /// Register an aggregate function in the registry.
90    pub fn register_aggr(&self, func: AggregateUDF) {
91        let _ = self
92            .aggregate_functions
93            .write()
94            .unwrap()
95            .insert(func.name().to_string(), func);
96    }
97
98    /// Register a table function
99    pub fn register_table_function(&self, func: TableFunction) {
100        let _ = self
101            .table_functions
102            .write()
103            .unwrap()
104            .insert(func.name().to_string(), Arc::new(func));
105    }
106
107    /// Register a function rewrite rule.
108    pub fn register_function_rewrite(&self, func: impl FunctionRewrite + Send + Sync + 'static) {
109        self.function_rewrites.write().unwrap().push(Arc::new(func));
110    }
111
112    /// Register a window function (UDWF).
113    pub fn register_window(&self, func: WindowUDF) {
114        let _ = self
115            .window_functions
116            .write()
117            .unwrap()
118            .insert(func.name().to_string(), func);
119    }
120
121    pub fn get_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
122        self.functions.read().unwrap().get(name).cloned()
123    }
124
125    /// Returns a list of all scalar functions registered in the registry.
126    pub fn scalar_functions(&self) -> Vec<ScalarFunctionFactory> {
127        self.functions.read().unwrap().values().cloned().collect()
128    }
129
130    /// Returns a list of all aggregate functions registered in the registry.
131    pub fn aggregate_functions(&self) -> Vec<AggregateUDF> {
132        self.aggregate_functions
133            .read()
134            .unwrap()
135            .values()
136            .cloned()
137            .collect()
138    }
139
140    pub fn table_functions(&self) -> Vec<Arc<TableFunction>> {
141        self.table_functions
142            .read()
143            .unwrap()
144            .values()
145            .cloned()
146            .collect()
147    }
148
149    /// Returns a list of all window functions registered in the registry.
150    pub fn window_functions(&self) -> Vec<WindowUDF> {
151        self.window_functions
152            .read()
153            .unwrap()
154            .values()
155            .cloned()
156            .collect()
157    }
158
159    /// Returns true if an aggregate function with the given name exists in the registry.
160    pub fn is_aggr_func_exist(&self, name: &str) -> bool {
161        self.aggregate_functions.read().unwrap().contains_key(name)
162    }
163
164    /// Returns a list of all function rewrite rules registered in the registry.
165    pub fn function_rewrites(&self) -> Vec<Arc<dyn FunctionRewrite + Send + Sync>> {
166        self.function_rewrites.read().unwrap().clone()
167    }
168}
169
170pub static FUNCTION_REGISTRY: LazyLock<Arc<FunctionRegistry>> = LazyLock::new(|| {
171    let function_registry = FunctionRegistry::default();
172
173    // Utility functions
174    MathFunction::register(&function_registry);
175    TimestampFunction::register(&function_registry);
176    DateFunction::register(&function_registry);
177    ExpressionFunction::register(&function_registry);
178    UddSketchCalcFunction::register(&function_registry);
179    HllCalcFunction::register(&function_registry);
180    DecodePrimaryKeyFunction::register(&function_registry);
181
182    // Full text search function
183    MatchesFunction::register(&function_registry);
184    MatchesTermFunction::register(&function_registry);
185
186    // System and administration functions
187    SystemFunction::register(&function_registry);
188    AdminFunction::register(&function_registry);
189
190    // Json related functions
191    JsonFunction::register(&function_registry);
192
193    // String related functions
194    register_string_functions(&function_registry);
195
196    // Vector related functions
197    VectorScalarFunction::register(&function_registry);
198    VectorAggrFunction::register(&function_registry);
199
200    // Geo functions
201    #[cfg(feature = "geo")]
202    crate::scalars::geo::GeoFunctions::register(&function_registry);
203    #[cfg(feature = "geo")]
204    crate::aggrs::geo::GeoFunction::register(&function_registry);
205
206    // Ip functions
207    IpFunctions::register(&function_registry);
208
209    // Approximate functions
210    ApproximateFunction::register(&function_registry);
211
212    // CountHash function
213    CountHash::register(&function_registry);
214
215    // state function of supported aggregate functions
216    StateMergeHelper::register(&function_registry);
217
218    // Anomaly detection window functions
219    AnomalyFunction::register(&function_registry);
220
221    Arc::new(function_registry)
222});
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::scalars::test::TestAndFunction;
228
229    #[test]
230    fn test_function_registry() {
231        let registry = FunctionRegistry::default();
232
233        assert!(registry.get_function("test_and").is_none());
234        assert!(registry.scalar_functions().is_empty());
235        registry.register_scalar(TestAndFunction::default());
236        let _ = registry.get_function("test_and").unwrap();
237        assert_eq!(1, registry.scalar_functions().len());
238    }
239}