common_function/admin/
migrate_region.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::time::Duration;
16
17use common_macro::admin_fn;
18use common_meta::rpc::procedure::MigrateRegionRequest;
19use common_query::error::{InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result};
20use datafusion_expr::{Signature, TypeSignature, Volatility};
21use datatypes::data_type::DataType;
22use datatypes::prelude::ConcreteDataType;
23use datatypes::value::{Value, ValueRef};
24use session::context::QueryContextRef;
25
26use crate::handlers::ProcedureServiceHandlerRef;
27use crate::helper::cast_u64;
28
29/// The default timeout for migrate region procedure.
30const DEFAULT_TIMEOUT_SECS: u64 = 300;
31
32/// A function to migrate a region from source peer to target peer.
33/// Returns the submitted procedure id if success. Only available in cluster mode.
34///
35/// - `migrate_region(region_id, from_peer, to_peer)`, with timeout(300 seconds).
36/// - `migrate_region(region_id, from_peer, to_peer, timeout(secs))`.
37///
38/// The parameters:
39/// - `region_id`:  the region id
40/// - `from_peer`:  the source peer id
41/// - `to_peer`:  the target peer id
42#[admin_fn(
43    name = MigrateRegionFunction,
44    display_name = migrate_region,
45    sig_fn = signature,
46    ret = string
47)]
48pub(crate) async fn migrate_region(
49    procedure_service_handler: &ProcedureServiceHandlerRef,
50    _ctx: &QueryContextRef,
51    params: &[ValueRef<'_>],
52) -> Result<Value> {
53    let (region_id, from_peer, to_peer, timeout) = match params.len() {
54        3 => {
55            let region_id = cast_u64(&params[0])?;
56            let from_peer = cast_u64(&params[1])?;
57            let to_peer = cast_u64(&params[2])?;
58
59            (region_id, from_peer, to_peer, Some(DEFAULT_TIMEOUT_SECS))
60        }
61
62        4 => {
63            let region_id = cast_u64(&params[0])?;
64            let from_peer = cast_u64(&params[1])?;
65            let to_peer = cast_u64(&params[2])?;
66            let replay_timeout = cast_u64(&params[3])?;
67
68            (region_id, from_peer, to_peer, replay_timeout)
69        }
70
71        size => {
72            return InvalidFuncArgsSnafu {
73                err_msg: format!(
74                    "The length of the args is not correct, expect exactly 3 or 4, have: {}",
75                    size
76                ),
77            }
78            .fail();
79        }
80    };
81
82    match (region_id, from_peer, to_peer, timeout) {
83        (Some(region_id), Some(from_peer), Some(to_peer), Some(timeout)) => {
84            let pid = procedure_service_handler
85                .migrate_region(MigrateRegionRequest {
86                    region_id,
87                    from_peer,
88                    to_peer,
89                    timeout: Duration::from_secs(timeout),
90                })
91                .await?;
92
93            match pid {
94                Some(pid) => Ok(Value::from(pid)),
95                None => Ok(Value::Null),
96            }
97        }
98
99        _ => Ok(Value::Null),
100    }
101}
102
103fn signature() -> Signature {
104    Signature::one_of(
105        vec![
106            // migrate_region(region_id, from_peer, to_peer)
107            TypeSignature::Uniform(
108                3,
109                ConcreteDataType::numerics()
110                    .into_iter()
111                    .map(|dt| dt.as_arrow_type())
112                    .collect(),
113            ),
114            // migrate_region(region_id, from_peer, to_peer, timeout(secs))
115            TypeSignature::Uniform(
116                4,
117                ConcreteDataType::numerics()
118                    .into_iter()
119                    .map(|dt| dt.as_arrow_type())
120                    .collect(),
121            ),
122        ],
123        Volatility::Immutable,
124    )
125}
126
127#[cfg(test)]
128mod tests {
129    use std::sync::Arc;
130
131    use arrow::array::{StringArray, UInt64Array};
132    use arrow::datatypes::{DataType, Field};
133    use datafusion_expr::ColumnarValue;
134
135    use super::*;
136    use crate::function::FunctionContext;
137    use crate::function_factory::ScalarFunctionFactory;
138
139    #[test]
140    fn test_migrate_region_misc() {
141        let factory: ScalarFunctionFactory = MigrateRegionFunction::factory().into();
142        let f = factory.provide(FunctionContext::mock());
143        assert_eq!("migrate_region", f.name());
144        assert_eq!(DataType::Utf8, f.return_type(&[]).unwrap());
145        assert!(matches!(f.signature(),
146                         datafusion_expr::Signature {
147                             type_signature: datafusion_expr::TypeSignature::OneOf(sigs),
148                             volatility: datafusion_expr::Volatility::Immutable
149                         } if sigs.len() == 2));
150    }
151
152    #[tokio::test]
153    async fn test_missing_procedure_service() {
154        let factory: ScalarFunctionFactory = MigrateRegionFunction::factory().into();
155        let provider = factory.provide(FunctionContext::default());
156        let f = provider.as_async().unwrap();
157
158        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
159            args: vec![
160                ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
161                ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
162                ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
163            ],
164            arg_fields: vec![
165                Arc::new(Field::new("arg_0", DataType::UInt64, false)),
166                Arc::new(Field::new("arg_1", DataType::UInt64, false)),
167                Arc::new(Field::new("arg_2", DataType::UInt64, false)),
168            ],
169            return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
170            number_rows: 1,
171            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
172        };
173        let result = f.invoke_async_with_args(func_args).await.unwrap_err();
174        assert_eq!(
175            "Execution error: Handler error: Missing ProcedureServiceHandler, not expected",
176            result.to_string()
177        );
178    }
179
180    #[tokio::test]
181    async fn test_migrate_region() {
182        let factory: ScalarFunctionFactory = MigrateRegionFunction::factory().into();
183        let provider = factory.provide(FunctionContext::mock());
184        let f = provider.as_async().unwrap();
185
186        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
187            args: vec![
188                ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
189                ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
190                ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
191            ],
192            arg_fields: vec![
193                Arc::new(Field::new("arg_0", DataType::UInt64, false)),
194                Arc::new(Field::new("arg_1", DataType::UInt64, false)),
195                Arc::new(Field::new("arg_2", DataType::UInt64, false)),
196            ],
197            return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
198            number_rows: 1,
199            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
200        };
201        let result = f.invoke_async_with_args(func_args).await.unwrap();
202
203        match result {
204            ColumnarValue::Array(array) => {
205                let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
206                assert_eq!(result_array.value(0), "test_pid");
207            }
208            ColumnarValue::Scalar(scalar) => {
209                assert_eq!(
210                    scalar,
211                    datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
212                );
213            }
214        }
215    }
216}