1use std::collections::BTreeMap;
17use std::sync::Arc;
18
19use common_error::ext::BoxedError;
20use common_function::function::{FunctionContext, FunctionRef};
21use datafusion::arrow::datatypes::{DataType, TimeUnit};
22use datafusion_expr::{Signature, Volatility};
23use datafusion_substrait::extensions::Extensions;
24use query::QueryEngine;
25use serde::{Deserialize, Serialize};
26use snafu::ResultExt;
27use substrait::substrait_proto_df as substrait_proto;
30use substrait_proto::proto::extensions::SimpleExtensionDeclaration;
31use substrait_proto::proto::extensions::simple_extension_declaration::MappingType;
32
33use crate::adapter::FlownodeContext;
34use crate::error::{Error, NotImplementedSnafu, UnexpectedSnafu};
35use crate::expr::{TUMBLE_END, TUMBLE_START};
36macro_rules! not_impl_err {
38 ($($arg:tt)*) => {
39 NotImplementedSnafu {
40 reason: format!($($arg)*),
41 }.fail()
42 };
43}
44
45macro_rules! plan_err {
47 ($($arg:tt)*) => {
48 PlanSnafu {
49 reason: format!($($arg)*),
50 }.fail()
51 };
52}
53
54mod aggr;
55mod expr;
56mod literal;
57mod plan;
58
59pub(crate) use expr::from_scalar_fn_to_df_fn_impl;
60
61#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
65pub struct FunctionExtensions {
66 anchor_to_name: BTreeMap<u32, String>,
67}
68
69impl FunctionExtensions {
70 pub fn from_iter(inner: impl IntoIterator<Item = (u32, impl ToString)>) -> Self {
71 Self {
72 anchor_to_name: inner.into_iter().map(|(k, s)| (k, s.to_string())).collect(),
73 }
74 }
75
76 pub fn try_from_proto(extensions: &[SimpleExtensionDeclaration]) -> Result<Self, Error> {
78 let mut anchor_to_name = BTreeMap::new();
79 for e in extensions {
80 match &e.mapping_type {
81 Some(ext) => match ext {
82 MappingType::ExtensionFunction(ext_f) => {
83 anchor_to_name.insert(ext_f.function_anchor, ext_f.name.clone());
84 }
85 _ => not_impl_err!("Extension type not supported: {ext:?}")?,
86 },
87 None => not_impl_err!("Cannot parse empty extension")?,
88 }
89 }
90 Ok(Self { anchor_to_name })
91 }
92
93 pub fn get(&self, anchor: &u32) -> Option<&String> {
95 self.anchor_to_name.get(anchor)
96 }
97
98 pub fn to_extensions(&self) -> Extensions {
99 Extensions {
100 functions: self
101 .anchor_to_name
102 .iter()
103 .map(|(k, v)| (*k, v.clone()))
104 .collect(),
105 ..Default::default()
106 }
107 }
108}
109
110pub fn register_function_to_query_engine(engine: &Arc<dyn QueryEngine>) {
112 let tumble_fn = Arc::new(TumbleFunction::new("tumble")) as FunctionRef;
113 let tumble_start_fn = Arc::new(TumbleFunction::new(TUMBLE_START)) as FunctionRef;
114 let tumble_end_fn = Arc::new(TumbleFunction::new(TUMBLE_END)) as FunctionRef;
115
116 engine.register_scalar_function(tumble_fn.into());
117 engine.register_scalar_function(tumble_start_fn.into());
118 engine.register_scalar_function(tumble_end_fn.into());
119}
120
121#[derive(Debug)]
122pub struct TumbleFunction {
123 name: String,
124}
125
126impl TumbleFunction {
127 fn new(name: &str) -> Self {
128 Self {
129 name: name.to_string(),
130 }
131 }
132}
133
134impl std::fmt::Display for TumbleFunction {
135 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
136 write!(f, "{}", self.name.to_ascii_uppercase())
137 }
138}
139
140impl common_function::function::Function for TumbleFunction {
141 fn name(&self) -> &str {
142 &self.name
143 }
144
145 fn return_type(&self, _: &[DataType]) -> common_query::error::Result<DataType> {
146 Ok(DataType::Timestamp(TimeUnit::Millisecond, None))
147 }
148
149 fn signature(&self) -> Signature {
150 Signature::variadic_any(Volatility::Immutable)
151 }
152
153 fn eval(
154 &self,
155 _func_ctx: &FunctionContext,
156 _columns: &[datatypes::prelude::VectorRef],
157 ) -> common_query::error::Result<datatypes::prelude::VectorRef> {
158 UnexpectedSnafu {
159 reason: "Tumbler function is not implemented for datafusion executor",
160 }
161 .fail()
162 .map_err(BoxedError::new)
163 .context(common_query::error::ExecuteSnafu)
164 }
165}
166
167#[cfg(test)]
168mod test {
169 use std::sync::Arc;
170
171 use catalog::RegisterTableRequest;
172 use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID};
173 use datatypes::data_type::ConcreteDataType as CDT;
174 use datatypes::prelude::*;
175 use datatypes::schema::Schema;
176 use datatypes::timestamp::TimestampMillisecond;
177 use datatypes::vectors::{TimestampMillisecondVectorBuilder, VectorRef};
178 use itertools::Itertools;
179 use prost::Message;
180 use query::QueryEngine;
181 use query::options::QueryOptions;
182 use query::parser::QueryLanguageParser;
183 use query::query_engine::DefaultSerializer;
184 use session::context::QueryContext;
185 use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
186 use substrait_proto::proto;
187 use table::table::numbers::{NUMBERS_TABLE_NAME, NumbersTable};
188 use table::test_util::MemTable;
189
190 use super::*;
191 use crate::adapter::node_context::IdToNameMap;
192 use crate::adapter::table_source::test::FlowDummyTableSource;
193 use crate::df_optimizer::apply_df_optimizer;
194 use crate::expr::GlobalId;
195
196 pub fn create_test_ctx() -> FlownodeContext {
197 let mut tri_map = IdToNameMap::new();
198 {
200 let gid = GlobalId::User(0);
201 let name = [
202 "greptime".to_string(),
203 "public".to_string(),
204 "numbers".to_string(),
205 ];
206 tri_map.insert(Some(name.clone()), Some(1024), gid);
207 }
208
209 {
210 let gid = GlobalId::User(1);
211 let name = [
212 "greptime".to_string(),
213 "public".to_string(),
214 "numbers_with_ts".to_string(),
215 ];
216 tri_map.insert(Some(name.clone()), Some(1025), gid);
217 }
218
219 let dummy_source = FlowDummyTableSource::default();
220
221 let mut ctx = FlownodeContext::new(Box::new(dummy_source));
222 ctx.table_repr = tri_map;
223 ctx.query_context = Some(Arc::new(QueryContext::with("greptime", "public")));
224
225 ctx
226 }
227
228 pub fn create_test_query_engine() -> Arc<dyn QueryEngine> {
229 let catalog_list = catalog::memory::new_memory_catalog_manager().unwrap();
230 let req = RegisterTableRequest {
231 catalog: DEFAULT_CATALOG_NAME.to_string(),
232 schema: DEFAULT_SCHEMA_NAME.to_string(),
233 table_name: NUMBERS_TABLE_NAME.to_string(),
234 table_id: NUMBERS_TABLE_ID,
235 table: NumbersTable::table(NUMBERS_TABLE_ID),
236 };
237 catalog_list.register_table_sync(req).unwrap();
238
239 let schema = vec![
240 datatypes::schema::ColumnSchema::new("number", CDT::uint32_datatype(), false),
241 datatypes::schema::ColumnSchema::new(
242 "ts",
243 CDT::timestamp_millisecond_datatype(),
244 false,
245 ),
246 ];
247 let mut columns = vec![];
248 let numbers = (1..=10).collect_vec();
249 let column: VectorRef = Arc::new(<u32 as Scalar>::VectorType::from_vec(numbers));
250 columns.push(column);
251
252 let ts = (1..=10).collect_vec();
253 let mut builder = TimestampMillisecondVectorBuilder::with_capacity(10);
254 ts.into_iter()
255 .map(|v| builder.push(Some(TimestampMillisecond::new(v))))
256 .count();
257 let column: VectorRef = builder.to_vector_cloned();
258 columns.push(column);
259
260 let schema = Arc::new(Schema::new(schema));
261 let recordbatch = common_recordbatch::RecordBatch::new(schema, columns).unwrap();
262 let table = MemTable::table("numbers_with_ts", recordbatch);
263
264 let req_with_ts = RegisterTableRequest {
265 catalog: DEFAULT_CATALOG_NAME.to_string(),
266 schema: DEFAULT_SCHEMA_NAME.to_string(),
267 table_name: "numbers_with_ts".to_string(),
268 table_id: 1024,
269 table,
270 };
271 catalog_list.register_table_sync(req_with_ts).unwrap();
272
273 let factory = query::QueryEngineFactory::new(
274 catalog_list,
275 None,
276 None,
277 None,
278 None,
279 false,
280 QueryOptions::default(),
281 );
282
283 let engine = factory.query_engine();
284 register_function_to_query_engine(&engine);
285
286 assert_eq!("datafusion", engine.name());
287 engine
288 }
289
290 pub async fn sql_to_substrait(engine: Arc<dyn QueryEngine>, sql: &str) -> proto::Plan {
291 let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
293 let plan = engine
294 .planner()
295 .plan(&stmt, QueryContext::arc())
296 .await
297 .unwrap();
298 let plan = apply_df_optimizer(plan, &QueryContext::arc())
299 .await
300 .unwrap();
301
302 let bytes = DFLogicalSubstraitConvertor {}
304 .encode(&plan, DefaultSerializer)
305 .unwrap();
306
307 proto::Plan::decode(bytes).unwrap()
308 }
309
310 #[tokio::test]
312 async fn test_missing_key_check() {
313 let engine = create_test_query_engine();
314 let sql = "SELECT avg(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour'), number";
315
316 let stmt = QueryLanguageParser::parse_sql(sql, &QueryContext::arc()).unwrap();
317 let plan = engine
318 .planner()
319 .plan(&stmt, QueryContext::arc())
320 .await
321 .unwrap();
322 let plan = apply_df_optimizer(plan, &QueryContext::arc()).await;
323
324 assert!(plan.is_err());
325 }
326}