common_function/
function_registry.rs1use std::collections::HashMap;
17use std::sync::{Arc, RwLock};
18
19use datafusion_expr::AggregateUDF;
20use once_cell::sync::Lazy;
21
22use crate::admin::AdminFunction;
23use crate::aggrs::approximate::ApproximateFunction;
24use crate::aggrs::count_hash::CountHash;
25use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
26use crate::function::{AsyncFunctionRef, Function, FunctionRef};
27use crate::function_factory::ScalarFunctionFactory;
28use crate::scalars::date::DateFunction;
29use crate::scalars::expression::ExpressionFunction;
30use crate::scalars::hll_count::HllCalcFunction;
31use crate::scalars::ip::IpFunctions;
32use crate::scalars::json::JsonFunction;
33use crate::scalars::matches::MatchesFunction;
34use crate::scalars::matches_term::MatchesTermFunction;
35use crate::scalars::math::MathFunction;
36use crate::scalars::timestamp::TimestampFunction;
37use crate::scalars::uddsketch_calc::UddSketchCalcFunction;
38use crate::scalars::vector::VectorFunction as VectorScalarFunction;
39use crate::system::SystemFunction;
40
41#[derive(Default)]
42pub struct FunctionRegistry {
43 functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
44 async_functions: RwLock<HashMap<String, AsyncFunctionRef>>,
45 aggregate_functions: RwLock<HashMap<String, AggregateUDF>>,
46}
47
48impl FunctionRegistry {
49 pub fn register(&self, func: impl Into<ScalarFunctionFactory>) {
50 let func = func.into();
51 let _ = self
52 .functions
53 .write()
54 .unwrap()
55 .insert(func.name().to_string(), func);
56 }
57
58 pub fn register_scalar(&self, func: impl Function + 'static) {
59 self.register(Arc::new(func) as FunctionRef);
60 }
61
62 pub fn register_async(&self, func: AsyncFunctionRef) {
63 let _ = self
64 .async_functions
65 .write()
66 .unwrap()
67 .insert(func.name().to_string(), func);
68 }
69
70 pub fn register_aggr(&self, func: AggregateUDF) {
71 let _ = self
72 .aggregate_functions
73 .write()
74 .unwrap()
75 .insert(func.name().to_string(), func);
76 }
77
78 pub fn get_async_function(&self, name: &str) -> Option<AsyncFunctionRef> {
79 self.async_functions.read().unwrap().get(name).cloned()
80 }
81
82 pub fn async_functions(&self) -> Vec<AsyncFunctionRef> {
83 self.async_functions
84 .read()
85 .unwrap()
86 .values()
87 .cloned()
88 .collect()
89 }
90
91 #[cfg(test)]
92 pub fn get_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
93 self.functions.read().unwrap().get(name).cloned()
94 }
95
96 pub fn scalar_functions(&self) -> Vec<ScalarFunctionFactory> {
97 self.functions.read().unwrap().values().cloned().collect()
98 }
99
100 pub fn aggregate_functions(&self) -> Vec<AggregateUDF> {
101 self.aggregate_functions
102 .read()
103 .unwrap()
104 .values()
105 .cloned()
106 .collect()
107 }
108}
109
110pub static FUNCTION_REGISTRY: Lazy<Arc<FunctionRegistry>> = Lazy::new(|| {
111 let function_registry = FunctionRegistry::default();
112
113 MathFunction::register(&function_registry);
115 TimestampFunction::register(&function_registry);
116 DateFunction::register(&function_registry);
117 ExpressionFunction::register(&function_registry);
118 UddSketchCalcFunction::register(&function_registry);
119 HllCalcFunction::register(&function_registry);
120
121 MatchesFunction::register(&function_registry);
123 MatchesTermFunction::register(&function_registry);
124
125 SystemFunction::register(&function_registry);
127 AdminFunction::register(&function_registry);
128
129 JsonFunction::register(&function_registry);
131
132 VectorScalarFunction::register(&function_registry);
134 VectorAggrFunction::register(&function_registry);
135
136 #[cfg(feature = "geo")]
138 crate::scalars::geo::GeoFunctions::register(&function_registry);
139 #[cfg(feature = "geo")]
140 crate::aggrs::geo::GeoFunction::register(&function_registry);
141
142 IpFunctions::register(&function_registry);
144
145 ApproximateFunction::register(&function_registry);
147
148 CountHash::register(&function_registry);
150
151 Arc::new(function_registry)
152});
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::scalars::test::TestAndFunction;
158
159 #[test]
160 fn test_function_registry() {
161 let registry = FunctionRegistry::default();
162
163 assert!(registry.get_function("test_and").is_none());
164 assert!(registry.scalar_functions().is_empty());
165 registry.register_scalar(TestAndFunction);
166 let _ = registry.get_function("test_and").unwrap();
167 assert_eq!(1, registry.scalar_functions().len());
168 }
169}