flow/
transform.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Transform Substrait into execution plan
16use std::collections::BTreeMap;
17use std::sync::Arc;
18
19use common_error::ext::BoxedError;
20use common_function::function::{FunctionContext, FunctionRef};
21use datafusion::arrow::datatypes::{DataType, TimeUnit};
22use datafusion_expr::{Signature, Volatility};
23use datafusion_substrait::extensions::Extensions;
24use query::QueryEngine;
25use serde::{Deserialize, Serialize};
26use snafu::ResultExt;
27/// note here we are using the `substrait_proto_df` crate from the `substrait` module and
28/// rename it to `substrait_proto`
29use substrait::substrait_proto_df as substrait_proto;
30use substrait_proto::proto::extensions::SimpleExtensionDeclaration;
31use substrait_proto::proto::extensions::simple_extension_declaration::MappingType;
32
33use crate::adapter::FlownodeContext;
34use crate::error::{Error, NotImplementedSnafu, UnexpectedSnafu};
35use crate::expr::{TUMBLE_END, TUMBLE_START};
36/// a simple macro to generate a not implemented error
37macro_rules! not_impl_err {
38    ($($arg:tt)*)  => {
39        NotImplementedSnafu {
40            reason: format!($($arg)*),
41        }.fail()
42    };
43}
44
45/// generate a plan error
46macro_rules! plan_err {
47    ($($arg:tt)*)  => {
48        PlanSnafu {
49            reason: format!($($arg)*),
50        }.fail()
51    };
52}
53
54mod aggr;
55mod expr;
56mod literal;
57mod plan;
58
59pub(crate) use expr::from_scalar_fn_to_df_fn_impl;
60
61/// In Substrait, a function can be define by an u32 anchor, and the anchor can be mapped to a name
62///
63/// So in substrait plan, a ref to a function can be a single u32 anchor instead of a full name in string
64#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
65pub struct FunctionExtensions {
66    anchor_to_name: BTreeMap<u32, String>,
67}
68
69impl FunctionExtensions {
70    pub fn from_iter(inner: impl IntoIterator<Item = (u32, impl ToString)>) -> Self {
71        Self {
72            anchor_to_name: inner.into_iter().map(|(k, s)| (k, s.to_string())).collect(),
73        }
74    }
75
76    /// Create a new FunctionExtensions from a list of SimpleExtensionDeclaration
77    pub fn try_from_proto(extensions: &[SimpleExtensionDeclaration]) -> Result<Self, Error> {
78        let mut anchor_to_name = BTreeMap::new();
79        for e in extensions {
80            match &e.mapping_type {
81                Some(ext) => match ext {
82                    MappingType::ExtensionFunction(ext_f) => {
83                        anchor_to_name.insert(ext_f.function_anchor, ext_f.name.clone());
84                    }
85                    _ => not_impl_err!("Extension type not supported: {ext:?}")?,
86                },
87                None => not_impl_err!("Cannot parse empty extension")?,
88            }
89        }
90        Ok(Self { anchor_to_name })
91    }
92
93    /// Get the name of a function by it's anchor
94    pub fn get(&self, anchor: &u32) -> Option<&String> {
95        self.anchor_to_name.get(anchor)
96    }
97
98    pub fn to_extensions(&self) -> Extensions {
99        Extensions {
100            functions: self
101                .anchor_to_name
102                .iter()
103                .map(|(k, v)| (*k, v.clone()))
104                .collect(),
105            ..Default::default()
106        }
107    }
108}
109
110/// register flow-specific functions to the query engine
111pub fn register_function_to_query_engine(engine: &Arc<dyn QueryEngine>) {
112    let tumble_fn = Arc::new(TumbleFunction::new("tumble")) as FunctionRef;
113    let tumble_start_fn = Arc::new(TumbleFunction::new(TUMBLE_START)) as FunctionRef;
114    let tumble_end_fn = Arc::new(TumbleFunction::new(TUMBLE_END)) as FunctionRef;
115
116    engine.register_scalar_function(tumble_fn.into());
117    engine.register_scalar_function(tumble_start_fn.into());
118    engine.register_scalar_function(tumble_end_fn.into());
119}
120
121#[derive(Debug)]
122pub struct TumbleFunction {
123    name: String,
124}
125
126impl TumbleFunction {
127    fn new(name: &str) -> Self {
128        Self {
129            name: name.to_string(),
130        }
131    }
132}
133
134impl std::fmt::Display for TumbleFunction {
135    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
136        write!(f, "{}", self.name.to_ascii_uppercase())
137    }
138}
139
140impl common_function::function::Function for TumbleFunction {
141    fn name(&self) -> &str {
142        &self.name
143    }
144
145    fn return_type(&self, _: &[DataType]) -> common_query::error::Result<DataType> {
146        Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
147    }
148
149    fn signature(&self) -> Signature {
150        Signature::variadic_any(Volatility::Immutable)
151    }
152
153    fn eval(
154        &self,
155        _func_ctx: &FunctionContext,
156        _columns: &[datatypes::prelude::VectorRef],
157    ) -> common_query::error::Result<datatypes::prelude::VectorRef> {
158        UnexpectedSnafu {
159            reason: "Tumbler function is not implemented for datafusion executor",
160        }
161        .fail()
162        .map_err(BoxedError::new)
163        .context(common_query::error::ExecuteSnafu)
164    }
165}
166
167#[cfg(test)]
168mod test {
169    use std::sync::Arc;
170
171    use catalog::RegisterTableRequest;
172    use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
173    use datatypes::data_type::ConcreteDataType as CDT;
174    use datatypes::prelude::*;
175    use datatypes::schema::Schema;
176    use datatypes::timestamp::TimestampMillisecond;
177    use datatypes::vectors::{TimestampMillisecondVectorBuilder, VectorRef};
178    use itertools::Itertools;
179    use prost::Message;
180    use query::QueryEngine;
181    use query::options::QueryOptions;
182    use query::parser::QueryLanguageParser;
183    use query::query_engine::DefaultSerializer;
184    use session::context::QueryContext;
185    use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
186    use substrait_proto::proto;
187    use table::table::numbers::{NUMBERS_TABLE_NAME, NumbersTable};
188    use table::test_util::MemTable;
189
190    use super::*;
191    use crate::adapter::node_context::IdToNameMap;
192    use crate::adapter::table_source::test::FlowDummyTableSource;
193    use crate::df_optimizer::apply_df_optimizer;
194    use crate::expr::GlobalId;
195
196    pub fn create_test_ctx() -> FlownodeContext {
197        let mut tri_map = IdToNameMap::new();
198        // FIXME(discord9): deprecated, use `numbers_with_ts` instead since this table has no timestamp column
199        {
200            let gid = GlobalId::User(0);
201            let name = [
202                "greptime".to_string(),
203                "public".to_string(),
204                "numbers".to_string(),
205            ];
206            tri_map.insert(Some(name.clone()), Some(1024), gid);
207        }
208
209        {
210            let gid = GlobalId::User(1);
211            let name = [
212                "greptime".to_string(),
213                "public".to_string(),
214                "numbers_with_ts".to_string(),
215            ];
216            tri_map.insert(Some(name.clone()), Some(1025), gid);
217        }
218
219        let dummy_source = FlowDummyTableSource::default();
220
221        let mut ctx = FlownodeContext::new(Box::new(dummy_source));
222        ctx.table_repr = tri_map;
223        ctx.query_context = Some(Arc::new(QueryContext::with("greptime", "public")));
224
225        ctx
226    }
227
228    pub fn create_test_query_engine() -> Arc<dyn QueryEngine> {
229        let catalog_list = catalog::memory::new_memory_catalog_manager().unwrap();
230        let req = RegisterTableRequest {
231            catalog: DEFAULT_CATALOG_NAME.to_string(),
232            schema: DEFAULT_SCHEMA_NAME.to_string(),
233            table_name: NUMBERS_TABLE_NAME.to_string(),
234            table_id: NUMBERS_TABLE_ID,
235            table: NumbersTable::table(NUMBERS_TABLE_ID),
236        };
237        catalog_list.register_table_sync(req).unwrap();
238
239        let schema = vec![
240            datatypes::schema::ColumnSchema::new("number", CDT::uint32_datatype(), false),
241            datatypes::schema::ColumnSchema::new(
242                "ts",
243                CDT::timestamp_millisecond_datatype(),
244                false,
245            ),
246        ];
247        let mut columns = vec![];
248        let numbers = (1..=10).collect_vec();
249        let column: VectorRef = Arc::new(<u32 as Scalar>::VectorType::from_vec(numbers));
250        columns.push(column);
251
252        let ts = (1..=10).collect_vec();
253        let mut builder = TimestampMillisecondVectorBuilder::with_capacity(10);
254        ts.into_iter()
255            .map(|v| builder.push(Some(TimestampMillisecond::new(v))))
256            .count();
257        let column: VectorRef = builder.to_vector_cloned();
258        columns.push(column);
259
260        let schema = Arc::new(Schema::new(schema));
261        let recordbatch = common_recordbatch::RecordBatch::new(schema, columns).unwrap();
262        let table = MemTable::table("numbers_with_ts", recordbatch);
263
264        let req_with_ts = RegisterTableRequest {
265            catalog: DEFAULT_CATALOG_NAME.to_string(),
266            schema: DEFAULT_SCHEMA_NAME.to_string(),
267            table_name: "numbers_with_ts".to_string(),
268            table_id: 1024,
269            table,
270        };
271        catalog_list.register_table_sync(req_with_ts).unwrap();
272
273        let factory = query::QueryEngineFactory::new(
274            catalog_list,
275            None,
276            None,
277            None,
278            None,
279            false,
280            QueryOptions::default(),
281        );
282
283        let engine = factory.query_engine();
284        register_function_to_query_engine(&engine);
285
286        assert_eq!("datafusion", engine.name());
287        engine
288    }
289
290    pub async fn sql_to_substrait(engine: Arc<dyn QueryEngine>, sql: &str) -> proto::Plan {
291        // let engine = create_test_query_engine();
292        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
293        let plan = engine
294            .planner()
295            .plan(&stmt, QueryContext::arc())
296            .await
297            .unwrap();
298        let plan = apply_df_optimizer(plan, &QueryContext::arc())
299            .await
300            .unwrap();
301
302        // encode then decode so to rely on the impl of conversion from logical plan to substrait plan
303        let bytes = DFLogicalSubstraitConvertor {}
304            .encode(&plan, DefaultSerializer)
305            .unwrap();
306
307        proto::Plan::decode(bytes).unwrap()
308    }
309
310    /// TODO(discord9): add more illegal sql tests
311    #[tokio::test]
312    async fn test_missing_key_check() {
313        let engine = create_test_query_engine();
314        let sql = "SELECT avg(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour'), number";
315
316        let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
317        let plan = engine
318            .planner()
319            .plan(&stmt, QueryContext::arc())
320            .await
321            .unwrap();
322        let plan = apply_df_optimizer(plan, &QueryContext::arc()).await;
323
324        assert!(plan.is_err());
325    }
326}