common_function/
function_registry.rs1use std::collections::HashMap;
17use std::sync::{Arc, LazyLock, RwLock};
18
19use datafusion::catalog::TableFunction;
20use datafusion_expr::expr_rewriter::FunctionRewrite;
21use datafusion_expr::{AggregateUDF, WindowUDF};
22
23use crate::admin::AdminFunction;
24use crate::aggrs::aggr_wrapper::StateMergeHelper;
25use crate::aggrs::approximate::ApproximateFunction;
26use crate::aggrs::count_hash::CountHash;
27use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
28use crate::function::{Function, FunctionRef};
29use crate::function_factory::ScalarFunctionFactory;
30use crate::scalars::anomaly::AnomalyFunction;
31use crate::scalars::date::DateFunction;
32use crate::scalars::expression::ExpressionFunction;
33use crate::scalars::hll_count::HllCalcFunction;
34use crate::scalars::ip::IpFunctions;
35use crate::scalars::json::JsonFunction;
36use crate::scalars::matches::MatchesFunction;
37use crate::scalars::matches_term::MatchesTermFunction;
38use crate::scalars::math::MathFunction;
39use crate::scalars::primary_key::DecodePrimaryKeyFunction;
40use crate::scalars::string::register_string_functions;
41use crate::scalars::timestamp::TimestampFunction;
42use crate::scalars::uddsketch_calc::UddSketchCalcFunction;
43use crate::scalars::vector::VectorFunction as VectorScalarFunction;
44use crate::system::SystemFunction;
45
46#[derive(Default)]
47pub struct FunctionRegistry {
48 functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
49 aggregate_functions: RwLock<HashMap<String, AggregateUDF>>,
50 table_functions: RwLock<HashMap<String, Arc<TableFunction>>>,
51 function_rewrites: RwLock<Vec<Arc<dyn FunctionRewrite + Send + Sync>>>,
52 window_functions: RwLock<HashMap<String, WindowUDF>>,
53}
54
55impl FunctionRegistry {
56 pub fn register(&self, func: impl Into<ScalarFunctionFactory>) {
65 let func = func.into();
66 let _ = self
67 .functions
68 .write()
69 .unwrap()
70 .insert(func.name().to_string(), func);
71 }
72
73 pub fn register_scalar(&self, func: impl Function + 'static) {
75 let func = Arc::new(func) as FunctionRef;
76
77 for alias in func.aliases() {
78 let func: ScalarFunctionFactory = func.clone().into();
79 let alias = ScalarFunctionFactory {
80 name: alias.clone(),
81 ..func
82 };
83 self.register(alias);
84 }
85
86 self.register(func)
87 }
88
89 pub fn register_aggr(&self, func: AggregateUDF) {
91 let _ = self
92 .aggregate_functions
93 .write()
94 .unwrap()
95 .insert(func.name().to_string(), func);
96 }
97
98 pub fn register_table_function(&self, func: TableFunction) {
100 let _ = self
101 .table_functions
102 .write()
103 .unwrap()
104 .insert(func.name().to_string(), Arc::new(func));
105 }
106
107 pub fn register_function_rewrite(&self, func: impl FunctionRewrite + Send + Sync + 'static) {
109 self.function_rewrites.write().unwrap().push(Arc::new(func));
110 }
111
112 pub fn register_window(&self, func: WindowUDF) {
114 let _ = self
115 .window_functions
116 .write()
117 .unwrap()
118 .insert(func.name().to_string(), func);
119 }
120
121 pub fn get_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
122 self.functions.read().unwrap().get(name).cloned()
123 }
124
125 pub fn scalar_functions(&self) -> Vec<ScalarFunctionFactory> {
127 self.functions.read().unwrap().values().cloned().collect()
128 }
129
130 pub fn aggregate_functions(&self) -> Vec<AggregateUDF> {
132 self.aggregate_functions
133 .read()
134 .unwrap()
135 .values()
136 .cloned()
137 .collect()
138 }
139
140 pub fn table_functions(&self) -> Vec<Arc<TableFunction>> {
141 self.table_functions
142 .read()
143 .unwrap()
144 .values()
145 .cloned()
146 .collect()
147 }
148
149 pub fn window_functions(&self) -> Vec<WindowUDF> {
151 self.window_functions
152 .read()
153 .unwrap()
154 .values()
155 .cloned()
156 .collect()
157 }
158
159 pub fn is_aggr_func_exist(&self, name: &str) -> bool {
161 self.aggregate_functions.read().unwrap().contains_key(name)
162 }
163
164 pub fn function_rewrites(&self) -> Vec<Arc<dyn FunctionRewrite + Send + Sync>> {
166 self.function_rewrites.read().unwrap().clone()
167 }
168}
169
170pub static FUNCTION_REGISTRY: LazyLock<Arc<FunctionRegistry>> = LazyLock::new(|| {
171 let function_registry = FunctionRegistry::default();
172
173 MathFunction::register(&function_registry);
175 TimestampFunction::register(&function_registry);
176 DateFunction::register(&function_registry);
177 ExpressionFunction::register(&function_registry);
178 UddSketchCalcFunction::register(&function_registry);
179 HllCalcFunction::register(&function_registry);
180 DecodePrimaryKeyFunction::register(&function_registry);
181
182 MatchesFunction::register(&function_registry);
184 MatchesTermFunction::register(&function_registry);
185
186 SystemFunction::register(&function_registry);
188 AdminFunction::register(&function_registry);
189
190 JsonFunction::register(&function_registry);
192
193 register_string_functions(&function_registry);
195
196 VectorScalarFunction::register(&function_registry);
198 VectorAggrFunction::register(&function_registry);
199
200 #[cfg(feature = "geo")]
202 crate::scalars::geo::GeoFunctions::register(&function_registry);
203 #[cfg(feature = "geo")]
204 crate::aggrs::geo::GeoFunction::register(&function_registry);
205
206 IpFunctions::register(&function_registry);
208
209 ApproximateFunction::register(&function_registry);
211
212 CountHash::register(&function_registry);
214
215 StateMergeHelper::register(&function_registry);
217
218 AnomalyFunction::register(&function_registry);
220
221 Arc::new(function_registry)
222});
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::scalars::test::TestAndFunction;
228
229 #[test]
230 fn test_function_registry() {
231 let registry = FunctionRegistry::default();
232
233 assert!(registry.get_function("test_and").is_none());
234 assert!(registry.scalar_functions().is_empty());
235 registry.register_scalar(TestAndFunction::default());
236 let _ = registry.get_function("test_and").unwrap();
237 assert_eq!(1, registry.scalar_functions().len());
238 }
239}