common_function/
function_registry.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! functions registry
16use std::collections::HashMap;
17use std::sync::{Arc, RwLock};
18
19use datafusion_expr::AggregateUDF;
20use once_cell::sync::Lazy;
21
22use crate::admin::AdminFunction;
23use crate::aggrs::approximate::ApproximateFunction;
24use crate::aggrs::count_hash::CountHash;
25use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
26use crate::function::{AsyncFunctionRef, Function, FunctionRef};
27use crate::function_factory::ScalarFunctionFactory;
28use crate::scalars::date::DateFunction;
29use crate::scalars::expression::ExpressionFunction;
30use crate::scalars::hll_count::HllCalcFunction;
31use crate::scalars::ip::IpFunctions;
32use crate::scalars::json::JsonFunction;
33use crate::scalars::matches::MatchesFunction;
34use crate::scalars::matches_term::MatchesTermFunction;
35use crate::scalars::math::MathFunction;
36use crate::scalars::timestamp::TimestampFunction;
37use crate::scalars::uddsketch_calc::UddSketchCalcFunction;
38use crate::scalars::vector::VectorFunction as VectorScalarFunction;
39use crate::system::SystemFunction;
40
41#[derive(Default)]
42pub struct FunctionRegistry {
43    functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
44    async_functions: RwLock<HashMap<String, AsyncFunctionRef>>,
45    aggregate_functions: RwLock<HashMap<String, AggregateUDF>>,
46}
47
48impl FunctionRegistry {
49    pub fn register(&self, func: impl Into<ScalarFunctionFactory>) {
50        let func = func.into();
51        let _ = self
52            .functions
53            .write()
54            .unwrap()
55            .insert(func.name().to_string(), func);
56    }
57
58    pub fn register_scalar(&self, func: impl Function + 'static) {
59        self.register(Arc::new(func) as FunctionRef);
60    }
61
62    pub fn register_async(&self, func: AsyncFunctionRef) {
63        let _ = self
64            .async_functions
65            .write()
66            .unwrap()
67            .insert(func.name().to_string(), func);
68    }
69
70    pub fn register_aggr(&self, func: AggregateUDF) {
71        let _ = self
72            .aggregate_functions
73            .write()
74            .unwrap()
75            .insert(func.name().to_string(), func);
76    }
77
78    pub fn get_async_function(&self, name: &str) -> Option<AsyncFunctionRef> {
79        self.async_functions.read().unwrap().get(name).cloned()
80    }
81
82    pub fn async_functions(&self) -> Vec<AsyncFunctionRef> {
83        self.async_functions
84            .read()
85            .unwrap()
86            .values()
87            .cloned()
88            .collect()
89    }
90
91    #[cfg(test)]
92    pub fn get_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
93        self.functions.read().unwrap().get(name).cloned()
94    }
95
96    pub fn scalar_functions(&self) -> Vec<ScalarFunctionFactory> {
97        self.functions.read().unwrap().values().cloned().collect()
98    }
99
100    pub fn aggregate_functions(&self) -> Vec<AggregateUDF> {
101        self.aggregate_functions
102            .read()
103            .unwrap()
104            .values()
105            .cloned()
106            .collect()
107    }
108}
109
110pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
111    let function_registry = FunctionRegistry::default();
112
113    // Utility functions
114    MathFunction::register(&function_registry);
115    TimestampFunction::register(&function_registry);
116    DateFunction::register(&function_registry);
117    ExpressionFunction::register(&function_registry);
118    UddSketchCalcFunction::register(&function_registry);
119    HllCalcFunction::register(&function_registry);
120
121    // Full text search function
122    MatchesFunction::register(&function_registry);
123    MatchesTermFunction::register(&function_registry);
124
125    // System and administration functions
126    SystemFunction::register(&function_registry);
127    AdminFunction::register(&function_registry);
128
129    // Json related functions
130    JsonFunction::register(&function_registry);
131
132    // Vector related functions
133    VectorScalarFunction::register(&function_registry);
134    VectorAggrFunction::register(&function_registry);
135
136    // Geo functions
137    #[cfg(feature = "geo")]
138    crate::scalars::geo::GeoFunctions::register(&function_registry);
139    #[cfg(feature = "geo")]
140    crate::aggrs::geo::GeoFunction::register(&function_registry);
141
142    // Ip functions
143    IpFunctions::register(&function_registry);
144
145    // Approximate functions
146    ApproximateFunction::register(&function_registry);
147
148    // CountHash function
149    CountHash::register(&function_registry);
150
151    Arc::new(function_registry)
152});
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use crate::scalars::test::TestAndFunction;
158
159    #[test]
160    fn test_function_registry() {
161        let registry = FunctionRegistry::default();
162
163        assert!(registry.get_function("test_and").is_none());
164        assert!(registry.scalar_functions().is_empty());
165        registry.register_scalar(TestAndFunction);
166        let _ = registry.get_function("test_and").unwrap();
167        assert_eq!(1, registry.scalar_functions().len());
168    }
169}