common_function/admin/
reconcile_database.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use api::v1::meta::reconcile_request::Target;
16use api::v1::meta::{ReconcileDatabase, ReconcileRequest};
17use arrow::datatypes::DataType as ArrowDataType;
18use common_macro::admin_fn;
19use common_query::error::{
20    InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
21    UnsupportedInputDataTypeSnafu,
22};
23use common_telemetry::info;
24use datafusion_expr::{Signature, TypeSignature, Volatility};
25use datatypes::data_type::DataType;
26use datatypes::prelude::*;
27use session::context::QueryContextRef;
28
29use crate::handlers::ProcedureServiceHandlerRef;
30use crate::helper::{
31    cast_u32, default_parallelism, default_resolve_strategy, get_string_from_params,
32    parse_resolve_strategy,
33};
34
35const FN_NAME: &str = "reconcile_database";
36
37/// A function to reconcile a database.
38/// Returns the procedure id if success.
39///
40/// - `reconcile_database(database_name)`.
41/// - `reconcile_database(database_name, resolve_strategy)`.
42/// - `reconcile_database(database_name, resolve_strategy, parallelism)`.
43///
44/// The parameters:
45/// - `database_name`:  the database name
46#[admin_fn(
47    name = ReconcileDatabaseFunction,
48    display_name = reconcile_database,
49    sig_fn = signature,
50    ret = string
51)]
52pub(crate) async fn reconcile_database(
53    procedure_service_handler: &ProcedureServiceHandlerRef,
54    query_ctx: &QueryContextRef,
55    params: &[ValueRef<'_>],
56) -> Result<Value> {
57    let (database_name, resolve_strategy, parallelism) = match params.len() {
58        1 => (
59            get_string_from_params(params, 0, FN_NAME)?,
60            default_resolve_strategy(),
61            default_parallelism(),
62        ),
63        2 => (
64            get_string_from_params(params, 0, FN_NAME)?,
65            parse_resolve_strategy(get_string_from_params(params, 1, FN_NAME)?)?,
66            default_parallelism(),
67        ),
68        3 => {
69            let Some(parallelism) = cast_u32(&params[2])? else {
70                return UnsupportedInputDataTypeSnafu {
71                    function: FN_NAME,
72                    datatypes: params.iter().map(|v| v.data_type()).collect::<Vec<_>>(),
73                }
74                .fail();
75            };
76            (
77                get_string_from_params(params, 0, FN_NAME)?,
78                parse_resolve_strategy(get_string_from_params(params, 1, FN_NAME)?)?,
79                parallelism,
80            )
81        }
82        size => {
83            return InvalidFuncArgsSnafu {
84                err_msg: format!(
85                    "The length of the args is not correct, expect 1, 2 or 3, have: {}",
86                    size
87                ),
88            }
89            .fail();
90        }
91    };
92    info!(
93        "Reconciling database: {}, resolve_strategy: {:?}, parallelism: {}",
94        database_name, resolve_strategy, parallelism
95    );
96    let pid = procedure_service_handler
97        .reconcile(ReconcileRequest {
98            target: Some(Target::ReconcileDatabase(ReconcileDatabase {
99                catalog_name: query_ctx.current_catalog().to_string(),
100                database_name: database_name.to_string(),
101                parallelism,
102                resolve_strategy: resolve_strategy as i32,
103            })),
104            ..Default::default()
105        })
106        .await?;
107    match pid {
108        Some(pid) => Ok(Value::from(pid)),
109        None => Ok(Value::Null),
110    }
111}
112
113fn signature() -> Signature {
114    let nums = ConcreteDataType::numerics();
115    let mut signs = Vec::with_capacity(2 + nums.len());
116    signs.extend([
117        // reconcile_database(datanode_name)
118        TypeSignature::Exact(vec![ArrowDataType::Utf8]),
119        // reconcile_database(database_name, resolve_strategy)
120        TypeSignature::Exact(vec![ArrowDataType::Utf8, ArrowDataType::Utf8]),
121    ]);
122    for sign in nums {
123        // reconcile_database(database_name, resolve_strategy, parallelism)
124        signs.push(TypeSignature::Exact(vec![
125            ArrowDataType::Utf8,
126            ArrowDataType::Utf8,
127            sign.as_arrow_type(),
128        ]));
129    }
130    Signature::one_of(signs, Volatility::Immutable)
131}
132
133#[cfg(test)]
134mod tests {
135    use std::sync::Arc;
136
137    use arrow::array::{StringArray, UInt32Array};
138    use arrow::datatypes::{DataType, Field};
139    use datafusion_expr::ColumnarValue;
140
141    use crate::admin::reconcile_database::ReconcileDatabaseFunction;
142    use crate::function::FunctionContext;
143    use crate::function_factory::ScalarFunctionFactory;
144
145    #[tokio::test]
146    async fn test_reconcile_catalog() {
147        common_telemetry::init_default_ut_logging();
148
149        // reconcile_database(database_name)
150        let factory: ScalarFunctionFactory = ReconcileDatabaseFunction::factory().into();
151        let provider = factory.provide(FunctionContext::mock());
152        let f = provider.as_async().unwrap();
153
154        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
155            args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
156                "test",
157            ])))],
158            arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
159            return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
160            number_rows: 1,
161            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
162        };
163        let result = f.invoke_async_with_args(func_args).await.unwrap();
164        match result {
165            ColumnarValue::Array(array) => {
166                let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
167                assert_eq!(result_array.value(0), "test_pid");
168            }
169            ColumnarValue::Scalar(scalar) => {
170                assert_eq!(
171                    scalar,
172                    datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
173                );
174            }
175        }
176
177        // reconcile_database(database_name, resolve_strategy)
178        let factory: ScalarFunctionFactory = ReconcileDatabaseFunction::factory().into();
179        let provider = factory.provide(FunctionContext::mock());
180        let f = provider.as_async().unwrap();
181
182        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
183            args: vec![
184                ColumnarValue::Array(Arc::new(StringArray::from(vec!["test"]))),
185                ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
186            ],
187            arg_fields: vec![
188                Arc::new(Field::new("arg_0", DataType::Utf8, false)),
189                Arc::new(Field::new("arg_1", DataType::Utf8, false)),
190            ],
191            return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
192            number_rows: 1,
193            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
194        };
195        let result = f.invoke_async_with_args(func_args).await.unwrap();
196        match result {
197            ColumnarValue::Array(array) => {
198                let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
199                assert_eq!(result_array.value(0), "test_pid");
200            }
201            ColumnarValue::Scalar(scalar) => {
202                assert_eq!(
203                    scalar,
204                    datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
205                );
206            }
207        }
208
209        // reconcile_database(database_name, resolve_strategy, parallelism)
210        let factory: ScalarFunctionFactory = ReconcileDatabaseFunction::factory().into();
211        let provider = factory.provide(FunctionContext::mock());
212        let f = provider.as_async().unwrap();
213
214        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
215            args: vec![
216                ColumnarValue::Array(Arc::new(StringArray::from(vec!["test"]))),
217                ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
218                ColumnarValue::Array(Arc::new(UInt32Array::from(vec![10]))),
219            ],
220            arg_fields: vec![
221                Arc::new(Field::new("arg_0", DataType::Utf8, false)),
222                Arc::new(Field::new("arg_1", DataType::Utf8, false)),
223                Arc::new(Field::new("arg_2", DataType::UInt32, false)),
224            ],
225            return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
226            number_rows: 1,
227            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
228        };
229        let result = f.invoke_async_with_args(func_args).await.unwrap();
230        match result {
231            ColumnarValue::Array(array) => {
232                let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
233                assert_eq!(result_array.value(0), "test_pid");
234            }
235            ColumnarValue::Scalar(scalar) => {
236                assert_eq!(
237                    scalar,
238                    datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
239                );
240            }
241        }
242
243        // invalid function args
244        let factory: ScalarFunctionFactory = ReconcileDatabaseFunction::factory().into();
245        let provider = factory.provide(FunctionContext::mock());
246        let f = provider.as_async().unwrap();
247
248        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
249            args: vec![
250                ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
251                ColumnarValue::Array(Arc::new(UInt32Array::from(vec![10]))),
252                ColumnarValue::Array(Arc::new(StringArray::from(vec!["v1"]))),
253                ColumnarValue::Array(Arc::new(StringArray::from(vec!["v2"]))),
254            ],
255            arg_fields: vec![
256                Arc::new(Field::new("arg_0", DataType::Utf8, false)),
257                Arc::new(Field::new("arg_1", DataType::UInt32, false)),
258                Arc::new(Field::new("arg_2", DataType::Utf8, false)),
259                Arc::new(Field::new("arg_3", DataType::Utf8, false)),
260            ],
261            return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
262            number_rows: 1,
263            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
264        };
265        let _err = f.invoke_async_with_args(func_args).await.unwrap_err();
266        // Note: Error type is DataFusionError at this level, not common_query::Error
267
268        // unsupported input data type
269        let factory: ScalarFunctionFactory = ReconcileDatabaseFunction::factory().into();
270        let provider = factory.provide(FunctionContext::mock());
271        let f = provider.as_async().unwrap();
272
273        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
274            args: vec![
275                ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
276                ColumnarValue::Array(Arc::new(UInt32Array::from(vec![10]))),
277                ColumnarValue::Array(Arc::new(StringArray::from(vec!["v1"]))),
278            ],
279            arg_fields: vec![
280                Arc::new(Field::new("arg_0", DataType::Utf8, false)),
281                Arc::new(Field::new("arg_1", DataType::UInt32, false)),
282                Arc::new(Field::new("arg_2", DataType::Utf8, false)),
283            ],
284            return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
285            number_rows: 1,
286            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
287        };
288        let _err = f.invoke_async_with_args(func_args).await.unwrap_err();
289        // Note: Error type is DataFusionError at this level, not common_query::Error
290    }
291}