common_function/
flush_flow.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use arrow::datatypes::DataType as ArrowDataType;
16use common_error::ext::BoxedError;
17use common_macro::admin_fn;
18use common_query::error::{
19    ExecuteSnafu, InvalidFuncArgsSnafu, MissingFlowServiceHandlerSnafu, Result,
20    UnsupportedInputDataTypeSnafu,
21};
22use datafusion_expr::{Signature, Volatility};
23use datatypes::value::{Value, ValueRef};
24use session::context::QueryContextRef;
25use snafu::{ensure, ResultExt};
26use sql::ast::ObjectNamePartExt;
27use sql::parser::ParserContext;
28
29use crate::handlers::FlowServiceHandlerRef;
30
31fn flush_signature() -> Signature {
32    Signature::uniform(1, vec![ArrowDataType::Utf8], Volatility::Immutable)
33}
34
35#[admin_fn(
36    name = FlushFlowFunction,
37    display_name = flush_flow,
38    sig_fn = flush_signature,
39    ret = uint64
40)]
41pub(crate) async fn flush_flow(
42    flow_service_handler: &FlowServiceHandlerRef,
43    query_ctx: &QueryContextRef,
44    params: &[ValueRef<'_>],
45) -> Result<Value> {
46    let (catalog_name, flow_name) = parse_flush_flow(params, query_ctx)?;
47
48    let res = flow_service_handler
49        .flush(&catalog_name, &flow_name, query_ctx.clone())
50        .await?;
51    let affected_rows = res.affected_rows;
52
53    Ok(Value::from(affected_rows))
54}
55
56fn parse_flush_flow(
57    params: &[ValueRef<'_>],
58    query_ctx: &QueryContextRef,
59) -> Result<(String, String)> {
60    ensure!(
61        params.len() == 1,
62        InvalidFuncArgsSnafu {
63            err_msg: format!(
64                "The length of the args is not correct, expect 1, have: {}",
65                params.len()
66            ),
67        }
68    );
69
70    let ValueRef::String(flow_name) = params[0] else {
71        return UnsupportedInputDataTypeSnafu {
72            function: "flush_flow",
73            datatypes: params.iter().map(|v| v.data_type()).collect::<Vec<_>>(),
74        }
75        .fail();
76    };
77    let obj_name = ParserContext::parse_table_name(flow_name, query_ctx.sql_dialect())
78        .map_err(BoxedError::new)
79        .context(ExecuteSnafu)?;
80
81    let (catalog_name, flow_name) = match &obj_name.0[..] {
82        [flow_name] => (
83            query_ctx.current_catalog().to_string(),
84            flow_name.to_string_unquoted(),
85        ),
86        [catalog, flow_name] => (catalog.to_string_unquoted(), flow_name.to_string_unquoted()),
87        _ => {
88            return InvalidFuncArgsSnafu {
89                err_msg: format!(
90                    "expect flow name to be <catalog>.<flow-name> or <flow-name>, actual: {}",
91                    obj_name
92                ),
93            }
94            .fail()
95        }
96    };
97    Ok((catalog_name, flow_name))
98}
99
100#[cfg(test)]
101mod test {
102    use std::sync::Arc;
103
104    use session::context::QueryContext;
105
106    use super::*;
107    use crate::function::FunctionContext;
108    use crate::function_factory::ScalarFunctionFactory;
109
110    #[test]
111    fn test_flush_flow_metadata() {
112        let factory: ScalarFunctionFactory = FlushFlowFunction::factory().into();
113        let f = factory.provide(FunctionContext::mock());
114        assert_eq!("flush_flow", f.name());
115        assert_eq!(ArrowDataType::UInt64, f.return_type(&[]).unwrap());
116        let expected_signature = datafusion_expr::Signature::uniform(
117            1,
118            vec![ArrowDataType::Utf8],
119            datafusion_expr::Volatility::Immutable,
120        );
121        assert_eq!(*f.signature(), expected_signature);
122    }
123
124    #[tokio::test]
125    async fn test_missing_flow_service() {
126        let factory: ScalarFunctionFactory = FlushFlowFunction::factory().into();
127        let binding = factory.provide(FunctionContext::default());
128        let f = binding.as_async().unwrap();
129
130        let flow_name_array = Arc::new(arrow::array::StringArray::from(vec!["flow_name"]));
131
132        let columnar_args = vec![datafusion_expr::ColumnarValue::Array(flow_name_array as _)];
133
134        let func_args = datafusion::logical_expr::ScalarFunctionArgs {
135            args: columnar_args,
136            arg_fields: vec![Arc::new(arrow::datatypes::Field::new(
137                "arg_0",
138                ArrowDataType::Utf8,
139                false,
140            ))],
141            return_field: Arc::new(arrow::datatypes::Field::new(
142                "result",
143                ArrowDataType::UInt64,
144                true,
145            )),
146            number_rows: 1,
147            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
148        };
149
150        let result = f.invoke_async_with_args(func_args).await.unwrap_err();
151        assert_eq!(
152            "Execution error: Handler error: Missing FlowServiceHandler, not expected",
153            result.to_string()
154        );
155    }
156
157    #[test]
158    fn test_parse_flow_args() {
159        let testcases = [
160            ("flow_name", ("greptime", "flow_name")),
161            ("catalog.flow_name", ("catalog", "flow_name")),
162        ];
163        for (input, expected) in testcases.iter() {
164            let args = vec![*input];
165            let args = args.into_iter().map(ValueRef::String).collect::<Vec<_>>();
166
167            let result = parse_flush_flow(&args, &QueryContext::arc()).unwrap();
168            assert_eq!(*expected, (result.0.as_str(), result.1.as_str()));
169        }
170    }
171}