1use std::collections::BTreeMap;
17use std::sync::Arc;
18
19use common_error::ext::BoxedError;
20use common_function::function::{FunctionContext, FunctionRef};
21use datafusion_substrait::extensions::Extensions;
22use datatypes::data_type::ConcreteDataType as CDT;
23use query::QueryEngine;
24use serde::{Deserialize, Serialize};
25use snafu::ResultExt;
26use substrait::substrait_proto_df as substrait_proto;
29use substrait_proto::proto::extensions::simple_extension_declaration::MappingType;
30use substrait_proto::proto::extensions::SimpleExtensionDeclaration;
31
32use crate::adapter::FlownodeContext;
33use crate::error::{Error, NotImplementedSnafu, UnexpectedSnafu};
34use crate::expr::{TUMBLE_END, TUMBLE_START};
35macro_rules! not_impl_err {
37 ($($arg:tt)*) => {
38 NotImplementedSnafu {
39 reason: format!($($arg)*),
40 }.fail()
41 };
42}
43
44macro_rules! plan_err {
46 ($($arg:tt)*) => {
47 PlanSnafu {
48 reason: format!($($arg)*),
49 }.fail()
50 };
51}
52
53mod aggr;
54mod expr;
55mod literal;
56mod plan;
57
58pub(crate) use expr::from_scalar_fn_to_df_fn_impl;
59
60#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
64pub struct FunctionExtensions {
65 anchor_to_name: BTreeMap<u32, String>,
66}
67
68impl FunctionExtensions {
69 pub fn from_iter(inner: impl IntoIterator<Item = (u32, impl ToString)>) -> Self {
70 Self {
71 anchor_to_name: inner.into_iter().map(|(k, s)| (k, s.to_string())).collect(),
72 }
73 }
74
75 pub fn try_from_proto(extensions: &[SimpleExtensionDeclaration]) -> Result<Self, Error> {
77 let mut anchor_to_name = BTreeMap::new();
78 for e in extensions {
79 match &e.mapping_type {
80 Some(ext) => match ext {
81 MappingType::ExtensionFunction(ext_f) => {
82 anchor_to_name.insert(ext_f.function_anchor, ext_f.name.clone());
83 }
84 _ => not_impl_err!("Extension type not supported: {ext:?}")?,
85 },
86 None => not_impl_err!("Cannot parse empty extension")?,
87 }
88 }
89 Ok(Self { anchor_to_name })
90 }
91
92 pub fn get(&self, anchor: &u32) -> Option<&String> {
94 self.anchor_to_name.get(anchor)
95 }
96
97 pub fn to_extensions(&self) -> Extensions {
98 Extensions {
99 functions: self
100 .anchor_to_name
101 .iter()
102 .map(|(k, v)| (*k, v.clone()))
103 .collect(),
104 ..Default::default()
105 }
106 }
107}
108
109pub fn register_function_to_query_engine(engine: &Arc<dyn QueryEngine>) {
111 let tumble_fn = Arc::new(TumbleFunction::new("tumble")) as FunctionRef;
112 let tumble_start_fn = Arc::new(TumbleFunction::new(TUMBLE_START)) as FunctionRef;
113 let tumble_end_fn = Arc::new(TumbleFunction::new(TUMBLE_END)) as FunctionRef;
114
115 engine.register_scalar_function(tumble_fn.into());
116 engine.register_scalar_function(tumble_start_fn.into());
117 engine.register_scalar_function(tumble_end_fn.into());
118}
119
120#[derive(Debug)]
121pub struct TumbleFunction {
122 name: String,
123}
124
125impl TumbleFunction {
126 fn new(name: &str) -> Self {
127 Self {
128 name: name.to_string(),
129 }
130 }
131}
132
133impl std::fmt::Display for TumbleFunction {
134 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
135 write!(f, "{}", self.name.to_ascii_uppercase())
136 }
137}
138
139impl common_function::function::Function for TumbleFunction {
140 fn name(&self) -> &str {
141 &self.name
142 }
143
144 fn return_type(&self, _input_types: &[CDT]) -> common_query::error::Result<CDT> {
145 Ok(CDT::timestamp_millisecond_datatype())
146 }
147
148 fn signature(&self) -> common_query::prelude::Signature {
149 common_query::prelude::Signature::variadic_any(common_query::prelude::Volatility::Immutable)
150 }
151
152 fn eval(
153 &self,
154 _func_ctx: &FunctionContext,
155 _columns: &[datatypes::prelude::VectorRef],
156 ) -> common_query::error::Result<datatypes::prelude::VectorRef> {
157 UnexpectedSnafu {
158 reason: "Tumbler function is not implemented for datafusion executor",
159 }
160 .fail()
161 .map_err(BoxedError::new)
162 .context(common_query::error::ExecuteSnafu)
163 }
164}
165
166#[cfg(test)]
167mod test {
168 use std::sync::Arc;
169
170 use catalog::RegisterTableRequest;
171 use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
172 use datatypes::prelude::*;
173 use datatypes::schema::Schema;
174 use datatypes::timestamp::TimestampMillisecond;
175 use datatypes::vectors::{TimestampMillisecondVectorBuilder, VectorRef};
176 use itertools::Itertools;
177 use prost::Message;
178 use query::options::QueryOptions;
179 use query::parser::QueryLanguageParser;
180 use query::query_engine::DefaultSerializer;
181 use query::QueryEngine;
182 use session::context::QueryContext;
183 use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
184 use substrait_proto::proto;
185 use table::table::numbers::{NumbersTable, NUMBERS_TABLE_NAME};
186 use table::test_util::MemTable;
187
188 use super::*;
189 use crate::adapter::node_context::IdToNameMap;
190 use crate::adapter::table_source::test::FlowDummyTableSource;
191 use crate::df_optimizer::apply_df_optimizer;
192 use crate::expr::GlobalId;
193
194 pub fn create_test_ctx() -> FlownodeContext {
195 let mut tri_map = IdToNameMap::new();
196 {
198 let gid = GlobalId::User(0);
199 let name = [
200 "greptime".to_string(),
201 "public".to_string(),
202 "numbers".to_string(),
203 ];
204 tri_map.insert(Some(name.clone()), Some(1024), gid);
205 }
206
207 {
208 let gid = GlobalId::User(1);
209 let name = [
210 "greptime".to_string(),
211 "public".to_string(),
212 "numbers_with_ts".to_string(),
213 ];
214 tri_map.insert(Some(name.clone()), Some(1025), gid);
215 }
216
217 let dummy_source = FlowDummyTableSource::default();
218
219 let mut ctx = FlownodeContext::new(Box::new(dummy_source));
220 ctx.table_repr = tri_map;
221 ctx.query_context = Some(Arc::new(QueryContext::with("greptime", "public")));
222
223 ctx
224 }
225
226 pub fn create_test_query_engine() -> Arc<dyn QueryEngine> {
227 let catalog_list = catalog::memory::new_memory_catalog_manager().unwrap();
228 let req = RegisterTableRequest {
229 catalog: DEFAULT_CATALOG_NAME.to_string(),
230 schema: DEFAULT_SCHEMA_NAME.to_string(),
231 table_name: NUMBERS_TABLE_NAME.to_string(),
232 table_id: NUMBERS_TABLE_ID,
233 table: NumbersTable::table(NUMBERS_TABLE_ID),
234 };
235 catalog_list.register_table_sync(req).unwrap();
236
237 let schema = vec![
238 datatypes::schema::ColumnSchema::new("number", CDT::uint32_datatype(), false),
239 datatypes::schema::ColumnSchema::new(
240 "ts",
241 CDT::timestamp_millisecond_datatype(),
242 false,
243 ),
244 ];
245 let mut columns = vec![];
246 let numbers = (1..=10).collect_vec();
247 let column: VectorRef = Arc::new(<u32 as Scalar>::VectorType::from_vec(numbers));
248 columns.push(column);
249
250 let ts = (1..=10).collect_vec();
251 let mut builder = TimestampMillisecondVectorBuilder::with_capacity(10);
252 ts.into_iter()
253 .map(|v| builder.push(Some(TimestampMillisecond::new(v))))
254 .count();
255 let column: VectorRef = builder.to_vector_cloned();
256 columns.push(column);
257
258 let schema = Arc::new(Schema::new(schema));
259 let recordbatch = common_recordbatch::RecordBatch::new(schema, columns).unwrap();
260 let table = MemTable::table("numbers_with_ts", recordbatch);
261
262 let req_with_ts = RegisterTableRequest {
263 catalog: DEFAULT_CATALOG_NAME.to_string(),
264 schema: DEFAULT_SCHEMA_NAME.to_string(),
265 table_name: "numbers_with_ts".to_string(),
266 table_id: 1024,
267 table,
268 };
269 catalog_list.register_table_sync(req_with_ts).unwrap();
270
271 let factory = query::QueryEngineFactory::new(
272 catalog_list,
273 None,
274 None,
275 None,
276 None,
277 false,
278 QueryOptions::default(),
279 );
280
281 let engine = factory.query_engine();
282 register_function_to_query_engine(&engine);
283
284 assert_eq!("datafusion", engine.name());
285 engine
286 }
287
288 pub async fn sql_to_substrait(engine: Arc<dyn QueryEngine>, sql: &str) -> proto::Plan {
289 let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
291 let plan = engine
292 .planner()
293 .plan(&stmt, QueryContext::arc())
294 .await
295 .unwrap();
296 let plan = apply_df_optimizer(plan).await.unwrap();
297
298 let bytes = DFLogicalSubstraitConvertor {}
300 .encode(&plan, DefaultSerializer)
301 .unwrap();
302
303 proto::Plan::decode(bytes).unwrap()
304 }
305
306 #[tokio::test]
308 async fn test_missing_key_check() {
309 let engine = create_test_query_engine();
310 let sql = "SELECT avg(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour'), number";
311
312 let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
313 let plan = engine
314 .planner()
315 .plan(&stmt, QueryContext::arc())
316 .await
317 .unwrap();
318 let plan = apply_df_optimizer(plan).await;
319
320 assert!(plan.is_err());
321 }
322}