operator/statement/
admin.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_function::function::FunctionContext;
18use common_function::function_registry::FUNCTION_REGISTRY;
19use common_query::prelude::TypeSignature;
20use common_query::Output;
21use common_recordbatch::{RecordBatch, RecordBatches};
22use common_sql::convert::sql_value_to_value;
23use common_telemetry::tracing;
24use common_time::Timezone;
25use datatypes::data_type::DataType;
26use datatypes::prelude::ConcreteDataType;
27use datatypes::schema::{ColumnSchema, Schema};
28use datatypes::value::Value;
29use datatypes::vectors::VectorRef;
30use session::context::QueryContextRef;
31use snafu::{ensure, OptionExt, ResultExt};
32use sql::ast::{Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Value as SqlValue};
33use sql::statements::admin::Admin;
34
35use crate::error::{self, ExecuteAdminFunctionSnafu, IntoVectorsSnafu, Result};
36use crate::statement::StatementExecutor;
37
38const DUMMY_COLUMN: &str = "<dummy>";
39
40impl StatementExecutor {
41    /// Execute the [`Admin`] statement and returns the output.
42    #[tracing::instrument(skip_all)]
43    pub(super) async fn execute_admin_command(
44        &self,
45        stmt: Admin,
46        query_ctx: QueryContextRef,
47    ) -> Result<Output> {
48        let Admin::Func(func) = &stmt;
49        // the function name should be in lower case.
50        let func_name = func.name.to_string().to_lowercase();
51        let factory = FUNCTION_REGISTRY
52            .get_function(&func_name)
53            .context(error::AdminFunctionNotFoundSnafu { name: func_name })?;
54        let func_ctx = FunctionContext {
55            query_ctx: query_ctx.clone(),
56            state: self.query_engine.engine_state().function_state(),
57        };
58
59        let admin_udf = factory.provide(func_ctx);
60        let fn_name = admin_udf.name();
61        let signature = admin_udf.signature();
62
63        // Parse function arguments
64        let FunctionArguments::List(args) = &func.args else {
65            return error::BuildAdminFunctionArgsSnafu {
66                msg: format!("unsupported function args {} for {}", func.args, fn_name),
67            }
68            .fail();
69        };
70        let arg_values = args
71            .args
72            .iter()
73            .map(|arg| {
74                let FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(value))) = arg else {
75                    return error::BuildAdminFunctionArgsSnafu {
76                        msg: format!("unsupported function arg {arg} for {}", fn_name),
77                    }
78                    .fail();
79                };
80                Ok(&value.value)
81            })
82            .collect::<Result<Vec<_>>>()?;
83
84        let type_sig = (&signature.type_signature).into();
85        let args = args_to_vector(&type_sig, &arg_values, &query_ctx)?;
86        let arg_types = args
87            .iter()
88            .map(|arg| arg.data_type().as_arrow_type())
89            .collect::<Vec<_>>();
90        let ret_type = admin_udf.return_type(&arg_types).map_err(|e| {
91            error::Error::BuildAdminFunctionArgs {
92                msg: format!(
93                    "Failed to get return type of admin function {}: {}",
94                    fn_name, e
95                ),
96            }
97        })?;
98
99        // Convert arguments to DataFusion ColumnarValue format
100        let columnar_args: Vec<datafusion_expr::ColumnarValue> = args
101            .iter()
102            .map(|vector| datafusion_expr::ColumnarValue::Array(vector.to_arrow_array()))
103            .collect();
104
105        // Create ScalarFunctionArgs following the same pattern as udf.rs
106        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
107            args: columnar_args,
108            arg_fields: args
109                .iter()
110                .enumerate()
111                .map(|(i, vector)| {
112                    Arc::new(arrow::datatypes::Field::new(
113                        format!("arg_{}", i),
114                        arg_types[i].clone(),
115                        vector.null_count() > 0,
116                    ))
117                })
118                .collect(),
119            return_field: Arc::new(arrow::datatypes::Field::new("result", ret_type, true)),
120            number_rows: if args.is_empty() { 1 } else { args[0].len() },
121            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
122        };
123
124        // Execute the async UDF
125        let result_columnar = admin_udf
126            .as_async()
127            .with_context(|| error::BuildAdminFunctionArgsSnafu {
128                msg: format!("Function {} is not async", fn_name),
129            })?
130            .invoke_async_with_args(func_args)
131            .await
132            .with_context(|_| ExecuteAdminFunctionSnafu {
133                msg: format!("Failed to execute admin function {}", fn_name),
134            })?;
135
136        // Convert result back to VectorRef
137        let result = match result_columnar {
138            datafusion_expr::ColumnarValue::Array(array) => {
139                datatypes::vectors::Helper::try_into_vector(array).context(IntoVectorsSnafu)?
140            }
141            datafusion_expr::ColumnarValue::Scalar(scalar) => {
142                let array =
143                    scalar
144                        .to_array_of_size(1)
145                        .with_context(|_| ExecuteAdminFunctionSnafu {
146                            msg: format!("Failed to convert scalar to array for {}", fn_name),
147                        })?;
148                datatypes::vectors::Helper::try_into_vector(array).context(IntoVectorsSnafu)?
149            }
150        };
151
152        let result_vector: VectorRef = result;
153        let column_schemas = vec![ColumnSchema::new(
154            // Use statement as the result column name
155            stmt.to_string(),
156            result_vector.data_type(),
157            false,
158        )];
159        let schema = Arc::new(Schema::new(column_schemas));
160        let batch = RecordBatch::new(schema.clone(), vec![result_vector])
161            .context(error::BuildRecordBatchSnafu)?;
162        let batches =
163            RecordBatches::try_new(schema, vec![batch]).context(error::BuildRecordBatchSnafu)?;
164
165        Ok(Output::new_with_record_batches(batches))
166    }
167}
168
169/// Try to cast the arguments to vectors by function's signature.
170fn args_to_vector(
171    type_signature: &TypeSignature,
172    args: &Vec<&SqlValue>,
173    query_ctx: &QueryContextRef,
174) -> Result<Vec<VectorRef>> {
175    let tz = query_ctx.timezone();
176
177    match type_signature {
178        TypeSignature::Variadic(valid_types) => {
179            values_to_vectors_by_valid_types(valid_types, args, Some(&tz))
180        }
181
182        TypeSignature::Uniform(arity, valid_types) => {
183            ensure!(
184                *arity == args.len(),
185                error::FunctionArityMismatchSnafu {
186                    actual: args.len(),
187                    expected: *arity,
188                }
189            );
190
191            values_to_vectors_by_valid_types(valid_types, args, Some(&tz))
192        }
193
194        TypeSignature::Exact(data_types) => {
195            values_to_vectors_by_exact_types(data_types, args, Some(&tz))
196        }
197
198        TypeSignature::VariadicAny => {
199            let data_types = args
200                .iter()
201                .map(|value| try_get_data_type_for_sql_value(value))
202                .collect::<Result<Vec<_>>>()?;
203
204            values_to_vectors_by_exact_types(&data_types, args, Some(&tz))
205        }
206
207        TypeSignature::Any(arity) => {
208            ensure!(
209                *arity == args.len(),
210                error::FunctionArityMismatchSnafu {
211                    actual: args.len(),
212                    expected: *arity,
213                }
214            );
215
216            let data_types = args
217                .iter()
218                .map(|value| try_get_data_type_for_sql_value(value))
219                .collect::<Result<Vec<_>>>()?;
220
221            values_to_vectors_by_exact_types(&data_types, args, Some(&tz))
222        }
223
224        TypeSignature::OneOf(type_sigs) => {
225            for type_sig in type_sigs {
226                if let Ok(vectors) = args_to_vector(type_sig, args, query_ctx) {
227                    return Ok(vectors);
228                }
229            }
230
231            error::BuildAdminFunctionArgsSnafu {
232                msg: "function signature not match",
233            }
234            .fail()
235        }
236
237        TypeSignature::NullAry => Ok(vec![]),
238    }
239}
240
241/// Try to cast sql values to vectors by exact data types.
242fn values_to_vectors_by_exact_types(
243    exact_types: &[ConcreteDataType],
244    args: &[&SqlValue],
245    tz: Option<&Timezone>,
246) -> Result<Vec<VectorRef>> {
247    args.iter()
248        .zip(exact_types.iter())
249        .map(|(value, data_type)| {
250            let value = sql_value_to_value(DUMMY_COLUMN, data_type, value, tz, None, false)
251                .context(error::SqlCommonSnafu)?;
252
253            Ok(value_to_vector(value))
254        })
255        .collect()
256}
257
258/// Try to cast sql values to vectors by valid data types.
259fn values_to_vectors_by_valid_types(
260    valid_types: &[ConcreteDataType],
261    args: &[&SqlValue],
262    tz: Option<&Timezone>,
263) -> Result<Vec<VectorRef>> {
264    args.iter()
265        .map(|value| {
266            for data_type in valid_types {
267                if let Ok(value) =
268                    sql_value_to_value(DUMMY_COLUMN, data_type, value, tz, None, false)
269                {
270                    return Ok(value_to_vector(value));
271                }
272            }
273
274            error::BuildAdminFunctionArgsSnafu {
275                msg: format!("failed to cast {value}"),
276            }
277            .fail()
278        })
279        .collect::<Result<Vec<_>>>()
280}
281
282/// Build a [`VectorRef`] from [`Value`]
283fn value_to_vector(value: Value) -> VectorRef {
284    let data_type = value.data_type();
285    let mut mutable_vector = data_type.create_mutable_vector(1);
286    mutable_vector.push_value_ref(value.as_value_ref());
287
288    mutable_vector.to_vector()
289}
290
291/// Try to infer the data type from sql value.
292fn try_get_data_type_for_sql_value(value: &SqlValue) -> Result<ConcreteDataType> {
293    match value {
294        SqlValue::Number(_, _) => Ok(ConcreteDataType::float64_datatype()),
295        SqlValue::Null => Ok(ConcreteDataType::null_datatype()),
296        SqlValue::Boolean(_) => Ok(ConcreteDataType::boolean_datatype()),
297        SqlValue::HexStringLiteral(_)
298        | SqlValue::DoubleQuotedString(_)
299        | SqlValue::SingleQuotedString(_) => Ok(ConcreteDataType::string_datatype()),
300        _ => error::BuildAdminFunctionArgsSnafu {
301            msg: format!("unsupported sql value: {value}"),
302        }
303        .fail(),
304    }
305}