common_function/system/
procedure_state.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use api::v1::meta::ProcedureStatus;
16use arrow::datatypes::DataType as ArrowDataType;
17use common_macro::admin_fn;
18use common_meta::rpc::procedure::ProcedureStateResponse;
19use common_query::error::{
20    InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
21    UnsupportedInputDataTypeSnafu,
22};
23use datafusion_expr::{Signature, Volatility};
24use datatypes::prelude::*;
25use serde::Serialize;
26use session::context::QueryContextRef;
27use snafu::ensure;
28
29use crate::handlers::ProcedureServiceHandlerRef;
30
31#[derive(Serialize)]
32struct ProcedureStateJson {
33    status: String,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    error: Option<String>,
36}
37
38/// A function to query procedure state by its id.
39/// Such as `procedure_state(pid)`.
40#[admin_fn(
41    name = ProcedureStateFunction,
42    display_name = procedure_state,
43    sig_fn = signature,
44    ret = string
45)]
46pub(crate) async fn procedure_state(
47    procedure_service_handler: &ProcedureServiceHandlerRef,
48    _ctx: &QueryContextRef,
49    params: &[ValueRef<'_>],
50) -> Result<Value> {
51    ensure!(
52        params.len() == 1,
53        InvalidFuncArgsSnafu {
54            err_msg: format!(
55                "The length of the args is not correct, expect 1, have: {}",
56                params.len()
57            ),
58        }
59    );
60
61    let ValueRef::String(pid) = params[0] else {
62        return UnsupportedInputDataTypeSnafu {
63            function: "procedure_state",
64            datatypes: params.iter().map(|v| v.data_type()).collect::<Vec<_>>(),
65        }
66        .fail();
67    };
68
69    let ProcedureStateResponse { status, error, .. } =
70        procedure_service_handler.query_procedure_state(pid).await?;
71    let status = ProcedureStatus::try_from(status)
72        .map(|v| v.as_str_name())
73        .unwrap_or("Unknown");
74
75    let state = ProcedureStateJson {
76        status: status.to_string(),
77        error: if error.is_empty() { None } else { Some(error) },
78    };
79    let json = serde_json::to_string(&state).unwrap_or_default();
80
81    Ok(Value::from(json))
82}
83
84fn signature() -> Signature {
85    Signature::uniform(1, vec![ArrowDataType::Utf8], Volatility::Immutable)
86}
87
88#[cfg(test)]
89mod tests {
90    use std::sync::Arc;
91
92    use arrow::array::StringArray;
93    use arrow::datatypes::{DataType, Field};
94    use datafusion_expr::ColumnarValue;
95
96    use super::*;
97    use crate::function::FunctionContext;
98    use crate::function_factory::ScalarFunctionFactory;
99
100    #[test]
101    fn test_procedure_state_misc() {
102        let factory: ScalarFunctionFactory = ProcedureStateFunction::factory().into();
103        let f = factory.provide(FunctionContext::mock());
104        assert_eq!("procedure_state", f.name());
105        assert_eq!(DataType::Utf8, f.return_type(&[]).unwrap());
106        assert!(matches!(f.signature(),
107                         datafusion_expr::Signature {
108                             type_signature: datafusion_expr::TypeSignature::Uniform(1, valid_types),
109                             volatility: datafusion_expr::Volatility::Immutable,
110                             ..
111                         } if valid_types == &vec![ArrowDataType::Utf8]));
112    }
113
114    #[tokio::test]
115    async fn test_missing_procedure_service() {
116        let factory: ScalarFunctionFactory = ProcedureStateFunction::factory().into();
117        let binding = factory.provide(FunctionContext::default());
118        let f = binding.as_async().unwrap();
119
120        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
121            args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
122                "pid",
123            ])))],
124            arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
125            return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
126            number_rows: 1,
127            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
128        };
129        let result = f.invoke_async_with_args(func_args).await;
130        assert!(result.is_err());
131    }
132
133    #[tokio::test]
134    async fn test_procedure_state() {
135        let factory: ScalarFunctionFactory = ProcedureStateFunction::factory().into();
136        let provider = factory.provide(FunctionContext::mock());
137        let f = provider.as_async().unwrap();
138
139        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
140            args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
141                "pid",
142            ])))],
143            arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
144            return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
145            number_rows: 1,
146            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
147        };
148        let result = f.invoke_async_with_args(func_args).await.unwrap();
149
150        match result {
151            ColumnarValue::Array(array) => {
152                let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
153                assert_eq!(
154                    result_array.value(0),
155                    "{\"status\":\"Done\",\"error\":\"OK\"}"
156                );
157            }
158            ColumnarValue::Scalar(scalar) => {
159                assert_eq!(
160                    scalar,
161                    datafusion_common::ScalarValue::Utf8(Some(
162                        "{\"status\":\"Done\",\"error\":\"OK\"}".to_string()
163                    ))
164                );
165            }
166        }
167    }
168}