1use std::time::Duration;
16
17use common_macro::admin_fn;
18use common_meta::rpc::procedure::MigrateRegionRequest;
19use common_query::error::{InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result};
20use datafusion_expr::{Signature, TypeSignature, Volatility};
21use datatypes::data_type::DataType;
22use datatypes::prelude::ConcreteDataType;
23use datatypes::value::{Value, ValueRef};
24use session::context::QueryContextRef;
25
26use crate::handlers::ProcedureServiceHandlerRef;
27use crate::helper::cast_u64;
28
29const DEFAULT_TIMEOUT_SECS: u64 = 300;
31
32#[admin_fn(
43 name = MigrateRegionFunction,
44 display_name = migrate_region,
45 sig_fn = signature,
46 ret = string
47)]
48pub(crate) async fn migrate_region(
49 procedure_service_handler: &ProcedureServiceHandlerRef,
50 _ctx: &QueryContextRef,
51 params: &[ValueRef<'_>],
52) -> Result<Value> {
53 let (region_id, from_peer, to_peer, timeout) = match params.len() {
54 3 => {
55 let region_id = cast_u64(¶ms[0])?;
56 let from_peer = cast_u64(¶ms[1])?;
57 let to_peer = cast_u64(¶ms[2])?;
58
59 (region_id, from_peer, to_peer, Some(DEFAULT_TIMEOUT_SECS))
60 }
61
62 4 => {
63 let region_id = cast_u64(¶ms[0])?;
64 let from_peer = cast_u64(¶ms[1])?;
65 let to_peer = cast_u64(¶ms[2])?;
66 let replay_timeout = cast_u64(¶ms[3])?;
67
68 (region_id, from_peer, to_peer, replay_timeout)
69 }
70
71 size => {
72 return InvalidFuncArgsSnafu {
73 err_msg: format!(
74 "The length of the args is not correct, expect exactly 3 or 4, have: {}",
75 size
76 ),
77 }
78 .fail();
79 }
80 };
81
82 match (region_id, from_peer, to_peer, timeout) {
83 (Some(region_id), Some(from_peer), Some(to_peer), Some(timeout)) => {
84 let pid = procedure_service_handler
85 .migrate_region(MigrateRegionRequest {
86 region_id,
87 from_peer,
88 to_peer,
89 timeout: Duration::from_secs(timeout),
90 })
91 .await?;
92
93 match pid {
94 Some(pid) => Ok(Value::from(pid)),
95 None => Ok(Value::Null),
96 }
97 }
98
99 _ => Ok(Value::Null),
100 }
101}
102
103fn signature() -> Signature {
104 Signature::one_of(
105 vec![
106 TypeSignature::Uniform(
108 3,
109 ConcreteDataType::numerics()
110 .into_iter()
111 .map(|dt| dt.as_arrow_type())
112 .collect(),
113 ),
114 TypeSignature::Uniform(
116 4,
117 ConcreteDataType::numerics()
118 .into_iter()
119 .map(|dt| dt.as_arrow_type())
120 .collect(),
121 ),
122 ],
123 Volatility::Immutable,
124 )
125}
126
127#[cfg(test)]
128mod tests {
129 use std::sync::Arc;
130
131 use arrow::array::{StringArray, UInt64Array};
132 use arrow::datatypes::{DataType, Field};
133 use datafusion_expr::ColumnarValue;
134
135 use super::*;
136 use crate::function::FunctionContext;
137 use crate::function_factory::ScalarFunctionFactory;
138
139 #[test]
140 fn test_migrate_region_misc() {
141 let factory: ScalarFunctionFactory = MigrateRegionFunction::factory().into();
142 let f = factory.provide(FunctionContext::mock());
143 assert_eq!("migrate_region", f.name());
144 assert_eq!(DataType::Utf8, f.return_type(&[]).unwrap());
145 assert!(matches!(f.signature(),
146 datafusion_expr::Signature {
147 type_signature: datafusion_expr::TypeSignature::OneOf(sigs),
148 volatility: datafusion_expr::Volatility::Immutable
149 } if sigs.len() == 2));
150 }
151
152 #[tokio::test]
153 async fn test_missing_procedure_service() {
154 let factory: ScalarFunctionFactory = MigrateRegionFunction::factory().into();
155 let provider = factory.provide(FunctionContext::default());
156 let f = provider.as_async().unwrap();
157
158 let func_args = datafusion::logical_expr::ScalarFunctionArgs {
159 args: vec![
160 ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
161 ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
162 ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
163 ],
164 arg_fields: vec![
165 Arc::new(Field::new("arg_0", DataType::UInt64, false)),
166 Arc::new(Field::new("arg_1", DataType::UInt64, false)),
167 Arc::new(Field::new("arg_2", DataType::UInt64, false)),
168 ],
169 return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
170 number_rows: 1,
171 config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
172 };
173 let result = f.invoke_async_with_args(func_args).await.unwrap_err();
174 assert_eq!(
175 "Execution error: Handler error: Missing ProcedureServiceHandler, not expected",
176 result.to_string()
177 );
178 }
179
180 #[tokio::test]
181 async fn test_migrate_region() {
182 let factory: ScalarFunctionFactory = MigrateRegionFunction::factory().into();
183 let provider = factory.provide(FunctionContext::mock());
184 let f = provider.as_async().unwrap();
185
186 let func_args = datafusion::logical_expr::ScalarFunctionArgs {
187 args: vec![
188 ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
189 ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
190 ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
191 ],
192 arg_fields: vec![
193 Arc::new(Field::new("arg_0", DataType::UInt64, false)),
194 Arc::new(Field::new("arg_1", DataType::UInt64, false)),
195 Arc::new(Field::new("arg_2", DataType::UInt64, false)),
196 ],
197 return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
198 number_rows: 1,
199 config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
200 };
201 let result = f.invoke_async_with_args(func_args).await.unwrap();
202
203 match result {
204 ColumnarValue::Array(array) => {
205 let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
206 assert_eq!(result_array.value(0), "test_pid");
207 }
208 ColumnarValue::Scalar(scalar) => {
209 assert_eq!(
210 scalar,
211 datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
212 );
213 }
214 }
215 }
216}