operator/statement/
admin.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_function::function::FunctionContext;
18use common_function::function_registry::FUNCTION_REGISTRY;
19use common_query::Output;
20use common_recordbatch::{RecordBatch, RecordBatches};
21use common_sql::convert::sql_value_to_value;
22use common_telemetry::tracing;
23use common_time::Timezone;
24use datafusion_expr::TypeSignature;
25use datatypes::arrow::datatypes::DataType as ArrowDataType;
26use datatypes::data_type::DataType;
27use datatypes::prelude::ConcreteDataType;
28use datatypes::schema::{ColumnSchema, Schema};
29use datatypes::value::Value;
30use datatypes::vectors::VectorRef;
31use session::context::QueryContextRef;
32use snafu::{OptionExt, ResultExt, ensure};
33use sql::ast::{Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Value as SqlValue};
34use sql::statements::admin::Admin;
35
36use crate::error::{self, CastSnafu, ExecuteAdminFunctionSnafu, Result};
37use crate::statement::StatementExecutor;
38
39const DUMMY_COLUMN: &str = "<dummy>";
40
41impl StatementExecutor {
42    /// Execute the [`Admin`] statement and returns the output.
43    #[tracing::instrument(skip_all)]
44    pub(super) async fn execute_admin_command(
45        &self,
46        stmt: Admin,
47        query_ctx: QueryContextRef,
48    ) -> Result<Output> {
49        let Admin::Func(func) = &stmt;
50        // the function name should be in lower case.
51        let func_name = func.name.to_string().to_lowercase();
52        let factory = FUNCTION_REGISTRY.get_function(&func_name).context(
53            error::AdminFunctionNotFoundSnafu {
54                name: func_name.clone(),
55            },
56        )?;
57
58        let func_ctx = FunctionContext {
59            query_ctx: query_ctx.clone(),
60            state: self.query_engine.engine_state().function_state(),
61        };
62
63        let admin_udf = factory.provide(func_ctx);
64        let admin_async_fn = admin_udf
65            .as_async()
66            .context(error::AdminFunctionNotFoundSnafu { name: func_name })?;
67
68        let fn_name = admin_udf.name();
69        let signature = admin_udf.signature();
70
71        // Parse function arguments
72        let FunctionArguments::List(args) = &func.args else {
73            return error::BuildAdminFunctionArgsSnafu {
74                msg: format!("unsupported function args {} for {}", func.args, fn_name),
75            }
76            .fail();
77        };
78        let arg_values = args
79            .args
80            .iter()
81            .map(|arg| {
82                let FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(value))) = arg else {
83                    return error::BuildAdminFunctionArgsSnafu {
84                        msg: format!("unsupported function arg {arg} for {}", fn_name),
85                    }
86                    .fail();
87                };
88                Ok(&value.value)
89            })
90            .collect::<Result<Vec<_>>>()?;
91
92        let args = args_to_vector(&signature.type_signature, &arg_values, &query_ctx)?;
93        let arg_types = args
94            .iter()
95            .map(|arg| arg.data_type().as_arrow_type())
96            .collect::<Vec<_>>();
97        let ret_type = admin_udf.return_type(&arg_types).map_err(|e| {
98            error::Error::BuildAdminFunctionArgs {
99                msg: format!(
100                    "Failed to get return type of admin function {}: {}",
101                    fn_name, e
102                ),
103            }
104        })?;
105
106        // Convert arguments to DataFusion ColumnarValue format
107        let columnar_args: Vec<datafusion_expr::ColumnarValue> = args
108            .iter()
109            .map(|vector| datafusion_expr::ColumnarValue::Array(vector.to_arrow_array()))
110            .collect();
111
112        // Create ScalarFunctionArgs following the same pattern as udf.rs
113        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
114            args: columnar_args,
115            arg_fields: args
116                .iter()
117                .enumerate()
118                .map(|(i, vector)| {
119                    Arc::new(arrow::datatypes::Field::new(
120                        format!("arg_{}", i),
121                        arg_types[i].clone(),
122                        vector.null_count() > 0,
123                    ))
124                })
125                .collect(),
126            return_field: Arc::new(arrow::datatypes::Field::new("result", ret_type, true)),
127            number_rows: if args.is_empty() { 1 } else { args[0].len() },
128            config_options: Arc::new(query_ctx.create_config_options()),
129        };
130
131        // Execute the async UDF
132        let result_columnar = admin_async_fn
133            .invoke_async_with_args(func_args)
134            .await
135            .with_context(|_| ExecuteAdminFunctionSnafu {
136                msg: format!("Failed to execute admin function {}", fn_name),
137            })?;
138
139        // Convert result back to VectorRef
140        let result_columnar: common_query::prelude::ColumnarValue =
141            (&result_columnar).try_into().context(CastSnafu)?;
142
143        let result_vector: VectorRef = result_columnar.try_into_vector(1).context(CastSnafu)?;
144
145        let column_schemas = vec![ColumnSchema::new(
146            // Use statement as the result column name
147            stmt.to_string(),
148            result_vector.data_type(),
149            false,
150        )];
151        let schema = Arc::new(Schema::new(column_schemas));
152        let batch = RecordBatch::new(schema.clone(), vec![result_vector])
153            .context(error::BuildRecordBatchSnafu)?;
154        let batches =
155            RecordBatches::try_new(schema, vec![batch]).context(error::BuildRecordBatchSnafu)?;
156
157        Ok(Output::new_with_record_batches(batches))
158    }
159}
160
161/// Try to cast the arguments to vectors by function's signature.
162fn args_to_vector(
163    type_signature: &TypeSignature,
164    args: &Vec<&SqlValue>,
165    query_ctx: &QueryContextRef,
166) -> Result<Vec<VectorRef>> {
167    let tz = query_ctx.timezone();
168
169    match type_signature {
170        TypeSignature::Variadic(valid_types) => {
171            values_to_vectors_by_valid_types(valid_types, args, Some(&tz))
172        }
173
174        TypeSignature::Uniform(arity, valid_types) => {
175            ensure!(
176                *arity == args.len(),
177                error::FunctionArityMismatchSnafu {
178                    actual: args.len(),
179                    expected: *arity,
180                }
181            );
182
183            values_to_vectors_by_valid_types(valid_types, args, Some(&tz))
184        }
185
186        TypeSignature::Exact(data_types) => {
187            values_to_vectors_by_exact_types(data_types, args, Some(&tz))
188        }
189
190        TypeSignature::VariadicAny => {
191            let data_types = args
192                .iter()
193                .map(|value| try_get_data_type_for_sql_value(value))
194                .collect::<Result<Vec<_>>>()?;
195
196            values_to_vectors_by_exact_types(&data_types, args, Some(&tz))
197        }
198
199        TypeSignature::Any(arity) => {
200            ensure!(
201                *arity == args.len(),
202                error::FunctionArityMismatchSnafu {
203                    actual: args.len(),
204                    expected: *arity,
205                }
206            );
207
208            let data_types = args
209                .iter()
210                .map(|value| try_get_data_type_for_sql_value(value))
211                .collect::<Result<Vec<_>>>()?;
212
213            values_to_vectors_by_exact_types(&data_types, args, Some(&tz))
214        }
215
216        TypeSignature::OneOf(type_sigs) => {
217            for type_sig in type_sigs {
218                if let Ok(vectors) = args_to_vector(type_sig, args, query_ctx) {
219                    return Ok(vectors);
220                }
221            }
222
223            error::BuildAdminFunctionArgsSnafu {
224                msg: "function signature not match",
225            }
226            .fail()
227        }
228
229        _ => error::BuildAdminFunctionArgsSnafu {
230            msg: format!("unknown function type signature: {type_signature:?}"),
231        }
232        .fail(),
233    }
234}
235
236/// Try to cast sql values to vectors by exact data types.
237fn values_to_vectors_by_exact_types(
238    exact_types: &[ArrowDataType],
239    args: &[&SqlValue],
240    tz: Option<&Timezone>,
241) -> Result<Vec<VectorRef>> {
242    args.iter()
243        .zip(exact_types.iter())
244        .map(|(value, data_type)| {
245            let data_type = &ConcreteDataType::from_arrow_type(data_type);
246            let value = sql_value_to_value(DUMMY_COLUMN, data_type, value, tz, None, false)
247                .context(error::SqlCommonSnafu)?;
248
249            Ok(value_to_vector(value))
250        })
251        .collect()
252}
253
254/// Try to cast sql values to vectors by valid data types.
255fn values_to_vectors_by_valid_types(
256    valid_types: &[ArrowDataType],
257    args: &[&SqlValue],
258    tz: Option<&Timezone>,
259) -> Result<Vec<VectorRef>> {
260    args.iter()
261        .map(|value| {
262            for data_type in valid_types {
263                let data_type = &ConcreteDataType::from_arrow_type(data_type);
264                if let Ok(value) =
265                    sql_value_to_value(DUMMY_COLUMN, data_type, value, tz, None, false)
266                {
267                    return Ok(value_to_vector(value));
268                }
269            }
270
271            error::BuildAdminFunctionArgsSnafu {
272                msg: format!("failed to cast {value}"),
273            }
274            .fail()
275        })
276        .collect::<Result<Vec<_>>>()
277}
278
279/// Build a [`VectorRef`] from [`Value`]
280fn value_to_vector(value: Value) -> VectorRef {
281    let data_type = value.data_type();
282    let mut mutable_vector = data_type.create_mutable_vector(1);
283    mutable_vector.push_value_ref(&value.as_value_ref());
284
285    mutable_vector.to_vector()
286}
287
288/// Try to infer the data type from sql value.
289fn try_get_data_type_for_sql_value(value: &SqlValue) -> Result<ArrowDataType> {
290    match value {
291        SqlValue::Number(_, _) => Ok(ArrowDataType::Float64),
292        SqlValue::Null => Ok(ArrowDataType::Null),
293        SqlValue::Boolean(_) => Ok(ArrowDataType::Boolean),
294        SqlValue::HexStringLiteral(_)
295        | SqlValue::DoubleQuotedString(_)
296        | SqlValue::SingleQuotedString(_) => Ok(ArrowDataType::Utf8),
297        _ => error::BuildAdminFunctionArgsSnafu {
298            msg: format!("unsupported sql value: {value}"),
299        }
300        .fail(),
301    }
302}