common_function/system/
procedure_state.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use api::v1::meta::ProcedureStatus;
16use common_macro::admin_fn;
17use common_meta::rpc::procedure::ProcedureStateResponse;
18use common_query::error::{
19    InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
20    UnsupportedInputDataTypeSnafu,
21};
22use common_query::prelude::{Signature, Volatility};
23use datatypes::prelude::*;
24use serde::Serialize;
25use session::context::QueryContextRef;
26use snafu::ensure;
27
28use crate::handlers::ProcedureServiceHandlerRef;
29
30#[derive(Serialize)]
31struct ProcedureStateJson {
32    status: String,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    error: Option<String>,
35}
36
37/// A function to query procedure state by its id.
38/// Such as `procedure_state(pid)`.
39#[admin_fn(
40    name = ProcedureStateFunction,
41    display_name = procedure_state,
42    sig_fn = signature,
43    ret = string
44)]
45pub(crate) async fn procedure_state(
46    procedure_service_handler: &ProcedureServiceHandlerRef,
47    _ctx: &QueryContextRef,
48    params: &[ValueRef<'_>],
49) -> Result<Value> {
50    ensure!(
51        params.len() == 1,
52        InvalidFuncArgsSnafu {
53            err_msg: format!(
54                "The length of the args is not correct, expect 1, have: {}",
55                params.len()
56            ),
57        }
58    );
59
60    let ValueRef::String(pid) = params[0] else {
61        return UnsupportedInputDataTypeSnafu {
62            function: "procedure_state",
63            datatypes: params.iter().map(|v| v.data_type()).collect::<Vec<_>>(),
64        }
65        .fail();
66    };
67
68    let ProcedureStateResponse { status, error, .. } =
69        procedure_service_handler.query_procedure_state(pid).await?;
70    let status = ProcedureStatus::try_from(status)
71        .map(|v| v.as_str_name())
72        .unwrap_or("Unknown");
73
74    let state = ProcedureStateJson {
75        status: status.to_string(),
76        error: if error.is_empty() { None } else { Some(error) },
77    };
78    let json = serde_json::to_string(&state).unwrap_or_default();
79
80    Ok(Value::from(json))
81}
82
83fn signature() -> Signature {
84    Signature::uniform(
85        1,
86        vec![ConcreteDataType::string_datatype()],
87        Volatility::Immutable,
88    )
89}
90
91#[cfg(test)]
92mod tests {
93    use std::sync::Arc;
94
95    use common_query::prelude::TypeSignature;
96    use datatypes::vectors::StringVector;
97
98    use super::*;
99    use crate::function::{AsyncFunction, FunctionContext};
100
101    #[test]
102    fn test_procedure_state_misc() {
103        let f = ProcedureStateFunction;
104        assert_eq!("procedure_state", f.name());
105        assert_eq!(
106            ConcreteDataType::string_datatype(),
107            f.return_type(&[]).unwrap()
108        );
109        assert!(matches!(f.signature(),
110                         Signature {
111                             type_signature: TypeSignature::Uniform(1, valid_types),
112                             volatility: Volatility::Immutable
113                         } if valid_types == vec![ConcreteDataType::string_datatype()]
114        ));
115    }
116
117    #[tokio::test]
118    async fn test_missing_procedure_service() {
119        let f = ProcedureStateFunction;
120
121        let args = vec!["pid"];
122
123        let args = args
124            .into_iter()
125            .map(|arg| Arc::new(StringVector::from_slice(&[arg])) as _)
126            .collect::<Vec<_>>();
127
128        let result = f.eval(FunctionContext::default(), &args).await.unwrap_err();
129        assert_eq!(
130            "Missing ProcedureServiceHandler, not expected",
131            result.to_string()
132        );
133    }
134
135    #[tokio::test]
136    async fn test_procedure_state() {
137        let f = ProcedureStateFunction;
138
139        let args = vec!["pid"];
140
141        let args = args
142            .into_iter()
143            .map(|arg| Arc::new(StringVector::from_slice(&[arg])) as _)
144            .collect::<Vec<_>>();
145
146        let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
147
148        let expect: VectorRef = Arc::new(StringVector::from(vec![
149            "{\"status\":\"Done\",\"error\":\"OK\"}",
150        ]));
151        assert_eq!(expect, result);
152    }
153}