common_function/
function_registry.rs1use std::collections::HashMap;
17use std::sync::{Arc, RwLock};
18
19use datafusion_expr::AggregateUDF;
20use once_cell::sync::Lazy;
21
22use crate::admin::AdminFunction;
23use crate::aggrs::aggr_wrapper::StateMergeHelper;
24use crate::aggrs::approximate::ApproximateFunction;
25use crate::aggrs::count_hash::CountHash;
26use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
27use crate::function::{Function, FunctionRef};
28use crate::function_factory::ScalarFunctionFactory;
29use crate::scalars::date::DateFunction;
30use crate::scalars::expression::ExpressionFunction;
31use crate::scalars::hll_count::HllCalcFunction;
32use crate::scalars::ip::IpFunctions;
33use crate::scalars::json::JsonFunction;
34use crate::scalars::matches::MatchesFunction;
35use crate::scalars::matches_term::MatchesTermFunction;
36use crate::scalars::math::MathFunction;
37use crate::scalars::timestamp::TimestampFunction;
38use crate::scalars::uddsketch_calc::UddSketchCalcFunction;
39use crate::scalars::vector::VectorFunction as VectorScalarFunction;
40use crate::system::SystemFunction;
41
42#[derive(Default)]
43pub struct FunctionRegistry {
44 functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
45 aggregate_functions: RwLock<HashMap<String, AggregateUDF>>,
46}
47
48impl FunctionRegistry {
49 pub fn register(&self, func: impl Into<ScalarFunctionFactory>) {
58 let func = func.into();
59 let _ = self
60 .functions
61 .write()
62 .unwrap()
63 .insert(func.name().to_string(), func);
64 }
65
66 pub fn register_scalar(&self, func: impl Function + 'static) {
68 self.register(Arc::new(func) as FunctionRef);
69 }
70
71 pub fn register_aggr(&self, func: AggregateUDF) {
73 let _ = self
74 .aggregate_functions
75 .write()
76 .unwrap()
77 .insert(func.name().to_string(), func);
78 }
79
80 pub fn get_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
81 self.functions.read().unwrap().get(name).cloned()
82 }
83
84 pub fn scalar_functions(&self) -> Vec<ScalarFunctionFactory> {
86 self.functions.read().unwrap().values().cloned().collect()
87 }
88
89 pub fn aggregate_functions(&self) -> Vec<AggregateUDF> {
91 self.aggregate_functions
92 .read()
93 .unwrap()
94 .values()
95 .cloned()
96 .collect()
97 }
98
99 pub fn is_aggr_func_exist(&self, name: &str) -> bool {
101 self.aggregate_functions.read().unwrap().contains_key(name)
102 }
103}
104
105pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
106 let function_registry = FunctionRegistry::default();
107
108 MathFunction::register(&function_registry);
110 TimestampFunction::register(&function_registry);
111 DateFunction::register(&function_registry);
112 ExpressionFunction::register(&function_registry);
113 UddSketchCalcFunction::register(&function_registry);
114 HllCalcFunction::register(&function_registry);
115
116 MatchesFunction::register(&function_registry);
118 MatchesTermFunction::register(&function_registry);
119
120 SystemFunction::register(&function_registry);
122 AdminFunction::register(&function_registry);
123
124 JsonFunction::register(&function_registry);
126
127 VectorScalarFunction::register(&function_registry);
129 VectorAggrFunction::register(&function_registry);
130
131 #[cfg(feature = "geo")]
133 crate::scalars::geo::GeoFunctions::register(&function_registry);
134 #[cfg(feature = "geo")]
135 crate::aggrs::geo::GeoFunction::register(&function_registry);
136
137 IpFunctions::register(&function_registry);
139
140 ApproximateFunction::register(&function_registry);
142
143 CountHash::register(&function_registry);
145
146 StateMergeHelper::register(&function_registry);
148
149 Arc::new(function_registry)
150});
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use crate::scalars::test::TestAndFunction;
156
157 #[test]
158 fn test_function_registry() {
159 let registry = FunctionRegistry::default();
160
161 assert!(registry.get_function("test_and").is_none());
162 assert!(registry.scalar_functions().is_empty());
163 registry.register_scalar(TestAndFunction);
164 let _ = registry.get_function("test_and").unwrap();
165 assert_eq!(1, registry.scalar_functions().len());
166 }
167}