common_function/
function_registry.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! functions registry
16use std::collections::HashMap;
17use std::sync::{Arc, LazyLock, RwLock};
18
19use datafusion::catalog::TableFunction;
20use datafusion_expr::AggregateUDF;
21
22use crate::admin::AdminFunction;
23use crate::aggrs::aggr_wrapper::StateMergeHelper;
24use crate::aggrs::approximate::ApproximateFunction;
25use crate::aggrs::count_hash::CountHash;
26use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
27use crate::function::{Function, FunctionRef};
28use crate::function_factory::ScalarFunctionFactory;
29use crate::scalars::date::DateFunction;
30use crate::scalars::expression::ExpressionFunction;
31use crate::scalars::hll_count::HllCalcFunction;
32use crate::scalars::ip::IpFunctions;
33use crate::scalars::json::JsonFunction;
34use crate::scalars::matches::MatchesFunction;
35use crate::scalars::matches_term::MatchesTermFunction;
36use crate::scalars::math::MathFunction;
37use crate::scalars::primary_key::DecodePrimaryKeyFunction;
38use crate::scalars::string::register_string_functions;
39use crate::scalars::timestamp::TimestampFunction;
40use crate::scalars::uddsketch_calc::UddSketchCalcFunction;
41use crate::scalars::vector::VectorFunction as VectorScalarFunction;
42use crate::system::SystemFunction;
43
44#[derive(Default)]
45pub struct FunctionRegistry {
46    functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
47    aggregate_functions: RwLock<HashMap<String, AggregateUDF>>,
48    table_functions: RwLock<HashMap<String, Arc<TableFunction>>>,
49}
50
51impl FunctionRegistry {
52    /// Register a function in the registry by converting it into a `ScalarFunctionFactory`.
53    ///
54    /// # Arguments
55    ///
56    /// * `func` - An object that can be converted into a `ScalarFunctionFactory`.
57    ///
58    /// The function is inserted into the internal function map, keyed by its name.
59    /// If a function with the same name already exists, it will be replaced.
60    pub fn register(&self, func: impl Into<ScalarFunctionFactory>) {
61        let func = func.into();
62        let _ = self
63            .functions
64            .write()
65            .unwrap()
66            .insert(func.name().to_string(), func);
67    }
68
69    /// Register a scalar function in the registry.
70    pub fn register_scalar(&self, func: impl Function + 'static) {
71        let func = Arc::new(func) as FunctionRef;
72
73        for alias in func.aliases() {
74            let func: ScalarFunctionFactory = func.clone().into();
75            let alias = ScalarFunctionFactory {
76                name: alias.clone(),
77                ..func
78            };
79            self.register(alias);
80        }
81
82        self.register(func)
83    }
84
85    /// Register an aggregate function in the registry.
86    pub fn register_aggr(&self, func: AggregateUDF) {
87        let _ = self
88            .aggregate_functions
89            .write()
90            .unwrap()
91            .insert(func.name().to_string(), func);
92    }
93
94    /// Register a table function
95    pub fn register_table_function(&self, func: TableFunction) {
96        let _ = self
97            .table_functions
98            .write()
99            .unwrap()
100            .insert(func.name().to_string(), Arc::new(func));
101    }
102
103    pub fn get_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
104        self.functions.read().unwrap().get(name).cloned()
105    }
106
107    /// Returns a list of all scalar functions registered in the registry.
108    pub fn scalar_functions(&self) -> Vec<ScalarFunctionFactory> {
109        self.functions.read().unwrap().values().cloned().collect()
110    }
111
112    /// Returns a list of all aggregate functions registered in the registry.
113    pub fn aggregate_functions(&self) -> Vec<AggregateUDF> {
114        self.aggregate_functions
115            .read()
116            .unwrap()
117            .values()
118            .cloned()
119            .collect()
120    }
121
122    pub fn table_functions(&self) -> Vec<Arc<TableFunction>> {
123        self.table_functions
124            .read()
125            .unwrap()
126            .values()
127            .cloned()
128            .collect()
129    }
130
131    /// Returns true if an aggregate function with the given name exists in the registry.
132    pub fn is_aggr_func_exist(&self, name: &str) -> bool {
133        self.aggregate_functions.read().unwrap().contains_key(name)
134    }
135}
136
137pub static FUNCTION_REGISTRY: LazyLock<Arc<FunctionRegistry>> = LazyLock::new(|| {
138    let function_registry = FunctionRegistry::default();
139
140    // Utility functions
141    MathFunction::register(&function_registry);
142    TimestampFunction::register(&function_registry);
143    DateFunction::register(&function_registry);
144    ExpressionFunction::register(&function_registry);
145    UddSketchCalcFunction::register(&function_registry);
146    HllCalcFunction::register(&function_registry);
147    DecodePrimaryKeyFunction::register(&function_registry);
148
149    // Full text search function
150    MatchesFunction::register(&function_registry);
151    MatchesTermFunction::register(&function_registry);
152
153    // System and administration functions
154    SystemFunction::register(&function_registry);
155    AdminFunction::register(&function_registry);
156
157    // Json related functions
158    JsonFunction::register(&function_registry);
159
160    // String related functions
161    register_string_functions(&function_registry);
162
163    // Vector related functions
164    VectorScalarFunction::register(&function_registry);
165    VectorAggrFunction::register(&function_registry);
166
167    // Geo functions
168    #[cfg(feature = "geo")]
169    crate::scalars::geo::GeoFunctions::register(&function_registry);
170    #[cfg(feature = "geo")]
171    crate::aggrs::geo::GeoFunction::register(&function_registry);
172
173    // Ip functions
174    IpFunctions::register(&function_registry);
175
176    // Approximate functions
177    ApproximateFunction::register(&function_registry);
178
179    // CountHash function
180    CountHash::register(&function_registry);
181
182    // state function of supported aggregate functions
183    StateMergeHelper::register(&function_registry);
184
185    Arc::new(function_registry)
186});
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use crate::scalars::test::TestAndFunction;
192
193    #[test]
194    fn test_function_registry() {
195        let registry = FunctionRegistry::default();
196
197        assert!(registry.get_function("test_and").is_none());
198        assert!(registry.scalar_functions().is_empty());
199        registry.register_scalar(TestAndFunction::default());
200        let _ = registry.get_function("test_and").unwrap();
201        assert_eq!(1, registry.scalar_functions().len());
202    }
203}