common_macro/
admin_fn.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use proc_macro::TokenStream;
16use quote::quote;
17use syn::spanned::Spanned;
18use syn::{
19    parse_macro_input, Attribute, Ident, ItemFn, Path, Signature, Type, TypePath, TypeReference,
20    Visibility,
21};
22
23use crate::utils::extract_input_types;
24
25/// Internal util macro to early return on error.
26macro_rules! ok {
27    ($item:expr) => {
28        match $item {
29            Ok(item) => item,
30            Err(e) => return e.into_compile_error().into(),
31        }
32    };
33}
34
35/// Internal util macro to create an error.
36macro_rules! error {
37    ($span:expr, $msg: expr) => {
38        Err(syn::Error::new($span, $msg))
39    };
40}
41
42pub(crate) fn process_admin_fn(args: TokenStream, input: TokenStream) -> TokenStream {
43    let mut name: Option<Ident> = None;
44    let mut display_name: Option<Ident> = None;
45    let mut sig_fn: Option<Ident> = None;
46    let mut ret: Option<Ident> = None;
47    let mut user_path: Option<Path> = None;
48
49    let parser = syn::meta::parser(|meta| {
50        if meta.path.is_ident("name") {
51            name = Some(meta.value()?.parse()?);
52            Ok(())
53        } else if meta.path.is_ident("display_name") {
54            display_name = Some(meta.value()?.parse()?);
55            Ok(())
56        } else if meta.path.is_ident("sig_fn") {
57            sig_fn = Some(meta.value()?.parse()?);
58            Ok(())
59        } else if meta.path.is_ident("ret") {
60            ret = Some(meta.value()?.parse()?);
61            Ok(())
62        } else if meta.path.is_ident("user_path") {
63            user_path = Some(meta.value()?.parse()?);
64            Ok(())
65        } else {
66            Err(meta.error("unsupported property"))
67        }
68    });
69
70    // extract arg map
71    parse_macro_input!(args with parser);
72
73    if user_path.is_none() {
74        user_path = Some(syn::parse_str("crate").expect("failed to parse user path"));
75    }
76
77    // decompose the fn block
78    let compute_fn = parse_macro_input!(input as ItemFn);
79    let ItemFn {
80        attrs,
81        vis,
82        sig,
83        block,
84    } = compute_fn;
85
86    // extract fn arg list
87    let Signature {
88        inputs,
89        ident: fn_name,
90        ..
91    } = &sig;
92
93    let arg_types = ok!(extract_input_types(inputs));
94    if arg_types.len() < 2 {
95        ok!(error!(
96            sig.span(),
97            "Expect at least two argument for admin fn: (handler, query_ctx)"
98        ));
99    }
100    let handler_type = ok!(extract_handler_type(&arg_types));
101
102    let mut result = TokenStream::new();
103    // build the struct and its impl block
104    // only do this when `display_name` is specified
105    if let Some(display_name) = display_name {
106        let struct_code = build_struct(
107            attrs,
108            vis,
109            fn_name,
110            name.expect("name required"),
111            sig_fn.expect("sig_fn required"),
112            ret.expect("ret required"),
113            handler_type,
114            display_name,
115            user_path.expect("user_path required"),
116        );
117        result.extend(struct_code);
118    }
119
120    // preserve this fn
121    let input_fn_code: TokenStream = quote! {
122        #sig { #block }
123    }
124    .into();
125
126    result.extend(input_fn_code);
127    result
128}
129
130/// Retrieve the handler type, `ProcedureServiceHandlerRef` or `TableMutationHandlerRef`.
131fn extract_handler_type(arg_types: &[Type]) -> Result<&Ident, syn::Error> {
132    match &arg_types[0] {
133        Type::Reference(TypeReference { elem, .. }) => match &**elem {
134            Type::Path(TypePath { path, .. }) => Ok(&path
135                .segments
136                .first()
137                .expect("Expected a reference of handler")
138                .ident),
139            other => {
140                error!(other.span(), "Expected a reference of handler")
141            }
142        },
143        other => {
144            error!(other.span(), "Expected a reference of handler")
145        }
146    }
147}
148
149/// Build the function struct
150#[allow(clippy::too_many_arguments)]
151fn build_struct(
152    attrs: Vec<Attribute>,
153    vis: Visibility,
154    fn_name: &Ident,
155    name: Ident,
156    sig_fn: Ident,
157    ret: Ident,
158    handler_type: &Ident,
159    display_name_ident: Ident,
160    user_path: Path,
161) -> TokenStream {
162    let display_name = display_name_ident.to_string();
163    let ret = Ident::new(&format!("{ret}_datatype"), ret.span());
164    let uppcase_display_name = display_name.to_uppercase();
165    // Get the handler name in function state by the argument ident
166    // TODO(discord9): consider simple depend injection if more handlers are needed
167    let (handler, snafu_type) = match handler_type.to_string().as_str() {
168        "ProcedureServiceHandlerRef" => (
169            Ident::new("procedure_service_handler", handler_type.span()),
170            Ident::new("MissingProcedureServiceHandlerSnafu", handler_type.span()),
171        ),
172
173        "TableMutationHandlerRef" => (
174            Ident::new("table_mutation_handler", handler_type.span()),
175            Ident::new("MissingTableMutationHandlerSnafu", handler_type.span()),
176        ),
177
178        "FlowServiceHandlerRef" => (
179            Ident::new("flow_service_handler", handler_type.span()),
180            Ident::new("MissingFlowServiceHandlerSnafu", handler_type.span()),
181        ),
182        handler => ok!(error!(
183            handler_type.span(),
184            format!("Unknown handler type: {handler}")
185        )),
186    };
187
188    quote! {
189        #(#attrs)*
190        #[derive(Debug)]
191        #vis struct #name;
192
193        impl std::fmt::Display for #name {
194            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
195                write!(f, #uppcase_display_name)
196            }
197        }
198
199
200        #[async_trait::async_trait]
201        impl #user_path::function::AsyncFunction for #name {
202            fn name(&self) -> &'static str {
203                #display_name
204            }
205
206            fn return_type(&self, _input_types: &[store_api::storage::ConcreteDataType]) -> common_query::error::Result<store_api::storage::ConcreteDataType> {
207                Ok(store_api::storage::ConcreteDataType::#ret())
208            }
209
210            fn signature(&self) -> Signature {
211                #sig_fn()
212            }
213
214            async fn eval(&self, func_ctx: #user_path::function::FunctionContext, columns: &[datatypes::vectors::VectorRef]) ->  common_query::error::Result<datatypes::vectors::VectorRef> {
215                // Ensure under the `greptime` catalog for security
216                #user_path::ensure_greptime!(func_ctx);
217
218                let columns_num = columns.len();
219                let rows_num = if columns.is_empty() {
220                    1
221                } else {
222                    columns[0].len()
223                };
224                let columns = Vec::from(columns);
225
226                use snafu::OptionExt;
227                use datatypes::data_type::DataType;
228
229                let query_ctx = &func_ctx.query_ctx;
230                let handler = func_ctx
231                    .state
232                    .#handler
233                    .as_ref()
234                    .context(#snafu_type)?;
235
236                let mut builder = store_api::storage::ConcreteDataType::#ret()
237                    .create_mutable_vector(rows_num);
238
239                if columns_num == 0 {
240                    let result = #fn_name(handler, query_ctx, &[]).await?;
241
242                    builder.push_value_ref(result.as_value_ref());
243                } else {
244                    for i in 0..rows_num {
245                        let args: Vec<_> = columns.iter()
246                            .map(|vector| vector.get_ref(i))
247                            .collect();
248
249                        let result = #fn_name(handler, query_ctx, &args).await?;
250
251                        builder.push_value_ref(result.as_value_ref());
252                    }
253                }
254
255                Ok(builder.to_vector())
256            }
257
258        }
259    }
260    .into()
261}