1use proc_macro::TokenStream;
16use quote::quote;
17use syn::spanned::Spanned;
18use syn::{
19 Attribute, Ident, ItemFn, Path, Signature, Type, TypePath, TypeReference, Visibility,
20 parse_macro_input,
21};
22
23use crate::utils::extract_input_types;
24
25macro_rules! ok {
27 ($item:expr) => {
28 match $item {
29 Ok(item) => item,
30 Err(e) => return e.into_compile_error().into(),
31 }
32 };
33}
34
35macro_rules! error {
37 ($span:expr, $msg: expr) => {
38 Err(syn::Error::new($span, $msg))
39 };
40}
41
42pub(crate) fn process_admin_fn(args: TokenStream, input: TokenStream) -> TokenStream {
43 let mut name: Option<Ident> = None;
44 let mut display_name: Option<Ident> = None;
45 let mut sig_fn: Option<Ident> = None;
46 let mut ret: Option<Ident> = None;
47 let mut user_path: Option<Path> = None;
48
49 let parser = syn::meta::parser(|meta| {
50 if meta.path.is_ident("name") {
51 name = Some(meta.value()?.parse()?);
52 Ok(())
53 } else if meta.path.is_ident("display_name") {
54 display_name = Some(meta.value()?.parse()?);
55 Ok(())
56 } else if meta.path.is_ident("sig_fn") {
57 sig_fn = Some(meta.value()?.parse()?);
58 Ok(())
59 } else if meta.path.is_ident("ret") {
60 ret = Some(meta.value()?.parse()?);
61 Ok(())
62 } else if meta.path.is_ident("user_path") {
63 user_path = Some(meta.value()?.parse()?);
64 Ok(())
65 } else {
66 Err(meta.error("unsupported property"))
67 }
68 });
69
70 parse_macro_input!(args with parser);
72
73 if user_path.is_none() {
74 user_path = Some(syn::parse_str("crate").expect("failed to parse user path"));
75 }
76
77 let compute_fn = parse_macro_input!(input as ItemFn);
79 let ItemFn {
80 attrs,
81 vis,
82 sig,
83 block,
84 } = compute_fn;
85
86 let Signature {
88 inputs,
89 ident: fn_name,
90 ..
91 } = &sig;
92
93 let arg_types = ok!(extract_input_types(inputs));
94 if arg_types.len() < 2 {
95 ok!(error!(
96 sig.span(),
97 "Expect at least two argument for admin fn: (handler, query_ctx)"
98 ));
99 }
100 let handler_type = ok!(extract_handler_type(&arg_types));
101
102 let mut result = TokenStream::new();
103 if let Some(display_name) = display_name {
106 let struct_code = build_struct(
107 attrs,
108 vis,
109 fn_name,
110 name.expect("name required"),
111 sig_fn.expect("sig_fn required"),
112 ret.expect("ret required"),
113 handler_type,
114 display_name,
115 user_path.expect("user_path required"),
116 );
117 result.extend(struct_code);
118 }
119
120 let input_fn_code: TokenStream = quote! {
122 #sig { #block }
123 }
124 .into();
125
126 result.extend(input_fn_code);
127 result
128}
129
130fn extract_handler_type(arg_types: &[Type]) -> Result<&Ident, syn::Error> {
132 match &arg_types[0] {
133 Type::Reference(TypeReference { elem, .. }) => match &**elem {
134 Type::Path(TypePath { path, .. }) => Ok(&path
135 .segments
136 .first()
137 .expect("Expected a reference of handler")
138 .ident),
139 other => {
140 error!(other.span(), "Expected a reference of handler")
141 }
142 },
143 other => {
144 error!(other.span(), "Expected a reference of handler")
145 }
146 }
147}
148
149#[allow(clippy::too_many_arguments)]
151fn build_struct(
152 attrs: Vec<Attribute>,
153 vis: Visibility,
154 fn_name: &Ident,
155 name: Ident,
156 sig_fn: Ident,
157 ret: Ident,
158 handler_type: &Ident,
159 display_name_ident: Ident,
160 user_path: Path,
161) -> TokenStream {
162 let display_name = display_name_ident.to_string();
163 let ret = Ident::new(&format!("{ret}_datatype"), ret.span());
164 let uppcase_display_name = display_name.to_uppercase();
165 let (handler, snafu_type) = match handler_type.to_string().as_str() {
168 "ProcedureServiceHandlerRef" => (
169 Ident::new("procedure_service_handler", handler_type.span()),
170 Ident::new("MissingProcedureServiceHandlerSnafu", handler_type.span()),
171 ),
172
173 "TableMutationHandlerRef" => (
174 Ident::new("table_mutation_handler", handler_type.span()),
175 Ident::new("MissingTableMutationHandlerSnafu", handler_type.span()),
176 ),
177
178 "FlowServiceHandlerRef" => (
179 Ident::new("flow_service_handler", handler_type.span()),
180 Ident::new("MissingFlowServiceHandlerSnafu", handler_type.span()),
181 ),
182 handler => ok!(error!(
183 handler_type.span(),
184 format!("Unknown handler type: {handler}")
185 )),
186 };
187
188 quote! {
189 #(#attrs)*
190 #vis struct #name {
191 signature: datafusion_expr::Signature,
192 func_ctx: #user_path::function::FunctionContext,
193 }
194
195 impl #name {
196 fn create(signature: datafusion_expr::Signature, func_ctx: #user_path::function::FunctionContext) -> Self {
198 Self {
199 signature,
200 func_ctx,
201 }
202 }
203
204 pub fn factory() -> impl Into< #user_path::function_factory::ScalarFunctionFactory> {
206 Self {
207 signature: #sig_fn().into(),
208 func_ctx: #user_path::function::FunctionContext::default(),
209 }
210 }
211 }
212
213 impl std::fmt::Display for #name {
214 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
215 write!(f, #uppcase_display_name)
216 }
217 }
218
219 impl std::fmt::Debug for #name {
220 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
221 write!(f, "{}({})", #uppcase_display_name, self.func_ctx)
222 }
223 }
224
225 impl datafusion::logical_expr::ScalarUDFImpl for #name {
227 fn as_any(&self) -> &dyn std::any::Any {
228 self
229 }
230
231 fn name(&self) -> &str {
232 #display_name
233 }
234
235 fn signature(&self) -> &datafusion_expr::Signature {
236 &self.signature
237 }
238
239 fn return_type(&self, _arg_types: &[datafusion::arrow::datatypes::DataType]) -> datafusion_common::Result<datafusion::arrow::datatypes::DataType> {
240 use datatypes::data_type::DataType;
241 Ok(store_api::storage::ConcreteDataType::#ret().as_arrow_type())
242 }
243
244 fn invoke_with_args(
245 &self,
246 _args: datafusion::logical_expr::ScalarFunctionArgs,
247 ) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
248 Err(datafusion_common::DataFusionError::NotImplemented(
249 format!("{} can only be called from async contexts", #display_name)
250 ))
251 }
252 }
253
254 impl From<#name> for #user_path::function_factory::ScalarFunctionFactory {
256 fn from(func: #name) -> Self {
257 use std::sync::Arc;
258 use datafusion_expr::ScalarUDFImpl;
259 use datafusion_expr::async_udf::AsyncScalarUDF;
260
261 let name = func.name().to_string();
262
263 let func = Arc::new(move |ctx: #user_path::function::FunctionContext| {
264 let udf_impl = #name::create(func.signature.clone(), ctx);
266 let async_udf = AsyncScalarUDF::new(Arc::new(udf_impl));
267 async_udf.into_scalar_udf()
268 });
269 Self {
270 name,
271 factory: func,
272 }
273 }
274 }
275
276 #[async_trait::async_trait]
278 impl datafusion_expr::async_udf::AsyncScalarUDFImpl for #name {
279 async fn invoke_async_with_args(
280 &self,
281 args: datafusion::logical_expr::ScalarFunctionArgs,
282 ) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
283 let columns = args.args
284 .iter()
285 .map(|arg| {
286 common_query::prelude::ColumnarValue::try_from(arg)
287 .and_then(|cv| match cv {
288 common_query::prelude::ColumnarValue::Vector(v) => Ok(v),
289 common_query::prelude::ColumnarValue::Scalar(s) => {
290 datatypes::vectors::Helper::try_from_scalar_value(s, args.number_rows)
291 .context(common_query::error::FromScalarValueSnafu)
292 }
293 })
294 })
295 .collect::<common_query::error::Result<Vec<_>>>()
296 .map_err(|e| datafusion_common::DataFusionError::Execution(format!("Column conversion error: {}", e)))?;
297
298 #user_path::ensure_greptime!(self.func_ctx);
300
301 let columns_num = columns.len();
302 let rows_num = if columns.is_empty() {
303 1
304 } else {
305 columns[0].len()
306 };
307
308 use snafu::{OptionExt, ResultExt};
309 use datatypes::data_type::DataType;
310
311 let query_ctx = &self.func_ctx.query_ctx;
312 let handler = self.func_ctx
313 .state
314 .#handler
315 .as_ref()
316 .context(#snafu_type)
317 .map_err(|e| datafusion_common::DataFusionError::Execution(format!("Handler error: {}", e)))?;
318
319 let mut builder = store_api::storage::ConcreteDataType::#ret()
320 .create_mutable_vector(rows_num);
321
322 if columns_num == 0 {
323 let result = #fn_name(handler, query_ctx, &[]).await
324 .map_err(|e| datafusion_common::DataFusionError::Execution(format!("Function execution error: {}", e)))?;
325
326 builder.push_value_ref(result.as_value_ref());
327 } else {
328 for i in 0..rows_num {
329 let args: Vec<_> = columns.iter()
330 .map(|vector| vector.get_ref(i))
331 .collect();
332
333 let result = #fn_name(handler, query_ctx, &args).await
334 .map_err(|e| datafusion_common::DataFusionError::Execution(format!("Function execution error: {}", e)))?;
335
336 builder.push_value_ref(result.as_value_ref());
337 }
338 }
339
340 let result_vector = builder.to_vector();
341
342 Ok(datafusion_expr::ColumnarValue::Array(result_vector.to_arrow_array()))
344 }
345 }
346 }
347 .into()
348}