common_macro/
range_fn.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use proc_macro::TokenStream;
16use quote::quote;
17use syn::spanned::Spanned;
18use syn::{
19    parse_macro_input, Attribute, Ident, ItemFn, Signature, Type, TypeReference, Visibility,
20};
21
22use crate::utils::extract_input_types;
23
24macro_rules! ok {
25    ($item:expr) => {
26        match $item {
27            Ok(item) => item,
28            Err(e) => return e.into_compile_error().into(),
29        }
30    };
31}
32
33pub(crate) fn process_range_fn(args: TokenStream, input: TokenStream) -> TokenStream {
34    let mut name: Option<Ident> = None;
35    let mut display_name: Option<Ident> = None;
36    let mut ret: Option<Ident> = None;
37
38    let parser = syn::meta::parser(|meta| {
39        if meta.path.is_ident("name") {
40            name = Some(meta.value()?.parse()?);
41            Ok(())
42        } else if meta.path.is_ident("display_name") {
43            display_name = Some(meta.value()?.parse()?);
44            Ok(())
45        } else if meta.path.is_ident("ret") {
46            ret = Some(meta.value()?.parse()?);
47            Ok(())
48        } else {
49            Err(meta.error("unsupported property"))
50        }
51    });
52
53    // extract arg map
54    parse_macro_input!(args with parser);
55
56    // decompose the fn block
57    let compute_fn = parse_macro_input!(input as ItemFn);
58    let ItemFn {
59        attrs,
60        vis,
61        sig,
62        block,
63    } = compute_fn;
64
65    // extract fn arg list
66    let Signature {
67        inputs,
68        ident: fn_name,
69        ..
70    } = &sig;
71    let arg_types = ok!(extract_input_types(inputs));
72
73    // with format like Float64Array
74    let array_types = arg_types
75        .iter()
76        .map(|ty| {
77            if let Type::Reference(TypeReference { elem, .. }) = ty {
78                elem.as_ref().clone()
79            } else {
80                ty.clone()
81            }
82        })
83        .collect::<Vec<_>>();
84
85    let mut result = TokenStream::new();
86
87    // build the struct and its impl block
88    // only do this when `display_name` is specified
89    if let Some(display_name) = display_name {
90        let struct_code = build_struct(
91            attrs,
92            vis,
93            name.clone().expect("name required"),
94            display_name,
95            array_types,
96            ret.clone().expect("ret required"),
97        );
98        result.extend(struct_code);
99    }
100
101    let calc_fn_code = build_calc_fn(
102        name.expect("name required"),
103        arg_types,
104        fn_name.clone(),
105        ret.expect("ret required"),
106    );
107    // preserve this fn, but remove its `pub` modifier
108    let input_fn_code: TokenStream = quote! {
109        #sig { #block }
110    }
111    .into();
112
113    result.extend(calc_fn_code);
114    result.extend(input_fn_code);
115    result
116}
117
118fn build_struct(
119    attrs: Vec<Attribute>,
120    vis: Visibility,
121    name: Ident,
122    display_name_ident: Ident,
123    array_types: Vec<Type>,
124    return_array_type: Ident,
125) -> TokenStream {
126    let display_name = display_name_ident.to_string();
127    quote! {
128        #(#attrs)*
129        #[derive(Debug)]
130        #vis struct #name {}
131
132        impl #name {
133            pub const fn name() -> &'static str {
134                #display_name
135            }
136
137            pub fn scalar_udf() -> ScalarUDF {
138                datafusion_expr::create_udf(
139                    Self::name(),
140                    Self::input_type(),
141                    Self::return_type(),
142                    Volatility::Volatile,
143                    Arc::new(Self::calc) as _,
144                )
145            }
146
147            fn input_type() -> Vec<DataType> {
148                vec![#( RangeArray::convert_data_type(#array_types::new_null(0).data_type().clone()), )*]
149            }
150
151            fn return_type() -> DataType {
152                #return_array_type::new_null(0).data_type().clone()
153            }
154        }
155    }
156    .into()
157}
158
159fn build_calc_fn(
160    name: Ident,
161    param_types: Vec<Type>,
162    fn_name: Ident,
163    ret_type: Ident,
164) -> TokenStream {
165    let param_names = param_types
166        .iter()
167        .enumerate()
168        .map(|(i, ty)| Ident::new(&format!("param_{}", i), ty.span()))
169        .collect::<Vec<_>>();
170    let unref_param_types = param_types
171        .iter()
172        .map(|ty| {
173            if let Type::Reference(TypeReference { elem, .. }) = ty {
174                elem.as_ref().clone()
175            } else {
176                ty.clone()
177            }
178        })
179        .collect::<Vec<_>>();
180    let num_params = param_types.len();
181    let param_numbers = (0..num_params).collect::<Vec<_>>();
182    let range_array_names = param_names
183        .iter()
184        .map(|name| Ident::new(&format!("{}_range_array", name), name.span()))
185        .collect::<Vec<_>>();
186    let first_range_array_name = range_array_names.first().unwrap().clone();
187    let first_param_name = param_names.first().unwrap().clone();
188
189    quote! {
190        impl #name {
191            fn calc(input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
192                assert_eq!(input.len(), #num_params);
193
194                #( let #range_array_names = RangeArray::try_new(extract_array(&input[#param_numbers])?.to_data().into())?; )*
195
196                // check arrays len
197                {
198                    let len_first = #first_range_array_name.len();
199                    #(
200                        if len_first != #range_array_names.len() {
201                            return Err(DataFusionError::Execution(format!("RangeArray have different lengths in PromQL function {}: array1={}, array2={}", #name::name(), len_first, #range_array_names.len())));
202                        }
203                    )*
204                }
205
206                let mut result_array = Vec::new();
207                for index in 0..#first_range_array_name.len(){
208                    #( let #param_names = #range_array_names.get(index).unwrap().as_any().downcast_ref::<#unref_param_types>().unwrap().clone(); )*
209
210                    // check element len
211                    {
212                        let len_first = #first_param_name.len();
213                        #(
214                            if len_first != #param_names.len() {
215                                return Err(DataFusionError::Execution(format!("RangeArray's element {} have different lengths in PromQL function {}: array1={}, array2={}", index, #name::name(), len_first, #param_names.len())));
216                            }
217                        )*
218                    }
219
220                    let result = #fn_name(#( &#param_names, )*);
221                    result_array.push(result);
222                }
223
224                let result = ColumnarValue::Array(Arc::new(#ret_type::from_iter(result_array)));
225                Ok(result)
226            }
227        }
228    }
229    .into()
230}