common_function/admin/
reconcile_catalog.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use api::v1::meta::reconcile_request::Target;
16use api::v1::meta::{ReconcileCatalog, ReconcileRequest};
17use arrow::datatypes::DataType as ArrowDataType;
18use common_macro::admin_fn;
19use common_query::error::{
20    InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
21    UnsupportedInputDataTypeSnafu,
22};
23use common_telemetry::info;
24use datafusion_expr::{Signature, TypeSignature, Volatility};
25use datatypes::data_type::DataType;
26use datatypes::prelude::*;
27use session::context::QueryContextRef;
28
29use crate::handlers::ProcedureServiceHandlerRef;
30use crate::helper::{
31    cast_u32, default_parallelism, default_resolve_strategy, get_string_from_params,
32    parse_resolve_strategy,
33};
34
35const FN_NAME: &str = "reconcile_catalog";
36
37/// A function to reconcile a catalog.
38/// Returns the procedure id if success.
39///
40/// - `reconcile_catalog(resolve_strategy)`.
41/// - `reconcile_catalog(resolve_strategy, parallelism)`.
42///
43/// - `reconcile_catalog()`.
44#[admin_fn(
45    name = ReconcileCatalogFunction,
46    display_name = reconcile_catalog,
47    sig_fn = signature,
48    ret = string
49)]
50pub(crate) async fn reconcile_catalog(
51    procedure_service_handler: &ProcedureServiceHandlerRef,
52    query_ctx: &QueryContextRef,
53    params: &[ValueRef<'_>],
54) -> Result<Value> {
55    let (resolve_strategy, parallelism) = match params.len() {
56        0 => (default_resolve_strategy(), default_parallelism()),
57        1 => (
58            parse_resolve_strategy(get_string_from_params(params, 0, FN_NAME)?)?,
59            default_parallelism(),
60        ),
61        2 => {
62            let Some(parallelism) = cast_u32(&params[1])? else {
63                return UnsupportedInputDataTypeSnafu {
64                    function: FN_NAME,
65                    datatypes: params.iter().map(|v| v.data_type()).collect::<Vec<_>>(),
66                }
67                .fail();
68            };
69            (
70                parse_resolve_strategy(get_string_from_params(params, 0, FN_NAME)?)?,
71                parallelism,
72            )
73        }
74        size => {
75            return InvalidFuncArgsSnafu {
76                err_msg: format!(
77                    "The length of the args is not correct, expect 0, 1 or 2, have: {}",
78                    size
79                ),
80            }
81            .fail();
82        }
83    };
84    info!(
85        "Reconciling catalog with resolve_strategy: {:?}, parallelism: {}",
86        resolve_strategy, parallelism
87    );
88    let pid = procedure_service_handler
89        .reconcile(ReconcileRequest {
90            target: Some(Target::ReconcileCatalog(ReconcileCatalog {
91                catalog_name: query_ctx.current_catalog().to_string(),
92                parallelism,
93                resolve_strategy: resolve_strategy as i32,
94            })),
95            ..Default::default()
96        })
97        .await?;
98    match pid {
99        Some(pid) => Ok(Value::from(pid)),
100        None => Ok(Value::Null),
101    }
102}
103
104fn signature() -> Signature {
105    let nums = ConcreteDataType::numerics();
106    let mut signs = Vec::with_capacity(2 + nums.len());
107    signs.extend([
108        // reconcile_catalog()
109        TypeSignature::Nullary,
110        // reconcile_catalog(resolve_strategy)
111        TypeSignature::Exact(vec![ArrowDataType::Utf8]),
112    ]);
113    for sign in nums {
114        // reconcile_catalog(resolve_strategy, parallelism)
115        signs.push(TypeSignature::Exact(vec![
116            ArrowDataType::Utf8,
117            sign.as_arrow_type(),
118        ]));
119    }
120    Signature::one_of(signs, Volatility::Immutable)
121}
122
123#[cfg(test)]
124mod tests {
125    use std::sync::Arc;
126
127    use arrow::array::{StringArray, UInt64Array};
128    use arrow::datatypes::{DataType, Field};
129    use datafusion_expr::ColumnarValue;
130
131    use crate::admin::reconcile_catalog::ReconcileCatalogFunction;
132    use crate::function::FunctionContext;
133    use crate::function_factory::ScalarFunctionFactory;
134
135    #[tokio::test]
136    async fn test_reconcile_catalog() {
137        common_telemetry::init_default_ut_logging();
138
139        // reconcile_catalog()
140        let factory: ScalarFunctionFactory = ReconcileCatalogFunction::factory().into();
141        let provider = factory.provide(FunctionContext::mock());
142        let f = provider.as_async().unwrap();
143
144        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
145            args: vec![],
146            arg_fields: vec![],
147            return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
148            number_rows: 1,
149            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
150        };
151
152        let result = f.invoke_async_with_args(func_args).await.unwrap();
153        match result {
154            ColumnarValue::Array(array) => {
155                let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
156                assert_eq!(result_array.value(0), "test_pid");
157            }
158            ColumnarValue::Scalar(scalar) => {
159                assert_eq!(
160                    scalar,
161                    datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
162                );
163            }
164        }
165
166        // reconcile_catalog(resolve_strategy)
167        let factory: ScalarFunctionFactory = ReconcileCatalogFunction::factory().into();
168        let provider = factory.provide(FunctionContext::mock());
169        let f = provider.as_async().unwrap();
170
171        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
172            args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
173                "UseMetasrv",
174            ])))],
175            arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
176            return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
177            number_rows: 1,
178            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
179        };
180        let result = f.invoke_async_with_args(func_args).await.unwrap();
181        match result {
182            ColumnarValue::Array(array) => {
183                let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
184                assert_eq!(result_array.value(0), "test_pid");
185            }
186            ColumnarValue::Scalar(scalar) => {
187                assert_eq!(
188                    scalar,
189                    datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
190                );
191            }
192        }
193
194        // reconcile_catalog(resolve_strategy, parallelism)
195        let factory: ScalarFunctionFactory = ReconcileCatalogFunction::factory().into();
196        let provider = factory.provide(FunctionContext::mock());
197        let f = provider.as_async().unwrap();
198
199        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
200            args: vec![
201                ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
202                ColumnarValue::Array(Arc::new(UInt64Array::from(vec![10]))),
203            ],
204            arg_fields: vec![
205                Arc::new(Field::new("arg_0", DataType::Utf8, false)),
206                Arc::new(Field::new("arg_1", DataType::UInt64, false)),
207            ],
208            return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
209            number_rows: 1,
210            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
211        };
212        let result = f.invoke_async_with_args(func_args).await.unwrap();
213        match result {
214            ColumnarValue::Array(array) => {
215                let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
216                assert_eq!(result_array.value(0), "test_pid");
217            }
218            ColumnarValue::Scalar(scalar) => {
219                assert_eq!(
220                    scalar,
221                    datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
222                );
223            }
224        }
225
226        // unsupported input data type
227        let factory: ScalarFunctionFactory = ReconcileCatalogFunction::factory().into();
228        let provider = factory.provide(FunctionContext::mock());
229        let f = provider.as_async().unwrap();
230
231        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
232            args: vec![
233                ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
234                ColumnarValue::Array(Arc::new(StringArray::from(vec!["test"]))),
235            ],
236            arg_fields: vec![
237                Arc::new(Field::new("arg_0", DataType::Utf8, false)),
238                Arc::new(Field::new("arg_1", DataType::Utf8, false)),
239            ],
240            return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
241            number_rows: 1,
242            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
243        };
244        let _err = f.invoke_async_with_args(func_args).await.unwrap_err();
245        // Note: Error type is DataFusionError at this level, not common_query::Error
246
247        // invalid function args
248        let factory: ScalarFunctionFactory = ReconcileCatalogFunction::factory().into();
249        let provider = factory.provide(FunctionContext::mock());
250        let f = provider.as_async().unwrap();
251
252        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
253            args: vec![
254                ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
255                ColumnarValue::Array(Arc::new(UInt64Array::from(vec![10]))),
256                ColumnarValue::Array(Arc::new(StringArray::from(vec!["10"]))),
257            ],
258            arg_fields: vec![
259                Arc::new(Field::new("arg_0", DataType::Utf8, false)),
260                Arc::new(Field::new("arg_1", DataType::UInt64, false)),
261                Arc::new(Field::new("arg_2", DataType::Utf8, false)),
262            ],
263            return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
264            number_rows: 1,
265            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
266        };
267        let _err = f.invoke_async_with_args(func_args).await.unwrap_err();
268        // Note: Error type is DataFusionError at this level, not common_query::Error
269    }
270}