common_function/
function_registry.rs1use std::collections::HashMap;
17use std::sync::{Arc, LazyLock, RwLock};
18
19use datafusion::catalog::TableFunction;
20use datafusion_expr::AggregateUDF;
21
22use crate::admin::AdminFunction;
23use crate::aggrs::aggr_wrapper::StateMergeHelper;
24use crate::aggrs::approximate::ApproximateFunction;
25use crate::aggrs::count_hash::CountHash;
26use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
27use crate::function::{Function, FunctionRef};
28use crate::function_factory::ScalarFunctionFactory;
29use crate::scalars::date::DateFunction;
30use crate::scalars::expression::ExpressionFunction;
31use crate::scalars::hll_count::HllCalcFunction;
32use crate::scalars::ip::IpFunctions;
33use crate::scalars::json::JsonFunction;
34use crate::scalars::matches::MatchesFunction;
35use crate::scalars::matches_term::MatchesTermFunction;
36use crate::scalars::math::MathFunction;
37use crate::scalars::string::register_string_functions;
38use crate::scalars::timestamp::TimestampFunction;
39use crate::scalars::uddsketch_calc::UddSketchCalcFunction;
40use crate::scalars::vector::VectorFunction as VectorScalarFunction;
41use crate::system::SystemFunction;
42
43#[derive(Default)]
44pub struct FunctionRegistry {
45 functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
46 aggregate_functions: RwLock<HashMap<String, AggregateUDF>>,
47 table_functions: RwLock<HashMap<String, Arc<TableFunction>>>,
48}
49
50impl FunctionRegistry {
51 pub fn register(&self, func: impl Into<ScalarFunctionFactory>) {
60 let func = func.into();
61 let _ = self
62 .functions
63 .write()
64 .unwrap()
65 .insert(func.name().to_string(), func);
66 }
67
68 pub fn register_scalar(&self, func: impl Function + 'static) {
70 let func = Arc::new(func) as FunctionRef;
71
72 for alias in func.aliases() {
73 let func: ScalarFunctionFactory = func.clone().into();
74 let alias = ScalarFunctionFactory {
75 name: alias.clone(),
76 ..func
77 };
78 self.register(alias);
79 }
80
81 self.register(func)
82 }
83
84 pub fn register_aggr(&self, func: AggregateUDF) {
86 let _ = self
87 .aggregate_functions
88 .write()
89 .unwrap()
90 .insert(func.name().to_string(), func);
91 }
92
93 pub fn register_table_function(&self, func: TableFunction) {
95 let _ = self
96 .table_functions
97 .write()
98 .unwrap()
99 .insert(func.name().to_string(), Arc::new(func));
100 }
101
102 pub fn get_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
103 self.functions.read().unwrap().get(name).cloned()
104 }
105
106 pub fn scalar_functions(&self) -> Vec<ScalarFunctionFactory> {
108 self.functions.read().unwrap().values().cloned().collect()
109 }
110
111 pub fn aggregate_functions(&self) -> Vec<AggregateUDF> {
113 self.aggregate_functions
114 .read()
115 .unwrap()
116 .values()
117 .cloned()
118 .collect()
119 }
120
121 pub fn table_functions(&self) -> Vec<Arc<TableFunction>> {
122 self.table_functions
123 .read()
124 .unwrap()
125 .values()
126 .cloned()
127 .collect()
128 }
129
130 pub fn is_aggr_func_exist(&self, name: &str) -> bool {
132 self.aggregate_functions.read().unwrap().contains_key(name)
133 }
134}
135
136pub static FUNCTION_REGISTRY: LazyLock<Arc<FunctionRegistry>> = LazyLock::new(|| {
137 let function_registry = FunctionRegistry::default();
138
139 MathFunction::register(&function_registry);
141 TimestampFunction::register(&function_registry);
142 DateFunction::register(&function_registry);
143 ExpressionFunction::register(&function_registry);
144 UddSketchCalcFunction::register(&function_registry);
145 HllCalcFunction::register(&function_registry);
146
147 MatchesFunction::register(&function_registry);
149 MatchesTermFunction::register(&function_registry);
150
151 SystemFunction::register(&function_registry);
153 AdminFunction::register(&function_registry);
154
155 JsonFunction::register(&function_registry);
157
158 register_string_functions(&function_registry);
160
161 VectorScalarFunction::register(&function_registry);
163 VectorAggrFunction::register(&function_registry);
164
165 #[cfg(feature = "geo")]
167 crate::scalars::geo::GeoFunctions::register(&function_registry);
168 #[cfg(feature = "geo")]
169 crate::aggrs::geo::GeoFunction::register(&function_registry);
170
171 IpFunctions::register(&function_registry);
173
174 ApproximateFunction::register(&function_registry);
176
177 CountHash::register(&function_registry);
179
180 StateMergeHelper::register(&function_registry);
182
183 Arc::new(function_registry)
184});
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::scalars::test::TestAndFunction;
190
191 #[test]
192 fn test_function_registry() {
193 let registry = FunctionRegistry::default();
194
195 assert!(registry.get_function("test_and").is_none());
196 assert!(registry.scalar_functions().is_empty());
197 registry.register_scalar(TestAndFunction::default());
198 let _ = registry.get_function("test_and").unwrap();
199 assert_eq!(1, registry.scalar_functions().len());
200 }
201}