common_macro/
admin_fn.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use proc_macro::TokenStream;
16use quote::quote;
17use syn::spanned::Spanned;
18use syn::{
19    Attribute, Ident, ItemFn, Path, Signature, Type, TypePath, TypeReference, Visibility,
20    parse_macro_input,
21};
22
23use crate::utils::extract_input_types;
24
25/// Internal util macro to early return on error.
26macro_rules! ok {
27    ($item:expr) => {
28        match $item {
29            Ok(item) => item,
30            Err(e) => return e.into_compile_error().into(),
31        }
32    };
33}
34
35/// Internal util macro to create an error.
36macro_rules! error {
37    ($span:expr, $msg: expr) => {
38        Err(syn::Error::new($span, $msg))
39    };
40}
41
42pub(crate) fn process_admin_fn(args: TokenStream, input: TokenStream) -> TokenStream {
43    let mut name: Option<Ident> = None;
44    let mut display_name: Option<Ident> = None;
45    let mut sig_fn: Option<Ident> = None;
46    let mut ret: Option<Ident> = None;
47    let mut user_path: Option<Path> = None;
48
49    let parser = syn::meta::parser(|meta| {
50        if meta.path.is_ident("name") {
51            name = Some(meta.value()?.parse()?);
52            Ok(())
53        } else if meta.path.is_ident("display_name") {
54            display_name = Some(meta.value()?.parse()?);
55            Ok(())
56        } else if meta.path.is_ident("sig_fn") {
57            sig_fn = Some(meta.value()?.parse()?);
58            Ok(())
59        } else if meta.path.is_ident("ret") {
60            ret = Some(meta.value()?.parse()?);
61            Ok(())
62        } else if meta.path.is_ident("user_path") {
63            user_path = Some(meta.value()?.parse()?);
64            Ok(())
65        } else {
66            Err(meta.error("unsupported property"))
67        }
68    });
69
70    // extract arg map
71    parse_macro_input!(args with parser);
72
73    if user_path.is_none() {
74        user_path = Some(syn::parse_str("crate").expect("failed to parse user path"));
75    }
76
77    // decompose the fn block
78    let compute_fn = parse_macro_input!(input as ItemFn);
79    let ItemFn {
80        attrs,
81        vis,
82        sig,
83        block,
84    } = compute_fn;
85
86    // extract fn arg list
87    let Signature {
88        inputs,
89        ident: fn_name,
90        ..
91    } = &sig;
92
93    let arg_types = ok!(extract_input_types(inputs));
94    if arg_types.len() < 2 {
95        ok!(error!(
96            sig.span(),
97            "Expect at least two argument for admin fn: (handler, query_ctx)"
98        ));
99    }
100    let handler_type = ok!(extract_handler_type(&arg_types));
101
102    let mut result = TokenStream::new();
103    // build the struct and its impl block
104    // only do this when `display_name` is specified
105    if let Some(display_name) = display_name {
106        let struct_code = build_struct(
107            attrs,
108            vis,
109            fn_name,
110            name.expect("name required"),
111            sig_fn.expect("sig_fn required"),
112            ret.expect("ret required"),
113            handler_type,
114            display_name,
115            user_path.expect("user_path required"),
116        );
117        result.extend(struct_code);
118    }
119
120    // preserve this fn
121    let input_fn_code: TokenStream = quote! {
122        #sig { #block }
123    }
124    .into();
125
126    result.extend(input_fn_code);
127    result
128}
129
130/// Retrieve the handler type, `ProcedureServiceHandlerRef` or `TableMutationHandlerRef`.
131fn extract_handler_type(arg_types: &[Type]) -> Result<&Ident, syn::Error> {
132    match &arg_types[0] {
133        Type::Reference(TypeReference { elem, .. }) => match &**elem {
134            Type::Path(TypePath { path, .. }) => Ok(&path
135                .segments
136                .first()
137                .expect("Expected a reference of handler")
138                .ident),
139            other => {
140                error!(other.span(), "Expected a reference of handler")
141            }
142        },
143        other => {
144            error!(other.span(), "Expected a reference of handler")
145        }
146    }
147}
148
149/// Build the function struct
150#[allow(clippy::too_many_arguments)]
151fn build_struct(
152    attrs: Vec<Attribute>,
153    vis: Visibility,
154    fn_name: &Ident,
155    name: Ident,
156    sig_fn: Ident,
157    ret: Ident,
158    handler_type: &Ident,
159    display_name_ident: Ident,
160    user_path: Path,
161) -> TokenStream {
162    let display_name = display_name_ident.to_string();
163    let ret = Ident::new(&format!("{ret}_datatype"), ret.span());
164    let uppcase_display_name = display_name.to_uppercase();
165    // Get the handler name in function state by the argument ident
166    // TODO(discord9): consider simple depend injection if more handlers are needed
167    let (handler, snafu_type) = match handler_type.to_string().as_str() {
168        "ProcedureServiceHandlerRef" => (
169            Ident::new("procedure_service_handler", handler_type.span()),
170            Ident::new("MissingProcedureServiceHandlerSnafu", handler_type.span()),
171        ),
172
173        "TableMutationHandlerRef" => (
174            Ident::new("table_mutation_handler", handler_type.span()),
175            Ident::new("MissingTableMutationHandlerSnafu", handler_type.span()),
176        ),
177
178        "FlowServiceHandlerRef" => (
179            Ident::new("flow_service_handler", handler_type.span()),
180            Ident::new("MissingFlowServiceHandlerSnafu", handler_type.span()),
181        ),
182        handler => ok!(error!(
183            handler_type.span(),
184            format!("Unknown handler type: {handler}")
185        )),
186    };
187
188    quote! {
189        #(#attrs)*
190        #vis struct #name {
191            signature: datafusion_expr::Signature,
192            func_ctx: #user_path::function::FunctionContext,
193        }
194
195        impl #name {
196            /// Creates a new instance of the function with function context.
197            fn create(signature: datafusion_expr::Signature, func_ctx: #user_path::function::FunctionContext) -> Self {
198                Self {
199                    signature,
200                    func_ctx,
201                }
202            }
203
204            /// Returns the [`ScalarFunctionFactory`] of the function.
205            pub fn factory() -> impl Into< #user_path::function_factory::ScalarFunctionFactory>  {
206                Self {
207                    signature: #sig_fn().into(),
208                    func_ctx: #user_path::function::FunctionContext::default(),
209                }
210            }
211        }
212
213        impl std::fmt::Display for #name {
214            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
215                write!(f, #uppcase_display_name)
216            }
217        }
218
219        impl std::fmt::Debug for #name {
220            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
221                write!(f, "{}({})", #uppcase_display_name, self.func_ctx)
222            }
223        }
224
225        // Implement DataFusion's ScalarUDFImpl trait
226        impl datafusion::logical_expr::ScalarUDFImpl for #name {
227            fn as_any(&self) -> &dyn std::any::Any {
228                self
229            }
230
231            fn name(&self) -> &str {
232                #display_name
233            }
234
235            fn signature(&self) -> &datafusion_expr::Signature {
236                &self.signature
237            }
238
239            fn return_type(&self, _arg_types: &[datafusion::arrow::datatypes::DataType]) -> datafusion_common::Result<datafusion::arrow::datatypes::DataType> {
240                use datatypes::data_type::DataType;
241                Ok(store_api::storage::ConcreteDataType::#ret().as_arrow_type())
242            }
243
244            fn invoke_with_args(
245                &self,
246                _args: datafusion::logical_expr::ScalarFunctionArgs,
247            ) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
248                Err(datafusion_common::DataFusionError::NotImplemented(
249                    format!("{} can only be called from async contexts", #display_name)
250                ))
251            }
252        }
253
254        /// Implement From trait for ScalarFunctionFactory
255        impl From<#name> for  #user_path::function_factory::ScalarFunctionFactory {
256            fn from(func: #name) -> Self {
257                 use std::sync::Arc;
258                 use datafusion_expr::ScalarUDFImpl;
259                 use datafusion_expr::async_udf::AsyncScalarUDF;
260
261                let name = func.name().to_string();
262
263                let func = Arc::new(move |ctx: #user_path::function::FunctionContext| {
264                    // create the UDF dynamically with function context
265                    let udf_impl = #name::create(func.signature.clone(), ctx);
266                    let async_udf = AsyncScalarUDF::new(Arc::new(udf_impl));
267                    async_udf.into_scalar_udf()
268                });
269                Self {
270                    name,
271                    factory: func,
272                }
273            }
274        }
275
276        // Implement DataFusion's AsyncScalarUDFImpl trait
277        #[async_trait::async_trait]
278        impl datafusion_expr::async_udf::AsyncScalarUDFImpl for #name {
279            async fn invoke_async_with_args(
280                &self,
281                args: datafusion::logical_expr::ScalarFunctionArgs,
282            ) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
283                let columns = args.args
284                    .iter()
285                    .map(|arg| {
286                        common_query::prelude::ColumnarValue::try_from(arg)
287                            .and_then(|cv| match cv {
288                                common_query::prelude::ColumnarValue::Vector(v) => Ok(v),
289                                common_query::prelude::ColumnarValue::Scalar(s) => {
290                                    datatypes::vectors::Helper::try_from_scalar_value(s, args.number_rows)
291                                        .context(common_query::error::FromScalarValueSnafu)
292                                }
293                            })
294                    })
295                    .collect::<common_query::error::Result<Vec<_>>>()
296                    .map_err(|e| datafusion_common::DataFusionError::Execution(format!("Column conversion error: {}", e)))?;
297
298                // Safety check: Ensure under the `greptime` catalog for security
299                #user_path::ensure_greptime!(self.func_ctx);
300
301                let columns_num = columns.len();
302                let rows_num = if columns.is_empty() {
303                    1
304                } else {
305                    columns[0].len()
306                };
307
308                use snafu::{OptionExt, ResultExt};
309                use datatypes::data_type::DataType;
310
311                let query_ctx = &self.func_ctx.query_ctx;
312                let handler = self.func_ctx
313                    .state
314                    .#handler
315                    .as_ref()
316                    .context(#snafu_type)
317                    .map_err(|e| datafusion_common::DataFusionError::Execution(format!("Handler error: {}", e)))?;
318
319                let mut builder = store_api::storage::ConcreteDataType::#ret()
320                    .create_mutable_vector(rows_num);
321
322                if columns_num == 0 {
323                    let result = #fn_name(handler, query_ctx, &[]).await
324                        .map_err(|e| datafusion_common::DataFusionError::Execution(format!("Function execution error: {}", e)))?;
325
326                    builder.push_value_ref(result.as_value_ref());
327                } else {
328                    for i in 0..rows_num {
329                        let args: Vec<_> = columns.iter()
330                            .map(|vector| vector.get_ref(i))
331                            .collect();
332
333                        let result = #fn_name(handler, query_ctx, &args).await
334                            .map_err(|e| datafusion_common::DataFusionError::Execution(format!("Function execution error: {}", e)))?;
335
336                        builder.push_value_ref(result.as_value_ref());
337                    }
338                }
339
340                let result_vector = builder.to_vector();
341
342                // Convert result back to DataFusion ColumnarValue
343                Ok(datafusion_expr::ColumnarValue::Array(result_vector.to_arrow_array()))
344            }
345        }
346    }
347    .into()
348}