1use proc_macro::TokenStream;
16use quote::quote;
17use syn::spanned::Spanned;
18use syn::{
19 parse_macro_input, Attribute, Ident, ItemFn, Signature, Type, TypeReference, Visibility,
20};
21
22use crate::utils::extract_input_types;
23
24macro_rules! ok {
25 ($item:expr) => {
26 match $item {
27 Ok(item) => item,
28 Err(e) => return e.into_compile_error().into(),
29 }
30 };
31}
32
33pub(crate) fn process_range_fn(args: TokenStream, input: TokenStream) -> TokenStream {
34 let mut name: Option<Ident> = None;
35 let mut display_name: Option<Ident> = None;
36 let mut ret: Option<Ident> = None;
37
38 let parser = syn::meta::parser(|meta| {
39 if meta.path.is_ident("name") {
40 name = Some(meta.value()?.parse()?);
41 Ok(())
42 } else if meta.path.is_ident("display_name") {
43 display_name = Some(meta.value()?.parse()?);
44 Ok(())
45 } else if meta.path.is_ident("ret") {
46 ret = Some(meta.value()?.parse()?);
47 Ok(())
48 } else {
49 Err(meta.error("unsupported property"))
50 }
51 });
52
53 parse_macro_input!(args with parser);
55
56 let compute_fn = parse_macro_input!(input as ItemFn);
58 let ItemFn {
59 attrs,
60 vis,
61 sig,
62 block,
63 } = compute_fn;
64
65 let Signature {
67 inputs,
68 ident: fn_name,
69 ..
70 } = &sig;
71 let arg_types = ok!(extract_input_types(inputs));
72
73 let array_types = arg_types
75 .iter()
76 .map(|ty| {
77 if let Type::Reference(TypeReference { elem, .. }) = ty {
78 elem.as_ref().clone()
79 } else {
80 ty.clone()
81 }
82 })
83 .collect::<Vec<_>>();
84
85 let mut result = TokenStream::new();
86
87 if let Some(display_name) = display_name {
90 let struct_code = build_struct(
91 attrs,
92 vis,
93 name.clone().expect("name required"),
94 display_name,
95 array_types,
96 ret.clone().expect("ret required"),
97 );
98 result.extend(struct_code);
99 }
100
101 let calc_fn_code = build_calc_fn(
102 name.expect("name required"),
103 arg_types,
104 fn_name.clone(),
105 ret.expect("ret required"),
106 );
107 let input_fn_code: TokenStream = quote! {
109 #sig { #block }
110 }
111 .into();
112
113 result.extend(calc_fn_code);
114 result.extend(input_fn_code);
115 result
116}
117
118fn build_struct(
119 attrs: Vec<Attribute>,
120 vis: Visibility,
121 name: Ident,
122 display_name_ident: Ident,
123 array_types: Vec<Type>,
124 return_array_type: Ident,
125) -> TokenStream {
126 let display_name = display_name_ident.to_string();
127 quote! {
128 #(#attrs)*
129 #[derive(Debug)]
130 #vis struct #name {}
131
132 impl #name {
133 pub const fn name() -> &'static str {
134 #display_name
135 }
136
137 pub fn scalar_udf() -> ScalarUDF {
138 datafusion_expr::create_udf(
139 Self::name(),
140 Self::input_type(),
141 Self::return_type(),
142 Volatility::Volatile,
143 Arc::new(Self::calc) as _,
144 )
145 }
146
147 fn input_type() -> Vec<DataType> {
148 vec![#( RangeArray::convert_data_type(#array_types::new_null(0).data_type().clone()), )*]
149 }
150
151 fn return_type() -> DataType {
152 #return_array_type::new_null(0).data_type().clone()
153 }
154 }
155 }
156 .into()
157}
158
159fn build_calc_fn(
160 name: Ident,
161 param_types: Vec<Type>,
162 fn_name: Ident,
163 ret_type: Ident,
164) -> TokenStream {
165 let param_names = param_types
166 .iter()
167 .enumerate()
168 .map(|(i, ty)| Ident::new(&format!("param_{}", i), ty.span()))
169 .collect::<Vec<_>>();
170 let unref_param_types = param_types
171 .iter()
172 .map(|ty| {
173 if let Type::Reference(TypeReference { elem, .. }) = ty {
174 elem.as_ref().clone()
175 } else {
176 ty.clone()
177 }
178 })
179 .collect::<Vec<_>>();
180 let num_params = param_types.len();
181 let param_numbers = (0..num_params).collect::<Vec<_>>();
182 let range_array_names = param_names
183 .iter()
184 .map(|name| Ident::new(&format!("{}_range_array", name), name.span()))
185 .collect::<Vec<_>>();
186 let first_range_array_name = range_array_names.first().unwrap().clone();
187 let first_param_name = param_names.first().unwrap().clone();
188
189 quote! {
190 impl #name {
191 fn calc(input: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
192 assert_eq!(input.len(), #num_params);
193
194 #( let #range_array_names = RangeArray::try_new(extract_array(&input[#param_numbers])?.to_data().into())?; )*
195
196 {
198 let len_first = #first_range_array_name.len();
199 #(
200 if len_first != #range_array_names.len() {
201 return Err(DataFusionError::Execution(format!("RangeArray have different lengths in PromQL function {}: array1={}, array2={}", #name::name(), len_first, #range_array_names.len())));
202 }
203 )*
204 }
205
206 let mut result_array = Vec::new();
207 for index in 0..#first_range_array_name.len(){
208 #( let #param_names = #range_array_names.get(index).unwrap().as_any().downcast_ref::<#unref_param_types>().unwrap().clone(); )*
209
210 {
212 let len_first = #first_param_name.len();
213 #(
214 if len_first != #param_names.len() {
215 return Err(DataFusionError::Execution(format!("RangeArray's element {} have different lengths in PromQL function {}: array1={}, array2={}", index, #name::name(), len_first, #param_names.len())));
216 }
217 )*
218 }
219
220 let result = #fn_name(#( &#param_names, )*);
221 result_array.push(result);
222 }
223
224 let result = ColumnarValue::Array(Arc::new(#ret_type::from_iter(result_array)));
225 Ok(result)
226 }
227 }
228 }
229 .into()
230}