common_function/
function_registry.rs1use std::collections::HashMap;
17use std::sync::{Arc, LazyLock, RwLock};
18
19use datafusion::catalog::TableFunction;
20use datafusion_expr::AggregateUDF;
21
22use crate::admin::AdminFunction;
23use crate::aggrs::aggr_wrapper::StateMergeHelper;
24use crate::aggrs::approximate::ApproximateFunction;
25use crate::aggrs::count_hash::CountHash;
26use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
27use crate::function::{Function, FunctionRef};
28use crate::function_factory::ScalarFunctionFactory;
29use crate::scalars::date::DateFunction;
30use crate::scalars::expression::ExpressionFunction;
31use crate::scalars::hll_count::HllCalcFunction;
32use crate::scalars::ip::IpFunctions;
33use crate::scalars::json::JsonFunction;
34use crate::scalars::matches::MatchesFunction;
35use crate::scalars::matches_term::MatchesTermFunction;
36use crate::scalars::math::MathFunction;
37use crate::scalars::timestamp::TimestampFunction;
38use crate::scalars::uddsketch_calc::UddSketchCalcFunction;
39use crate::scalars::vector::VectorFunction as VectorScalarFunction;
40use crate::system::SystemFunction;
41
42#[derive(Default)]
43pub struct FunctionRegistry {
44 functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
45 aggregate_functions: RwLock<HashMap<String, AggregateUDF>>,
46 table_functions: RwLock<HashMap<String, Arc<TableFunction>>>,
47}
48
49impl FunctionRegistry {
50 pub fn register(&self, func: impl Into<ScalarFunctionFactory>) {
59 let func = func.into();
60 let _ = self
61 .functions
62 .write()
63 .unwrap()
64 .insert(func.name().to_string(), func);
65 }
66
67 pub fn register_scalar(&self, func: impl Function + 'static) {
69 let func = Arc::new(func) as FunctionRef;
70
71 for alias in func.aliases() {
72 let func: ScalarFunctionFactory = func.clone().into();
73 let alias = ScalarFunctionFactory {
74 name: alias.to_string(),
75 ..func
76 };
77 self.register(alias);
78 }
79
80 self.register(func)
81 }
82
83 pub fn register_aggr(&self, func: AggregateUDF) {
85 let _ = self
86 .aggregate_functions
87 .write()
88 .unwrap()
89 .insert(func.name().to_string(), func);
90 }
91
92 pub fn register_table_function(&self, func: TableFunction) {
94 let _ = self
95 .table_functions
96 .write()
97 .unwrap()
98 .insert(func.name().to_string(), Arc::new(func));
99 }
100
101 pub fn get_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
102 self.functions.read().unwrap().get(name).cloned()
103 }
104
105 pub fn scalar_functions(&self) -> Vec<ScalarFunctionFactory> {
107 self.functions.read().unwrap().values().cloned().collect()
108 }
109
110 pub fn aggregate_functions(&self) -> Vec<AggregateUDF> {
112 self.aggregate_functions
113 .read()
114 .unwrap()
115 .values()
116 .cloned()
117 .collect()
118 }
119
120 pub fn table_functions(&self) -> Vec<Arc<TableFunction>> {
121 self.table_functions
122 .read()
123 .unwrap()
124 .values()
125 .cloned()
126 .collect()
127 }
128
129 pub fn is_aggr_func_exist(&self, name: &str) -> bool {
131 self.aggregate_functions.read().unwrap().contains_key(name)
132 }
133}
134
135pub static FUNCTION_REGISTRY: LazyLock<Arc<FunctionRegistry>> = LazyLock::new(|| {
136 let function_registry = FunctionRegistry::default();
137
138 MathFunction::register(&function_registry);
140 TimestampFunction::register(&function_registry);
141 DateFunction::register(&function_registry);
142 ExpressionFunction::register(&function_registry);
143 UddSketchCalcFunction::register(&function_registry);
144 HllCalcFunction::register(&function_registry);
145
146 MatchesFunction::register(&function_registry);
148 MatchesTermFunction::register(&function_registry);
149
150 SystemFunction::register(&function_registry);
152 AdminFunction::register(&function_registry);
153
154 JsonFunction::register(&function_registry);
156
157 VectorScalarFunction::register(&function_registry);
159 VectorAggrFunction::register(&function_registry);
160
161 #[cfg(feature = "geo")]
163 crate::scalars::geo::GeoFunctions::register(&function_registry);
164 #[cfg(feature = "geo")]
165 crate::aggrs::geo::GeoFunction::register(&function_registry);
166
167 IpFunctions::register(&function_registry);
169
170 ApproximateFunction::register(&function_registry);
172
173 CountHash::register(&function_registry);
175
176 StateMergeHelper::register(&function_registry);
178
179 Arc::new(function_registry)
180});
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use crate::scalars::test::TestAndFunction;
186
187 #[test]
188 fn test_function_registry() {
189 let registry = FunctionRegistry::default();
190
191 assert!(registry.get_function("test_and").is_none());
192 assert!(registry.scalar_functions().is_empty());
193 registry.register_scalar(TestAndFunction);
194 let _ = registry.get_function("test_and").unwrap();
195 assert_eq!(1, registry.scalar_functions().len());
196 }
197}