common_function/admin/
remove_region_follower.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use common_macro::admin_fn;
16use common_meta::rpc::procedure::RemoveRegionFollowerRequest;
17use common_query::error::{
18    InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
19    UnsupportedInputDataTypeSnafu,
20};
21use datafusion_expr::{Signature, TypeSignature, Volatility};
22use datatypes::data_type::DataType;
23use datatypes::prelude::ConcreteDataType;
24use datatypes::value::{Value, ValueRef};
25use session::context::QueryContextRef;
26use snafu::ensure;
27
28use crate::handlers::ProcedureServiceHandlerRef;
29use crate::helper::cast_u64;
30
31/// A function to remove a follower from a region.
32//// Only available in cluster mode.
33///
34/// - `remove_region_follower(region_id, peer_id)`.
35///
36/// The parameters:
37/// - `region_id`:  the region id
38/// - `peer_id`:  the peer id
39#[admin_fn(
40    name = RemoveRegionFollowerFunction,
41    display_name = remove_region_follower,
42    sig_fn = signature,
43    ret = uint64
44)]
45pub(crate) async fn remove_region_follower(
46    procedure_service_handler: &ProcedureServiceHandlerRef,
47    _ctx: &QueryContextRef,
48    params: &[ValueRef<'_>],
49) -> Result<Value> {
50    ensure!(
51        params.len() == 2,
52        InvalidFuncArgsSnafu {
53            err_msg: format!(
54                "The length of the args is not correct, expect exactly 2, have: {}",
55                params.len()
56            ),
57        }
58    );
59
60    let Some(region_id) = cast_u64(&params[0])? else {
61        return UnsupportedInputDataTypeSnafu {
62            function: "add_region_follower",
63            datatypes: params.iter().map(|v| v.data_type()).collect::<Vec<_>>(),
64        }
65        .fail();
66    };
67    let Some(peer_id) = cast_u64(&params[1])? else {
68        return UnsupportedInputDataTypeSnafu {
69            function: "add_region_follower",
70            datatypes: params.iter().map(|v| v.data_type()).collect::<Vec<_>>(),
71        }
72        .fail();
73    };
74
75    procedure_service_handler
76        .remove_region_follower(RemoveRegionFollowerRequest { region_id, peer_id })
77        .await?;
78
79    Ok(Value::from(0u64))
80}
81
82fn signature() -> Signature {
83    Signature::one_of(
84        vec![
85            // remove_region_follower(region_id, peer_id)
86            TypeSignature::Uniform(
87                2,
88                ConcreteDataType::numerics()
89                    .into_iter()
90                    .map(|dt| dt.as_arrow_type())
91                    .collect(),
92            ),
93        ],
94        Volatility::Immutable,
95    )
96}
97
98#[cfg(test)]
99mod tests {
100    use std::sync::Arc;
101
102    use arrow::array::UInt64Array;
103    use arrow::datatypes::{DataType, Field};
104    use datafusion_expr::ColumnarValue;
105
106    use super::*;
107    use crate::function::FunctionContext;
108    use crate::function_factory::ScalarFunctionFactory;
109
110    #[test]
111    fn test_remove_region_follower_misc() {
112        let factory: ScalarFunctionFactory = RemoveRegionFollowerFunction::factory().into();
113        let f = factory.provide(FunctionContext::mock());
114        assert_eq!("remove_region_follower", f.name());
115        assert_eq!(DataType::UInt64, f.return_type(&[]).unwrap());
116        assert!(matches!(f.signature(),
117                         datafusion_expr::Signature {
118                             type_signature: datafusion_expr::TypeSignature::OneOf(sigs),
119                             volatility: datafusion_expr::Volatility::Immutable
120                         } if sigs.len() == 1));
121    }
122
123    #[tokio::test]
124    async fn test_remove_region_follower() {
125        let factory: ScalarFunctionFactory = RemoveRegionFollowerFunction::factory().into();
126        let provider = factory.provide(FunctionContext::mock());
127        let f = provider.as_async().unwrap();
128
129        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
130            args: vec![
131                ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
132                ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
133            ],
134            arg_fields: vec![
135                Arc::new(Field::new("arg_0", DataType::UInt64, false)),
136                Arc::new(Field::new("arg_1", DataType::UInt64, false)),
137            ],
138            return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
139            number_rows: 1,
140            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
141        };
142
143        let result = f.invoke_async_with_args(func_args).await.unwrap();
144
145        match result {
146            ColumnarValue::Array(array) => {
147                let result_array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
148                assert_eq!(result_array.value(0), 0u64);
149            }
150            ColumnarValue::Scalar(scalar) => {
151                assert_eq!(scalar, datafusion_common::ScalarValue::UInt64(Some(0)));
152            }
153        }
154    }
155}