common_function/
function_registry.rs1use std::collections::HashMap;
17use std::sync::{Arc, LazyLock, RwLock};
18
19use datafusion::catalog::TableFunction;
20use datafusion_expr::AggregateUDF;
21
22use crate::admin::AdminFunction;
23use crate::aggrs::aggr_wrapper::StateMergeHelper;
24use crate::aggrs::approximate::ApproximateFunction;
25use crate::aggrs::count_hash::CountHash;
26use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
27use crate::function::{Function, FunctionRef};
28use crate::function_factory::ScalarFunctionFactory;
29use crate::scalars::date::DateFunction;
30use crate::scalars::expression::ExpressionFunction;
31use crate::scalars::hll_count::HllCalcFunction;
32use crate::scalars::ip::IpFunctions;
33use crate::scalars::json::JsonFunction;
34use crate::scalars::matches::MatchesFunction;
35use crate::scalars::matches_term::MatchesTermFunction;
36use crate::scalars::math::MathFunction;
37use crate::scalars::primary_key::DecodePrimaryKeyFunction;
38use crate::scalars::string::register_string_functions;
39use crate::scalars::timestamp::TimestampFunction;
40use crate::scalars::uddsketch_calc::UddSketchCalcFunction;
41use crate::scalars::vector::VectorFunction as VectorScalarFunction;
42use crate::system::SystemFunction;
43
44#[derive(Default)]
45pub struct FunctionRegistry {
46 functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
47 aggregate_functions: RwLock<HashMap<String, AggregateUDF>>,
48 table_functions: RwLock<HashMap<String, Arc<TableFunction>>>,
49}
50
51impl FunctionRegistry {
52 pub fn register(&self, func: impl Into<ScalarFunctionFactory>) {
61 let func = func.into();
62 let _ = self
63 .functions
64 .write()
65 .unwrap()
66 .insert(func.name().to_string(), func);
67 }
68
69 pub fn register_scalar(&self, func: impl Function + 'static) {
71 let func = Arc::new(func) as FunctionRef;
72
73 for alias in func.aliases() {
74 let func: ScalarFunctionFactory = func.clone().into();
75 let alias = ScalarFunctionFactory {
76 name: alias.clone(),
77 ..func
78 };
79 self.register(alias);
80 }
81
82 self.register(func)
83 }
84
85 pub fn register_aggr(&self, func: AggregateUDF) {
87 let _ = self
88 .aggregate_functions
89 .write()
90 .unwrap()
91 .insert(func.name().to_string(), func);
92 }
93
94 pub fn register_table_function(&self, func: TableFunction) {
96 let _ = self
97 .table_functions
98 .write()
99 .unwrap()
100 .insert(func.name().to_string(), Arc::new(func));
101 }
102
103 pub fn get_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
104 self.functions.read().unwrap().get(name).cloned()
105 }
106
107 pub fn scalar_functions(&self) -> Vec<ScalarFunctionFactory> {
109 self.functions.read().unwrap().values().cloned().collect()
110 }
111
112 pub fn aggregate_functions(&self) -> Vec<AggregateUDF> {
114 self.aggregate_functions
115 .read()
116 .unwrap()
117 .values()
118 .cloned()
119 .collect()
120 }
121
122 pub fn table_functions(&self) -> Vec<Arc<TableFunction>> {
123 self.table_functions
124 .read()
125 .unwrap()
126 .values()
127 .cloned()
128 .collect()
129 }
130
131 pub fn is_aggr_func_exist(&self, name: &str) -> bool {
133 self.aggregate_functions.read().unwrap().contains_key(name)
134 }
135}
136
137pub static FUNCTION_REGISTRY: LazyLock<Arc<FunctionRegistry>> = LazyLock::new(|| {
138 let function_registry = FunctionRegistry::default();
139
140 MathFunction::register(&function_registry);
142 TimestampFunction::register(&function_registry);
143 DateFunction::register(&function_registry);
144 ExpressionFunction::register(&function_registry);
145 UddSketchCalcFunction::register(&function_registry);
146 HllCalcFunction::register(&function_registry);
147 DecodePrimaryKeyFunction::register(&function_registry);
148
149 MatchesFunction::register(&function_registry);
151 MatchesTermFunction::register(&function_registry);
152
153 SystemFunction::register(&function_registry);
155 AdminFunction::register(&function_registry);
156
157 JsonFunction::register(&function_registry);
159
160 register_string_functions(&function_registry);
162
163 VectorScalarFunction::register(&function_registry);
165 VectorAggrFunction::register(&function_registry);
166
167 #[cfg(feature = "geo")]
169 crate::scalars::geo::GeoFunctions::register(&function_registry);
170 #[cfg(feature = "geo")]
171 crate::aggrs::geo::GeoFunction::register(&function_registry);
172
173 IpFunctions::register(&function_registry);
175
176 ApproximateFunction::register(&function_registry);
178
179 CountHash::register(&function_registry);
181
182 StateMergeHelper::register(&function_registry);
184
185 Arc::new(function_registry)
186});
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::scalars::test::TestAndFunction;
192
193 #[test]
194 fn test_function_registry() {
195 let registry = FunctionRegistry::default();
196
197 assert!(registry.get_function("test_and").is_none());
198 assert!(registry.scalar_functions().is_empty());
199 registry.register_scalar(TestAndFunction::default());
200 let _ = registry.get_function("test_and").unwrap();
201 assert_eq!(1, registry.scalar_functions().len());
202 }
203}