1use std::fmt;
16use std::fmt::Formatter;
17use std::hash::{Hash, Hasher};
18use std::sync::Arc;
19
20use arrow::array::{AsArray, BooleanArray};
21use common_function::scalars::matches_term::MatchesTermFinder;
22use datafusion::config::ConfigOptions;
23use datafusion::error::Result as DfResult;
24use datafusion::physical_optimizer::PhysicalOptimizerRule;
25use datafusion::physical_plan::ExecutionPlan;
26use datafusion::physical_plan::filter::FilterExec;
27use datafusion_common::ScalarValue;
28use datafusion_common::tree_node::{Transformed, TreeNode};
29use datafusion_expr::ColumnarValue;
30use datafusion_physical_expr::expressions::Literal;
31use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr};
32
33#[derive(Debug)]
39pub struct PreCompiledMatchesTermExpr {
40 text: Arc<dyn PhysicalExpr>,
42 term: String,
44 finder: MatchesTermFinder,
46
47 probes: Vec<String>,
50}
51
52impl fmt::Display for PreCompiledMatchesTermExpr {
53 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54 write!(
55 f,
56 "MatchesConstTerm({}, term: \"{}\", probes: {:?})",
57 self.text, self.term, self.probes
58 )
59 }
60}
61
62impl Hash for PreCompiledMatchesTermExpr {
63 fn hash<H: Hasher>(&self, state: &mut H) {
64 self.text.hash(state);
65 self.term.hash(state);
66 }
67}
68
69impl PartialEq for PreCompiledMatchesTermExpr {
70 fn eq(&self, other: &Self) -> bool {
71 self.text.eq(&other.text) && self.term.eq(&other.term)
72 }
73}
74
75impl Eq for PreCompiledMatchesTermExpr {}
76
77impl PhysicalExpr for PreCompiledMatchesTermExpr {
78 fn as_any(&self) -> &dyn std::any::Any {
79 self
80 }
81
82 fn data_type(
83 &self,
84 _input_schema: &arrow_schema::Schema,
85 ) -> datafusion::error::Result<arrow_schema::DataType> {
86 Ok(arrow_schema::DataType::Boolean)
87 }
88
89 fn nullable(&self, input_schema: &arrow_schema::Schema) -> datafusion::error::Result<bool> {
90 self.text.nullable(input_schema)
91 }
92
93 fn evaluate(
94 &self,
95 batch: &common_recordbatch::DfRecordBatch,
96 ) -> datafusion::error::Result<ColumnarValue> {
97 let num_rows = batch.num_rows();
98
99 let text_value = self.text.evaluate(batch)?;
100 let array = text_value.into_array(num_rows)?;
101 let str_array = array.as_string::<i32>();
102
103 let mut result = BooleanArray::builder(num_rows);
104 for text in str_array {
105 match text {
106 Some(text) => {
107 result.append_value(self.finder.find(text));
108 }
109 None => {
110 result.append_null();
111 }
112 }
113 }
114
115 Ok(ColumnarValue::Array(Arc::new(result.finish())))
116 }
117
118 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
119 vec![&self.text]
120 }
121
122 fn with_new_children(
123 self: Arc<Self>,
124 children: Vec<Arc<dyn PhysicalExpr>>,
125 ) -> datafusion::error::Result<Arc<dyn PhysicalExpr>> {
126 Ok(Arc::new(PreCompiledMatchesTermExpr {
127 text: children[0].clone(),
128 term: self.term.clone(),
129 finder: self.finder.clone(),
130 probes: self.probes.clone(),
131 }))
132 }
133
134 fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
135 write!(f, "{}", self)
136 }
137}
138
139#[derive(Debug)]
159pub struct MatchesConstantTermOptimizer;
160
161impl PhysicalOptimizerRule for MatchesConstantTermOptimizer {
162 fn optimize(
163 &self,
164 plan: Arc<dyn ExecutionPlan>,
165 _config: &ConfigOptions,
166 ) -> DfResult<Arc<dyn ExecutionPlan>> {
167 let res = plan
168 .transform_down(&|plan: Arc<dyn ExecutionPlan>| {
169 if let Some(filter) = plan.as_any().downcast_ref::<FilterExec>() {
170 let pred = filter.predicate().clone();
171 let new_pred = pred.transform_down(&|expr: Arc<dyn PhysicalExpr>| {
172 if let Some(func) = expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
173 if !func.name().eq_ignore_ascii_case("matches_term") {
174 return Ok(Transformed::no(expr));
175 }
176 let args = func.args();
177 if args.len() != 2 {
178 return Ok(Transformed::no(expr));
179 }
180
181 if let Some(lit) = args[1].as_any().downcast_ref::<Literal>()
182 && let ScalarValue::Utf8(Some(term)) = lit.value()
183 {
184 let finder = MatchesTermFinder::new(term);
185
186 let probes = term
188 .split(|c: char| !c.is_alphanumeric() && c != '_')
189 .filter(|s| !s.is_empty())
190 .map(|s| s.to_string())
191 .collect();
192
193 let expr = PreCompiledMatchesTermExpr {
194 text: args[0].clone(),
195 term: term.clone(),
196 finder,
197 probes,
198 };
199
200 return Ok(Transformed::yes(Arc::new(expr)));
201 }
202 }
203
204 Ok(Transformed::no(expr))
205 })?;
206
207 if new_pred.transformed {
208 let exec = FilterExec::try_new(new_pred.data, filter.input().clone())?
209 .with_default_selectivity(filter.default_selectivity())?
210 .with_projection(filter.projection().cloned())?;
211 return Ok(Transformed::yes(Arc::new(exec) as _));
212 }
213 }
214
215 Ok(Transformed::no(plan))
216 })?
217 .data;
218
219 Ok(res)
220 }
221
222 fn name(&self) -> &str {
223 "MatchesConstantTerm"
224 }
225
226 fn schema_check(&self) -> bool {
227 false
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use std::sync::Arc;
234
235 use arrow::array::{ArrayRef, StringArray};
236 use arrow::datatypes::{DataType, Field, Schema};
237 use arrow::record_batch::RecordBatch;
238 use catalog::RegisterTableRequest;
239 use catalog::memory::MemoryCatalogManager;
240 use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
241 use common_function::scalars::matches_term::MatchesTermFunction;
242 use common_function::scalars::udf::create_udf;
243 use datafusion::datasource::memory::MemorySourceConfig;
244 use datafusion::datasource::source::DataSourceExec;
245 use datafusion::physical_optimizer::PhysicalOptimizerRule;
246 use datafusion::physical_plan::filter::FilterExec;
247 use datafusion::physical_plan::get_plan_string;
248 use datafusion_common::{Column, DFSchema};
249 use datafusion_expr::expr::ScalarFunction;
250 use datafusion_expr::{Expr, Literal, ScalarUDF};
251 use datafusion_physical_expr::{ScalarFunctionExpr, create_physical_expr};
252 use datatypes::prelude::ConcreteDataType;
253 use datatypes::schema::ColumnSchema;
254 use session::context::QueryContext;
255 use table::metadata::{TableInfoBuilder, TableMetaBuilder};
256 use table::test_util::EmptyTable;
257
258 use super::*;
259 use crate::parser::QueryLanguageParser;
260 use crate::{QueryEngineFactory, QueryEngineRef};
261
262 fn create_test_batch() -> RecordBatch {
263 let schema = Schema::new(vec![Field::new("text", DataType::Utf8, true)]);
264
265 let text_array = StringArray::from(vec![
266 Some("hello world"),
267 Some("greeting"),
268 Some("hello there"),
269 None,
270 ]);
271
272 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(text_array) as ArrayRef]).unwrap()
273 }
274
275 fn create_test_engine() -> QueryEngineRef {
276 let table_name = "test".to_string();
277 let columns = vec![
278 ColumnSchema::new(
279 "text".to_string(),
280 ConcreteDataType::string_datatype(),
281 false,
282 ),
283 ColumnSchema::new(
284 "timestamp".to_string(),
285 ConcreteDataType::timestamp_millisecond_datatype(),
286 false,
287 )
288 .with_time_index(true),
289 ];
290
291 let schema = Arc::new(datatypes::schema::Schema::new(columns));
292 let table_meta = TableMetaBuilder::empty()
293 .schema(schema)
294 .primary_key_indices(vec![])
295 .value_indices(vec![0])
296 .next_column_id(2)
297 .build()
298 .unwrap();
299 let table_info = TableInfoBuilder::default()
300 .name(&table_name)
301 .meta(table_meta)
302 .build()
303 .unwrap();
304 let table = EmptyTable::from_table_info(&table_info);
305 let catalog_list = MemoryCatalogManager::with_default_setup();
306 assert!(
307 catalog_list
308 .register_table_sync(RegisterTableRequest {
309 catalog: DEFAULT_CATALOG_NAME.to_string(),
310 schema: DEFAULT_SCHEMA_NAME.to_string(),
311 table_name,
312 table_id: 1024,
313 table,
314 })
315 .is_ok()
316 );
317 QueryEngineFactory::new(
318 catalog_list,
319 None,
320 None,
321 None,
322 None,
323 false,
324 Default::default(),
325 )
326 .query_engine()
327 }
328
329 fn matches_term_udf() -> Arc<ScalarUDF> {
330 Arc::new(create_udf(Arc::new(MatchesTermFunction::default())))
331 }
332
333 #[test]
334 fn test_matches_term_optimization() {
335 let batch = create_test_batch();
336
337 let predicate = create_physical_expr(
339 &Expr::ScalarFunction(ScalarFunction::new_udf(
340 matches_term_udf(),
341 vec![Expr::Column(Column::from_name("text")), "hello".lit()],
342 )),
343 &DFSchema::try_from(batch.schema().clone()).unwrap(),
344 &Default::default(),
345 )
346 .unwrap();
347
348 let input = DataSourceExec::from_data_source(
349 MemorySourceConfig::try_new(&[vec![batch.clone()]], batch.schema(), None).unwrap(),
350 );
351 let filter = FilterExec::try_new(predicate, input).unwrap();
352
353 let optimizer = MatchesConstantTermOptimizer;
355 let optimized_plan = optimizer
356 .optimize(Arc::new(filter), &Default::default())
357 .unwrap();
358
359 let optimized_filter = optimized_plan
360 .as_any()
361 .downcast_ref::<FilterExec>()
362 .unwrap();
363 let predicate = optimized_filter.predicate();
364
365 assert!(
367 std::any::TypeId::of::<PreCompiledMatchesTermExpr>() == predicate.as_any().type_id()
368 );
369 }
370
371 #[test]
372 fn test_matches_term_no_optimization() {
373 let batch = create_test_batch();
374
375 let predicate = create_physical_expr(
377 &Expr::ScalarFunction(ScalarFunction::new_udf(
378 matches_term_udf(),
379 vec![
380 Expr::Column(Column::from_name("text")),
381 Expr::Column(Column::from_name("text")),
382 ],
383 )),
384 &DFSchema::try_from(batch.schema().clone()).unwrap(),
385 &Default::default(),
386 )
387 .unwrap();
388
389 let input = DataSourceExec::from_data_source(
390 MemorySourceConfig::try_new(&[vec![batch.clone()]], batch.schema(), None).unwrap(),
391 );
392 let filter = FilterExec::try_new(predicate, input).unwrap();
393
394 let optimizer = MatchesConstantTermOptimizer;
395 let optimized_plan = optimizer
396 .optimize(Arc::new(filter), &Default::default())
397 .unwrap();
398
399 let optimized_filter = optimized_plan
400 .as_any()
401 .downcast_ref::<FilterExec>()
402 .unwrap();
403 let predicate = optimized_filter.predicate();
404
405 assert!(std::any::TypeId::of::<ScalarFunctionExpr>() == predicate.as_any().type_id());
407 }
408
409 #[tokio::test]
410 async fn test_matches_term_optimization_from_sql() {
411 let sql = "WITH base AS (
412 SELECT text, timestamp FROM test
413 WHERE MATCHES_TERM(text, 'hello wo_rld')
414 AND timestamp > '2025-01-01 00:00:00'
415 ),
416 subquery1 AS (
417 SELECT * FROM base
418 WHERE MATCHES_TERM(text, 'world')
419 ),
420 subquery2 AS (
421 SELECT * FROM test
422 WHERE MATCHES_TERM(text, 'greeting')
423 AND timestamp < '2025-01-02 00:00:00'
424 ),
425 union_result AS (
426 SELECT * FROM subquery1
427 UNION ALL
428 SELECT * FROM subquery2
429 ),
430 joined_data AS (
431 SELECT a.text, a.timestamp, b.text as other_text
432 FROM union_result a
433 JOIN test b ON a.timestamp = b.timestamp
434 WHERE MATCHES_TERM(a.text, 'there')
435 )
436 SELECT text, other_text
437 FROM joined_data
438 WHERE MATCHES_TERM(text, '42')
439 AND MATCHES_TERM(other_text, 'foo')";
440
441 let query_ctx = QueryContext::arc();
442
443 let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx).unwrap();
444 let engine = create_test_engine();
445 let logical_plan = engine
446 .planner()
447 .plan(&stmt, query_ctx.clone())
448 .await
449 .unwrap();
450
451 let engine_ctx = engine.engine_context(query_ctx);
452 let state = engine_ctx.state();
453
454 let analyzed_plan = state
455 .analyzer()
456 .execute_and_check(logical_plan.clone(), state.config_options(), |_, _| {})
457 .unwrap();
458
459 let optimized_plan = state
460 .optimizer()
461 .optimize(analyzed_plan, state, |_, _| {})
462 .unwrap();
463
464 let physical_plan = state
465 .query_planner()
466 .create_physical_plan(&optimized_plan, state)
467 .await
468 .unwrap();
469
470 let plan_str = get_plan_string(&physical_plan).join("\n");
471 assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"foo\", probes: [\"foo\"]"));
472 assert!(plan_str.contains(
473 "MatchesConstTerm(text@0, term: \"hello wo_rld\", probes: [\"hello\", \"wo_rld\"]"
474 ));
475 assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"world\", probes: [\"world\"]"));
476 assert!(
477 plan_str
478 .contains("MatchesConstTerm(text@0, term: \"greeting\", probes: [\"greeting\"]")
479 );
480 assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"there\", probes: [\"there\"]"));
481 assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"42\", probes: [\"42\"]"));
482 assert!(!plan_str.contains("matches_term"))
483 }
484}