1use std::sync::Arc;
16
17use common_function::function::FunctionContext;
18use common_function::function_registry::FUNCTION_REGISTRY;
19use common_query::prelude::TypeSignature;
20use common_query::Output;
21use common_recordbatch::{RecordBatch, RecordBatches};
22use common_sql::convert::sql_value_to_value;
23use common_telemetry::tracing;
24use common_time::Timezone;
25use datatypes::data_type::DataType;
26use datatypes::prelude::ConcreteDataType;
27use datatypes::schema::{ColumnSchema, Schema};
28use datatypes::value::Value;
29use datatypes::vectors::VectorRef;
30use session::context::QueryContextRef;
31use snafu::{ensure, OptionExt, ResultExt};
32use sql::ast::{Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Value as SqlValue};
33use sql::statements::admin::Admin;
34
35use crate::error::{self, ExecuteAdminFunctionSnafu, IntoVectorsSnafu, Result};
36use crate::statement::StatementExecutor;
37
38const DUMMY_COLUMN: &str = "<dummy>";
39
40impl StatementExecutor {
41 #[tracing::instrument(skip_all)]
43 pub(super) async fn execute_admin_command(
44 &self,
45 stmt: Admin,
46 query_ctx: QueryContextRef,
47 ) -> Result<Output> {
48 let Admin::Func(func) = &stmt;
49 let func_name = func.name.to_string().to_lowercase();
51 let factory = FUNCTION_REGISTRY
52 .get_function(&func_name)
53 .context(error::AdminFunctionNotFoundSnafu { name: func_name })?;
54 let func_ctx = FunctionContext {
55 query_ctx: query_ctx.clone(),
56 state: self.query_engine.engine_state().function_state(),
57 };
58
59 let admin_udf = factory.provide(func_ctx);
60 let fn_name = admin_udf.name();
61 let signature = admin_udf.signature();
62
63 let FunctionArguments::List(args) = &func.args else {
65 return error::BuildAdminFunctionArgsSnafu {
66 msg: format!("unsupported function args {} for {}", func.args, fn_name),
67 }
68 .fail();
69 };
70 let arg_values = args
71 .args
72 .iter()
73 .map(|arg| {
74 let FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(value))) = arg else {
75 return error::BuildAdminFunctionArgsSnafu {
76 msg: format!("unsupported function arg {arg} for {}", fn_name),
77 }
78 .fail();
79 };
80 Ok(&value.value)
81 })
82 .collect::<Result<Vec<_>>>()?;
83
84 let type_sig = (&signature.type_signature).into();
85 let args = args_to_vector(&type_sig, &arg_values, &query_ctx)?;
86 let arg_types = args
87 .iter()
88 .map(|arg| arg.data_type().as_arrow_type())
89 .collect::<Vec<_>>();
90 let ret_type = admin_udf.return_type(&arg_types).map_err(|e| {
91 error::Error::BuildAdminFunctionArgs {
92 msg: format!(
93 "Failed to get return type of admin function {}: {}",
94 fn_name, e
95 ),
96 }
97 })?;
98
99 let columnar_args: Vec<datafusion_expr::ColumnarValue> = args
101 .iter()
102 .map(|vector| datafusion_expr::ColumnarValue::Array(vector.to_arrow_array()))
103 .collect();
104
105 let func_args = datafusion::logical_expr::ScalarFunctionArgs {
107 args: columnar_args,
108 arg_fields: args
109 .iter()
110 .enumerate()
111 .map(|(i, vector)| {
112 Arc::new(arrow::datatypes::Field::new(
113 format!("arg_{}", i),
114 arg_types[i].clone(),
115 vector.null_count() > 0,
116 ))
117 })
118 .collect(),
119 return_field: Arc::new(arrow::datatypes::Field::new("result", ret_type, true)),
120 number_rows: if args.is_empty() { 1 } else { args[0].len() },
121 config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
122 };
123
124 let result_columnar = admin_udf
126 .as_async()
127 .with_context(|| error::BuildAdminFunctionArgsSnafu {
128 msg: format!("Function {} is not async", fn_name),
129 })?
130 .invoke_async_with_args(func_args)
131 .await
132 .with_context(|_| ExecuteAdminFunctionSnafu {
133 msg: format!("Failed to execute admin function {}", fn_name),
134 })?;
135
136 let result = match result_columnar {
138 datafusion_expr::ColumnarValue::Array(array) => {
139 datatypes::vectors::Helper::try_into_vector(array).context(IntoVectorsSnafu)?
140 }
141 datafusion_expr::ColumnarValue::Scalar(scalar) => {
142 let array =
143 scalar
144 .to_array_of_size(1)
145 .with_context(|_| ExecuteAdminFunctionSnafu {
146 msg: format!("Failed to convert scalar to array for {}", fn_name),
147 })?;
148 datatypes::vectors::Helper::try_into_vector(array).context(IntoVectorsSnafu)?
149 }
150 };
151
152 let result_vector: VectorRef = result;
153 let column_schemas = vec![ColumnSchema::new(
154 stmt.to_string(),
156 result_vector.data_type(),
157 false,
158 )];
159 let schema = Arc::new(Schema::new(column_schemas));
160 let batch = RecordBatch::new(schema.clone(), vec![result_vector])
161 .context(error::BuildRecordBatchSnafu)?;
162 let batches =
163 RecordBatches::try_new(schema, vec![batch]).context(error::BuildRecordBatchSnafu)?;
164
165 Ok(Output::new_with_record_batches(batches))
166 }
167}
168
169fn args_to_vector(
171 type_signature: &TypeSignature,
172 args: &Vec<&SqlValue>,
173 query_ctx: &QueryContextRef,
174) -> Result<Vec<VectorRef>> {
175 let tz = query_ctx.timezone();
176
177 match type_signature {
178 TypeSignature::Variadic(valid_types) => {
179 values_to_vectors_by_valid_types(valid_types, args, Some(&tz))
180 }
181
182 TypeSignature::Uniform(arity, valid_types) => {
183 ensure!(
184 *arity == args.len(),
185 error::FunctionArityMismatchSnafu {
186 actual: args.len(),
187 expected: *arity,
188 }
189 );
190
191 values_to_vectors_by_valid_types(valid_types, args, Some(&tz))
192 }
193
194 TypeSignature::Exact(data_types) => {
195 values_to_vectors_by_exact_types(data_types, args, Some(&tz))
196 }
197
198 TypeSignature::VariadicAny => {
199 let data_types = args
200 .iter()
201 .map(|value| try_get_data_type_for_sql_value(value))
202 .collect::<Result<Vec<_>>>()?;
203
204 values_to_vectors_by_exact_types(&data_types, args, Some(&tz))
205 }
206
207 TypeSignature::Any(arity) => {
208 ensure!(
209 *arity == args.len(),
210 error::FunctionArityMismatchSnafu {
211 actual: args.len(),
212 expected: *arity,
213 }
214 );
215
216 let data_types = args
217 .iter()
218 .map(|value| try_get_data_type_for_sql_value(value))
219 .collect::<Result<Vec<_>>>()?;
220
221 values_to_vectors_by_exact_types(&data_types, args, Some(&tz))
222 }
223
224 TypeSignature::OneOf(type_sigs) => {
225 for type_sig in type_sigs {
226 if let Ok(vectors) = args_to_vector(type_sig, args, query_ctx) {
227 return Ok(vectors);
228 }
229 }
230
231 error::BuildAdminFunctionArgsSnafu {
232 msg: "function signature not match",
233 }
234 .fail()
235 }
236
237 TypeSignature::NullAry => Ok(vec![]),
238 }
239}
240
241fn values_to_vectors_by_exact_types(
243 exact_types: &[ConcreteDataType],
244 args: &[&SqlValue],
245 tz: Option<&Timezone>,
246) -> Result<Vec<VectorRef>> {
247 args.iter()
248 .zip(exact_types.iter())
249 .map(|(value, data_type)| {
250 let value = sql_value_to_value(DUMMY_COLUMN, data_type, value, tz, None, false)
251 .context(error::SqlCommonSnafu)?;
252
253 Ok(value_to_vector(value))
254 })
255 .collect()
256}
257
258fn values_to_vectors_by_valid_types(
260 valid_types: &[ConcreteDataType],
261 args: &[&SqlValue],
262 tz: Option<&Timezone>,
263) -> Result<Vec<VectorRef>> {
264 args.iter()
265 .map(|value| {
266 for data_type in valid_types {
267 if let Ok(value) =
268 sql_value_to_value(DUMMY_COLUMN, data_type, value, tz, None, false)
269 {
270 return Ok(value_to_vector(value));
271 }
272 }
273
274 error::BuildAdminFunctionArgsSnafu {
275 msg: format!("failed to cast {value}"),
276 }
277 .fail()
278 })
279 .collect::<Result<Vec<_>>>()
280}
281
282fn value_to_vector(value: Value) -> VectorRef {
284 let data_type = value.data_type();
285 let mut mutable_vector = data_type.create_mutable_vector(1);
286 mutable_vector.push_value_ref(value.as_value_ref());
287
288 mutable_vector.to_vector()
289}
290
291fn try_get_data_type_for_sql_value(value: &SqlValue) -> Result<ConcreteDataType> {
293 match value {
294 SqlValue::Number(_, _) => Ok(ConcreteDataType::float64_datatype()),
295 SqlValue::Null => Ok(ConcreteDataType::null_datatype()),
296 SqlValue::Boolean(_) => Ok(ConcreteDataType::boolean_datatype()),
297 SqlValue::HexStringLiteral(_)
298 | SqlValue::DoubleQuotedString(_)
299 | SqlValue::SingleQuotedString(_) => Ok(ConcreteDataType::string_datatype()),
300 _ => error::BuildAdminFunctionArgsSnafu {
301 msg: format!("unsupported sql value: {value}"),
302 }
303 .fail(),
304 }
305}