common_function/system/
procedure_state.rs1use api::v1::meta::ProcedureStatus;
16use arrow::datatypes::DataType as ArrowDataType;
17use common_macro::admin_fn;
18use common_meta::rpc::procedure::ProcedureStateResponse;
19use common_query::error::{
20 InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
21 UnsupportedInputDataTypeSnafu,
22};
23use datafusion_expr::{Signature, Volatility};
24use datatypes::prelude::*;
25use serde::Serialize;
26use session::context::QueryContextRef;
27use snafu::ensure;
28
29use crate::handlers::ProcedureServiceHandlerRef;
30
31#[derive(Serialize)]
32struct ProcedureStateJson {
33 status: String,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 error: Option<String>,
36}
37
38#[admin_fn(
41 name = ProcedureStateFunction,
42 display_name = procedure_state,
43 sig_fn = signature,
44 ret = string
45)]
46pub(crate) async fn procedure_state(
47 procedure_service_handler: &ProcedureServiceHandlerRef,
48 _ctx: &QueryContextRef,
49 params: &[ValueRef<'_>],
50) -> Result<Value> {
51 ensure!(
52 params.len() == 1,
53 InvalidFuncArgsSnafu {
54 err_msg: format!(
55 "The length of the args is not correct, expect 1, have: {}",
56 params.len()
57 ),
58 }
59 );
60
61 let ValueRef::String(pid) = params[0] else {
62 return UnsupportedInputDataTypeSnafu {
63 function: "procedure_state",
64 datatypes: params.iter().map(|v| v.data_type()).collect::<Vec<_>>(),
65 }
66 .fail();
67 };
68
69 let ProcedureStateResponse { status, error, .. } =
70 procedure_service_handler.query_procedure_state(pid).await?;
71 let status = ProcedureStatus::try_from(status)
72 .map(|v| v.as_str_name())
73 .unwrap_or("Unknown");
74
75 let state = ProcedureStateJson {
76 status: status.to_string(),
77 error: if error.is_empty() { None } else { Some(error) },
78 };
79 let json = serde_json::to_string(&state).unwrap_or_default();
80
81 Ok(Value::from(json))
82}
83
84fn signature() -> Signature {
85 Signature::uniform(1, vec![ArrowDataType::Utf8], Volatility::Immutable)
86}
87
88#[cfg(test)]
89mod tests {
90 use std::sync::Arc;
91
92 use arrow::array::StringArray;
93 use arrow::datatypes::{DataType, Field};
94 use datafusion_expr::ColumnarValue;
95
96 use super::*;
97 use crate::function::FunctionContext;
98 use crate::function_factory::ScalarFunctionFactory;
99
100 #[test]
101 fn test_procedure_state_misc() {
102 let factory: ScalarFunctionFactory = ProcedureStateFunction::factory().into();
103 let f = factory.provide(FunctionContext::mock());
104 assert_eq!("procedure_state", f.name());
105 assert_eq!(DataType::Utf8, f.return_type(&[]).unwrap());
106 assert!(matches!(f.signature(),
107 datafusion_expr::Signature {
108 type_signature: datafusion_expr::TypeSignature::Uniform(1, valid_types),
109 volatility: datafusion_expr::Volatility::Immutable
110 } if valid_types == &vec![ArrowDataType::Utf8]));
111 }
112
113 #[tokio::test]
114 async fn test_missing_procedure_service() {
115 let factory: ScalarFunctionFactory = ProcedureStateFunction::factory().into();
116 let binding = factory.provide(FunctionContext::default());
117 let f = binding.as_async().unwrap();
118
119 let func_args = datafusion::logical_expr::ScalarFunctionArgs {
120 args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
121 "pid",
122 ])))],
123 arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
124 return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
125 number_rows: 1,
126 config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
127 };
128 let result = f.invoke_async_with_args(func_args).await;
129 assert!(result.is_err());
130 }
131
132 #[tokio::test]
133 async fn test_procedure_state() {
134 let factory: ScalarFunctionFactory = ProcedureStateFunction::factory().into();
135 let provider = factory.provide(FunctionContext::mock());
136 let f = provider.as_async().unwrap();
137
138 let func_args = datafusion::logical_expr::ScalarFunctionArgs {
139 args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
140 "pid",
141 ])))],
142 arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
143 return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
144 number_rows: 1,
145 config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
146 };
147 let result = f.invoke_async_with_args(func_args).await.unwrap();
148
149 match result {
150 ColumnarValue::Array(array) => {
151 let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
152 assert_eq!(
153 result_array.value(0),
154 "{\"status\":\"Done\",\"error\":\"OK\"}"
155 );
156 }
157 ColumnarValue::Scalar(scalar) => {
158 assert_eq!(
159 scalar,
160 datafusion_common::ScalarValue::Utf8(Some(
161 "{\"status\":\"Done\",\"error\":\"OK\"}".to_string()
162 ))
163 );
164 }
165 }
166 }
167}