operator/statement/
admin.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_function::function::FunctionContext;
18use common_function::function_registry::FUNCTION_REGISTRY;
19use common_query::Output;
20use common_recordbatch::{RecordBatch, RecordBatches};
21use common_sql::convert::sql_value_to_value;
22use common_telemetry::tracing;
23use common_time::Timezone;
24use datafusion_expr::TypeSignature;
25use datatypes::arrow::datatypes::DataType as ArrowDataType;
26use datatypes::data_type::DataType;
27use datatypes::prelude::ConcreteDataType;
28use datatypes::schema::{ColumnSchema, Schema};
29use datatypes::value::Value;
30use datatypes::vectors::VectorRef;
31use session::context::QueryContextRef;
32use snafu::{OptionExt, ResultExt, ensure};
33use sql::ast::{Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Value as SqlValue};
34use sql::statements::admin::Admin;
35
36use crate::error::{self, CastSnafu, ExecuteAdminFunctionSnafu, Result};
37use crate::statement::StatementExecutor;
38
39const DUMMY_COLUMN: &str = "<dummy>";
40
41impl StatementExecutor {
42    /// Execute the [`Admin`] statement and returns the output.
43    #[tracing::instrument(skip_all)]
44    pub(super) async fn execute_admin_command(
45        &self,
46        stmt: Admin,
47        query_ctx: QueryContextRef,
48    ) -> Result<Output> {
49        let Admin::Func(func) = &stmt;
50        // the function name should be in lower case.
51        let func_name = func.name.to_string().to_lowercase();
52        let factory = FUNCTION_REGISTRY
53            .get_function(&func_name)
54            .context(error::AdminFunctionNotFoundSnafu { name: func_name })?;
55        let func_ctx = FunctionContext {
56            query_ctx: query_ctx.clone(),
57            state: self.query_engine.engine_state().function_state(),
58        };
59
60        let admin_udf = factory.provide(func_ctx);
61        let fn_name = admin_udf.name();
62        let signature = admin_udf.signature();
63
64        // Parse function arguments
65        let FunctionArguments::List(args) = &func.args else {
66            return error::BuildAdminFunctionArgsSnafu {
67                msg: format!("unsupported function args {} for {}", func.args, fn_name),
68            }
69            .fail();
70        };
71        let arg_values = args
72            .args
73            .iter()
74            .map(|arg| {
75                let FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(value))) = arg else {
76                    return error::BuildAdminFunctionArgsSnafu {
77                        msg: format!("unsupported function arg {arg} for {}", fn_name),
78                    }
79                    .fail();
80                };
81                Ok(&value.value)
82            })
83            .collect::<Result<Vec<_>>>()?;
84
85        let args = args_to_vector(&signature.type_signature, &arg_values, &query_ctx)?;
86        let arg_types = args
87            .iter()
88            .map(|arg| arg.data_type().as_arrow_type())
89            .collect::<Vec<_>>();
90        let ret_type = admin_udf.return_type(&arg_types).map_err(|e| {
91            error::Error::BuildAdminFunctionArgs {
92                msg: format!(
93                    "Failed to get return type of admin function {}: {}",
94                    fn_name, e
95                ),
96            }
97        })?;
98
99        // Convert arguments to DataFusion ColumnarValue format
100        let columnar_args: Vec<datafusion_expr::ColumnarValue> = args
101            .iter()
102            .map(|vector| datafusion_expr::ColumnarValue::Array(vector.to_arrow_array()))
103            .collect();
104
105        // Create ScalarFunctionArgs following the same pattern as udf.rs
106        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
107            args: columnar_args,
108            arg_fields: args
109                .iter()
110                .enumerate()
111                .map(|(i, vector)| {
112                    Arc::new(arrow::datatypes::Field::new(
113                        format!("arg_{}", i),
114                        arg_types[i].clone(),
115                        vector.null_count() > 0,
116                    ))
117                })
118                .collect(),
119            return_field: Arc::new(arrow::datatypes::Field::new("result", ret_type, true)),
120            number_rows: if args.is_empty() { 1 } else { args[0].len() },
121            config_options: Arc::new(query_ctx.create_config_options()),
122        };
123
124        // Execute the async UDF
125        let result_columnar = admin_udf
126            .as_async()
127            .with_context(|| error::BuildAdminFunctionArgsSnafu {
128                msg: format!("Function {} is not async", fn_name),
129            })?
130            .invoke_async_with_args(func_args)
131            .await
132            .with_context(|_| ExecuteAdminFunctionSnafu {
133                msg: format!("Failed to execute admin function {}", fn_name),
134            })?;
135
136        // Convert result back to VectorRef
137        let result_columnar: common_query::prelude::ColumnarValue =
138            (&result_columnar).try_into().context(CastSnafu)?;
139
140        let result_vector: VectorRef = result_columnar.try_into_vector(1).context(CastSnafu)?;
141
142        let column_schemas = vec![ColumnSchema::new(
143            // Use statement as the result column name
144            stmt.to_string(),
145            result_vector.data_type(),
146            false,
147        )];
148        let schema = Arc::new(Schema::new(column_schemas));
149        let batch = RecordBatch::new(schema.clone(), vec![result_vector])
150            .context(error::BuildRecordBatchSnafu)?;
151        let batches =
152            RecordBatches::try_new(schema, vec![batch]).context(error::BuildRecordBatchSnafu)?;
153
154        Ok(Output::new_with_record_batches(batches))
155    }
156}
157
158/// Try to cast the arguments to vectors by function's signature.
159fn args_to_vector(
160    type_signature: &TypeSignature,
161    args: &Vec<&SqlValue>,
162    query_ctx: &QueryContextRef,
163) -> Result<Vec<VectorRef>> {
164    let tz = query_ctx.timezone();
165
166    match type_signature {
167        TypeSignature::Variadic(valid_types) => {
168            values_to_vectors_by_valid_types(valid_types, args, Some(&tz))
169        }
170
171        TypeSignature::Uniform(arity, valid_types) => {
172            ensure!(
173                *arity == args.len(),
174                error::FunctionArityMismatchSnafu {
175                    actual: args.len(),
176                    expected: *arity,
177                }
178            );
179
180            values_to_vectors_by_valid_types(valid_types, args, Some(&tz))
181        }
182
183        TypeSignature::Exact(data_types) => {
184            values_to_vectors_by_exact_types(data_types, args, Some(&tz))
185        }
186
187        TypeSignature::VariadicAny => {
188            let data_types = args
189                .iter()
190                .map(|value| try_get_data_type_for_sql_value(value))
191                .collect::<Result<Vec<_>>>()?;
192
193            values_to_vectors_by_exact_types(&data_types, args, Some(&tz))
194        }
195
196        TypeSignature::Any(arity) => {
197            ensure!(
198                *arity == args.len(),
199                error::FunctionArityMismatchSnafu {
200                    actual: args.len(),
201                    expected: *arity,
202                }
203            );
204
205            let data_types = args
206                .iter()
207                .map(|value| try_get_data_type_for_sql_value(value))
208                .collect::<Result<Vec<_>>>()?;
209
210            values_to_vectors_by_exact_types(&data_types, args, Some(&tz))
211        }
212
213        TypeSignature::OneOf(type_sigs) => {
214            for type_sig in type_sigs {
215                if let Ok(vectors) = args_to_vector(type_sig, args, query_ctx) {
216                    return Ok(vectors);
217                }
218            }
219
220            error::BuildAdminFunctionArgsSnafu {
221                msg: "function signature not match",
222            }
223            .fail()
224        }
225
226        _ => error::BuildAdminFunctionArgsSnafu {
227            msg: format!("unknown function type signature: {type_signature:?}"),
228        }
229        .fail(),
230    }
231}
232
233/// Try to cast sql values to vectors by exact data types.
234fn values_to_vectors_by_exact_types(
235    exact_types: &[ArrowDataType],
236    args: &[&SqlValue],
237    tz: Option<&Timezone>,
238) -> Result<Vec<VectorRef>> {
239    args.iter()
240        .zip(exact_types.iter())
241        .map(|(value, data_type)| {
242            let data_type = &ConcreteDataType::from_arrow_type(data_type);
243            let value = sql_value_to_value(DUMMY_COLUMN, data_type, value, tz, None, false)
244                .context(error::SqlCommonSnafu)?;
245
246            Ok(value_to_vector(value))
247        })
248        .collect()
249}
250
251/// Try to cast sql values to vectors by valid data types.
252fn values_to_vectors_by_valid_types(
253    valid_types: &[ArrowDataType],
254    args: &[&SqlValue],
255    tz: Option<&Timezone>,
256) -> Result<Vec<VectorRef>> {
257    args.iter()
258        .map(|value| {
259            for data_type in valid_types {
260                let data_type = &ConcreteDataType::from_arrow_type(data_type);
261                if let Ok(value) =
262                    sql_value_to_value(DUMMY_COLUMN, data_type, value, tz, None, false)
263                {
264                    return Ok(value_to_vector(value));
265                }
266            }
267
268            error::BuildAdminFunctionArgsSnafu {
269                msg: format!("failed to cast {value}"),
270            }
271            .fail()
272        })
273        .collect::<Result<Vec<_>>>()
274}
275
276/// Build a [`VectorRef`] from [`Value`]
277fn value_to_vector(value: Value) -> VectorRef {
278    let data_type = value.data_type();
279    let mut mutable_vector = data_type.create_mutable_vector(1);
280    mutable_vector.push_value_ref(value.as_value_ref());
281
282    mutable_vector.to_vector()
283}
284
285/// Try to infer the data type from sql value.
286fn try_get_data_type_for_sql_value(value: &SqlValue) -> Result<ArrowDataType> {
287    match value {
288        SqlValue::Number(_, _) => Ok(ArrowDataType::Float64),
289        SqlValue::Null => Ok(ArrowDataType::Null),
290        SqlValue::Boolean(_) => Ok(ArrowDataType::Boolean),
291        SqlValue::HexStringLiteral(_)
292        | SqlValue::DoubleQuotedString(_)
293        | SqlValue::SingleQuotedString(_) => Ok(ArrowDataType::Utf8),
294        _ => error::BuildAdminFunctionArgsSnafu {
295            msg: format!("unsupported sql value: {value}"),
296        }
297        .fail(),
298    }
299}