common_function/system/
procedure_state.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use api::v1::meta::ProcedureStatus;
16use arrow::datatypes::DataType as ArrowDataType;
17use common_macro::admin_fn;
18use common_meta::rpc::procedure::ProcedureStateResponse;
19use common_query::error::{
20    InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
21    UnsupportedInputDataTypeSnafu,
22};
23use datafusion_expr::{Signature, Volatility};
24use datatypes::prelude::*;
25use serde::Serialize;
26use session::context::QueryContextRef;
27use snafu::ensure;
28
29use crate::handlers::ProcedureServiceHandlerRef;
30
31#[derive(Serialize)]
32struct ProcedureStateJson {
33    status: String,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    error: Option<String>,
36}
37
38/// A function to query procedure state by its id.
39/// Such as `procedure_state(pid)`.
40#[admin_fn(
41    name = ProcedureStateFunction,
42    display_name = procedure_state,
43    sig_fn = signature,
44    ret = string
45)]
46pub(crate) async fn procedure_state(
47    procedure_service_handler: &ProcedureServiceHandlerRef,
48    _ctx: &QueryContextRef,
49    params: &[ValueRef<'_>],
50) -> Result<Value> {
51    ensure!(
52        params.len() == 1,
53        InvalidFuncArgsSnafu {
54            err_msg: format!(
55                "The length of the args is not correct, expect 1, have: {}",
56                params.len()
57            ),
58        }
59    );
60
61    let ValueRef::String(pid) = params[0] else {
62        return UnsupportedInputDataTypeSnafu {
63            function: "procedure_state",
64            datatypes: params.iter().map(|v| v.data_type()).collect::<Vec<_>>(),
65        }
66        .fail();
67    };
68
69    let ProcedureStateResponse { status, error, .. } =
70        procedure_service_handler.query_procedure_state(pid).await?;
71    let status = ProcedureStatus::try_from(status)
72        .map(|v| v.as_str_name())
73        .unwrap_or("Unknown");
74
75    let state = ProcedureStateJson {
76        status: status.to_string(),
77        error: if error.is_empty() { None } else { Some(error) },
78    };
79    let json = serde_json::to_string(&state).unwrap_or_default();
80
81    Ok(Value::from(json))
82}
83
84fn signature() -> Signature {
85    Signature::uniform(1, vec![ArrowDataType::Utf8], Volatility::Immutable)
86}
87
88#[cfg(test)]
89mod tests {
90    use std::sync::Arc;
91
92    use arrow::array::StringArray;
93    use arrow::datatypes::{DataType, Field};
94    use datafusion_expr::ColumnarValue;
95
96    use super::*;
97    use crate::function::FunctionContext;
98    use crate::function_factory::ScalarFunctionFactory;
99
100    #[test]
101    fn test_procedure_state_misc() {
102        let factory: ScalarFunctionFactory = ProcedureStateFunction::factory().into();
103        let f = factory.provide(FunctionContext::mock());
104        assert_eq!("procedure_state", f.name());
105        assert_eq!(DataType::Utf8, f.return_type(&[]).unwrap());
106        assert!(matches!(f.signature(),
107                         datafusion_expr::Signature {
108                             type_signature: datafusion_expr::TypeSignature::Uniform(1, valid_types),
109                             volatility: datafusion_expr::Volatility::Immutable
110                         } if valid_types == &vec![ArrowDataType::Utf8]));
111    }
112
113    #[tokio::test]
114    async fn test_missing_procedure_service() {
115        let factory: ScalarFunctionFactory = ProcedureStateFunction::factory().into();
116        let binding = factory.provide(FunctionContext::default());
117        let f = binding.as_async().unwrap();
118
119        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
120            args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
121                "pid",
122            ])))],
123            arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
124            return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
125            number_rows: 1,
126            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
127        };
128        let result = f.invoke_async_with_args(func_args).await;
129        assert!(result.is_err());
130    }
131
132    #[tokio::test]
133    async fn test_procedure_state() {
134        let factory: ScalarFunctionFactory = ProcedureStateFunction::factory().into();
135        let provider = factory.provide(FunctionContext::mock());
136        let f = provider.as_async().unwrap();
137
138        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
139            args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
140                "pid",
141            ])))],
142            arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
143            return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
144            number_rows: 1,
145            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
146        };
147        let result = f.invoke_async_with_args(func_args).await.unwrap();
148
149        match result {
150            ColumnarValue::Array(array) => {
151                let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
152                assert_eq!(
153                    result_array.value(0),
154                    "{\"status\":\"Done\",\"error\":\"OK\"}"
155                );
156            }
157            ColumnarValue::Scalar(scalar) => {
158                assert_eq!(
159                    scalar,
160                    datafusion_common::ScalarValue::Utf8(Some(
161                        "{\"status\":\"Done\",\"error\":\"OK\"}".to_string()
162                    ))
163                );
164            }
165        }
166    }
167}