flow/
transform.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Transform Substrait into execution plan
16use std::collections::BTreeMap;
17use std::sync::Arc;
18
19use common_error::ext::BoxedError;
20use common_function::function::{FunctionContext, FunctionRef};
21use datafusion_substrait::extensions::Extensions;
22use datatypes::data_type::ConcreteDataType as CDT;
23use query::QueryEngine;
24use serde::{Deserialize, Serialize};
25use snafu::ResultExt;
26/// note here we are using the `substrait_proto_df` crate from the `substrait` module and
27/// rename it to `substrait_proto`
28use substrait::substrait_proto_df as substrait_proto;
29use substrait_proto::proto::extensions::simple_extension_declaration::MappingType;
30use substrait_proto::proto::extensions::SimpleExtensionDeclaration;
31
32use crate::adapter::FlownodeContext;
33use crate::error::{Error, NotImplementedSnafu, UnexpectedSnafu};
34use crate::expr::{TUMBLE_END, TUMBLE_START};
35/// a simple macro to generate a not implemented error
36macro_rules! not_impl_err {
37    ($($arg:tt)*)  => {
38        NotImplementedSnafu {
39            reason: format!($($arg)*),
40        }.fail()
41    };
42}
43
44/// generate a plan error
45macro_rules! plan_err {
46    ($($arg:tt)*)  => {
47        PlanSnafu {
48            reason: format!($($arg)*),
49        }.fail()
50    };
51}
52
53mod aggr;
54mod expr;
55mod literal;
56mod plan;
57
58pub(crate) use expr::from_scalar_fn_to_df_fn_impl;
59
60/// In Substrait, a function can be define by an u32 anchor, and the anchor can be mapped to a name
61///
62/// So in substrait plan, a ref to a function can be a single u32 anchor instead of a full name in string
63#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
64pub struct FunctionExtensions {
65    anchor_to_name: BTreeMap<u32, String>,
66}
67
68impl FunctionExtensions {
69    pub fn from_iter(inner: impl IntoIterator<Item = (u32, impl ToString)>) -> Self {
70        Self {
71            anchor_to_name: inner.into_iter().map(|(k, s)| (k, s.to_string())).collect(),
72        }
73    }
74
75    /// Create a new FunctionExtensions from a list of SimpleExtensionDeclaration
76    pub fn try_from_proto(extensions: &[SimpleExtensionDeclaration]) -> Result<Self, Error> {
77        let mut anchor_to_name = BTreeMap::new();
78        for e in extensions {
79            match &e.mapping_type {
80                Some(ext) => match ext {
81                    MappingType::ExtensionFunction(ext_f) => {
82                        anchor_to_name.insert(ext_f.function_anchor, ext_f.name.clone());
83                    }
84                    _ => not_impl_err!("Extension type not supported: {ext:?}")?,
85                },
86                None => not_impl_err!("Cannot parse empty extension")?,
87            }
88        }
89        Ok(Self { anchor_to_name })
90    }
91
92    /// Get the name of a function by it's anchor
93    pub fn get(&self, anchor: &u32) -> Option<&String> {
94        self.anchor_to_name.get(anchor)
95    }
96
97    pub fn to_extensions(&self) -> Extensions {
98        Extensions {
99            functions: self
100                .anchor_to_name
101                .iter()
102                .map(|(k, v)| (*k, v.clone()))
103                .collect(),
104            ..Default::default()
105        }
106    }
107}
108
109/// register flow-specific functions to the query engine
110pub fn register_function_to_query_engine(engine: &Arc<dyn QueryEngine>) {
111    let tumble_fn = Arc::new(TumbleFunction::new("tumble")) as FunctionRef;
112    let tumble_start_fn = Arc::new(TumbleFunction::new(TUMBLE_START)) as FunctionRef;
113    let tumble_end_fn = Arc::new(TumbleFunction::new(TUMBLE_END)) as FunctionRef;
114
115    engine.register_scalar_function(tumble_fn.into());
116    engine.register_scalar_function(tumble_start_fn.into());
117    engine.register_scalar_function(tumble_end_fn.into());
118}
119
120#[derive(Debug)]
121pub struct TumbleFunction {
122    name: String,
123}
124
125impl TumbleFunction {
126    fn new(name: &str) -> Self {
127        Self {
128            name: name.to_string(),
129        }
130    }
131}
132
133impl std::fmt::Display for TumbleFunction {
134    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
135        write!(f, "{}", self.name.to_ascii_uppercase())
136    }
137}
138
139impl common_function::function::Function for TumbleFunction {
140    fn name(&self) -> &str {
141        &self.name
142    }
143
144    fn return_type(&self, _input_types: &[CDT]) -> common_query::error::Result<CDT> {
145        Ok(CDT::timestamp_millisecond_datatype())
146    }
147
148    fn signature(&self) -> common_query::prelude::Signature {
149        common_query::prelude::Signature::variadic_any(common_query::prelude::Volatility::Immutable)
150    }
151
152    fn eval(
153        &self,
154        _func_ctx: &FunctionContext,
155        _columns: &[datatypes::prelude::VectorRef],
156    ) -> common_query::error::Result<datatypes::prelude::VectorRef> {
157        UnexpectedSnafu {
158            reason: "Tumbler function is not implemented for datafusion executor",
159        }
160        .fail()
161        .map_err(BoxedError::new)
162        .context(common_query::error::ExecuteSnafu)
163    }
164}
165
166#[cfg(test)]
167mod test {
168    use std::sync::Arc;
169
170    use catalog::RegisterTableRequest;
171    use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
172    use datatypes::prelude::*;
173    use datatypes::schema::Schema;
174    use datatypes::timestamp::TimestampMillisecond;
175    use datatypes::vectors::{TimestampMillisecondVectorBuilder, VectorRef};
176    use itertools::Itertools;
177    use prost::Message;
178    use query::options::QueryOptions;
179    use query::parser::QueryLanguageParser;
180    use query::query_engine::DefaultSerializer;
181    use query::QueryEngine;
182    use session::context::QueryContext;
183    use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
184    use substrait_proto::proto;
185    use table::table::numbers::{NumbersTable, NUMBERS_TABLE_NAME};
186    use table::test_util::MemTable;
187
188    use super::*;
189    use crate::adapter::node_context::IdToNameMap;
190    use crate::adapter::table_source::test::FlowDummyTableSource;
191    use crate::df_optimizer::apply_df_optimizer;
192    use crate::expr::GlobalId;
193
194    pub fn create_test_ctx() -> FlownodeContext {
195        let mut tri_map = IdToNameMap::new();
196        // FIXME(discord9): deprecated, use `numbers_with_ts` instead since this table has no timestamp column
197        {
198            let gid = GlobalId::User(0);
199            let name = [
200                "greptime".to_string(),
201                "public".to_string(),
202                "numbers".to_string(),
203            ];
204            tri_map.insert(Some(name.clone()), Some(1024), gid);
205        }
206
207        {
208            let gid = GlobalId::User(1);
209            let name = [
210                "greptime".to_string(),
211                "public".to_string(),
212                "numbers_with_ts".to_string(),
213            ];
214            tri_map.insert(Some(name.clone()), Some(1025), gid);
215        }
216
217        let dummy_source = FlowDummyTableSource::default();
218
219        let mut ctx = FlownodeContext::new(Box::new(dummy_source));
220        ctx.table_repr = tri_map;
221        ctx.query_context = Some(Arc::new(QueryContext::with("greptime", "public")));
222
223        ctx
224    }
225
226    pub fn create_test_query_engine() -> Arc<dyn QueryEngine> {
227        let catalog_list = catalog::memory::new_memory_catalog_manager().unwrap();
228        let req = RegisterTableRequest {
229            catalog: DEFAULT_CATALOG_NAME.to_string(),
230            schema: DEFAULT_SCHEMA_NAME.to_string(),
231            table_name: NUMBERS_TABLE_NAME.to_string(),
232            table_id: NUMBERS_TABLE_ID,
233            table: NumbersTable::table(NUMBERS_TABLE_ID),
234        };
235        catalog_list.register_table_sync(req).unwrap();
236
237        let schema = vec![
238            datatypes::schema::ColumnSchema::new("number", CDT::uint32_datatype(), false),
239            datatypes::schema::ColumnSchema::new(
240                "ts",
241                CDT::timestamp_millisecond_datatype(),
242                false,
243            ),
244        ];
245        let mut columns = vec![];
246        let numbers = (1..=10).collect_vec();
247        let column: VectorRef = Arc::new(<u32 as Scalar>::VectorType::from_vec(numbers));
248        columns.push(column);
249
250        let ts = (1..=10).collect_vec();
251        let mut builder = TimestampMillisecondVectorBuilder::with_capacity(10);
252        ts.into_iter()
253            .map(|v| builder.push(Some(TimestampMillisecond::new(v))))
254            .count();
255        let column: VectorRef = builder.to_vector_cloned();
256        columns.push(column);
257
258        let schema = Arc::new(Schema::new(schema));
259        let recordbatch = common_recordbatch::RecordBatch::new(schema, columns).unwrap();
260        let table = MemTable::table("numbers_with_ts", recordbatch);
261
262        let req_with_ts = RegisterTableRequest {
263            catalog: DEFAULT_CATALOG_NAME.to_string(),
264            schema: DEFAULT_SCHEMA_NAME.to_string(),
265            table_name: "numbers_with_ts".to_string(),
266            table_id: 1024,
267            table,
268        };
269        catalog_list.register_table_sync(req_with_ts).unwrap();
270
271        let factory = query::QueryEngineFactory::new(
272            catalog_list,
273            None,
274            None,
275            None,
276            None,
277            false,
278            QueryOptions::default(),
279        );
280
281        let engine = factory.query_engine();
282        register_function_to_query_engine(&engine);
283
284        assert_eq!("datafusion", engine.name());
285        engine
286    }
287
288    pub async fn sql_to_substrait(engine: Arc<dyn QueryEngine>, sql: &str) -> proto::Plan {
289        // let engine = create_test_query_engine();
290        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
291        let plan = engine
292            .planner()
293            .plan(&stmt, QueryContext::arc())
294            .await
295            .unwrap();
296        let plan = apply_df_optimizer(plan).await.unwrap();
297
298        // encode then decode so to rely on the impl of conversion from logical plan to substrait plan
299        let bytes = DFLogicalSubstraitConvertor {}
300            .encode(&plan, DefaultSerializer)
301            .unwrap();
302
303        proto::Plan::decode(bytes).unwrap()
304    }
305
306    /// TODO(discord9): add more illegal sql tests
307    #[tokio::test]
308    async fn test_missing_key_check() {
309        let engine = create_test_query_engine();
310        let sql = "SELECT avg(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour'), number";
311
312        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
313        let plan = engine
314            .planner()
315            .plan(&stmt, QueryContext::arc())
316            .await
317            .unwrap();
318        let plan = apply_df_optimizer(plan).await;
319
320        assert!(plan.is_err());
321    }
322}