1use std::fmt;
16use std::fmt::Formatter;
17use std::hash::{Hash, Hasher};
18use std::sync::Arc;
19
20use arrow::array::{AsArray, BooleanArray};
21use common_function::scalars::matches_term::MatchesTermFinder;
22use datafusion::config::ConfigOptions;
23use datafusion::error::Result as DfResult;
24use datafusion::physical_optimizer::PhysicalOptimizerRule;
25use datafusion::physical_plan::ExecutionPlan;
26use datafusion::physical_plan::filter::{FilterExec, FilterExecBuilder};
27use datafusion_common::ScalarValue;
28use datafusion_common::tree_node::{Transformed, TreeNode};
29use datafusion_expr::ColumnarValue;
30use datafusion_physical_expr::expressions::Literal;
31use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr};
32
33#[derive(Debug)]
39pub struct PreCompiledMatchesTermExpr {
40 text: Arc<dyn PhysicalExpr>,
42 term: String,
44 finder: MatchesTermFinder,
46
47 probes: Vec<String>,
50}
51
52impl fmt::Display for PreCompiledMatchesTermExpr {
53 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54 write!(
55 f,
56 "MatchesConstTerm({}, term: \"{}\", probes: {:?})",
57 self.text, self.term, self.probes
58 )
59 }
60}
61
62impl Hash for PreCompiledMatchesTermExpr {
63 fn hash<H: Hasher>(&self, state: &mut H) {
64 self.text.hash(state);
65 self.term.hash(state);
66 }
67}
68
69impl PartialEq for PreCompiledMatchesTermExpr {
70 fn eq(&self, other: &Self) -> bool {
71 self.text.eq(&other.text) && self.term.eq(&other.term)
72 }
73}
74
75impl Eq for PreCompiledMatchesTermExpr {}
76
77impl PhysicalExpr for PreCompiledMatchesTermExpr {
78 fn as_any(&self) -> &dyn std::any::Any {
79 self
80 }
81
82 fn data_type(
83 &self,
84 _input_schema: &arrow_schema::Schema,
85 ) -> datafusion::error::Result<arrow_schema::DataType> {
86 Ok(arrow_schema::DataType::Boolean)
87 }
88
89 fn nullable(&self, input_schema: &arrow_schema::Schema) -> datafusion::error::Result<bool> {
90 self.text.nullable(input_schema)
91 }
92
93 fn evaluate(
94 &self,
95 batch: &common_recordbatch::DfRecordBatch,
96 ) -> datafusion::error::Result<ColumnarValue> {
97 let num_rows = batch.num_rows();
98
99 let text_value = self.text.evaluate(batch)?;
100 let array = text_value.into_array(num_rows)?;
101 let str_array = array.as_string::<i32>();
102
103 let mut result = BooleanArray::builder(num_rows);
104 for text in str_array {
105 match text {
106 Some(text) => {
107 result.append_value(self.finder.find(text));
108 }
109 None => {
110 result.append_null();
111 }
112 }
113 }
114
115 Ok(ColumnarValue::Array(Arc::new(result.finish())))
116 }
117
118 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
119 vec![&self.text]
120 }
121
122 fn with_new_children(
123 self: Arc<Self>,
124 children: Vec<Arc<dyn PhysicalExpr>>,
125 ) -> datafusion::error::Result<Arc<dyn PhysicalExpr>> {
126 Ok(Arc::new(PreCompiledMatchesTermExpr {
127 text: children[0].clone(),
128 term: self.term.clone(),
129 finder: self.finder.clone(),
130 probes: self.probes.clone(),
131 }))
132 }
133
134 fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
135 write!(f, "{}", self)
136 }
137}
138
139#[derive(Debug)]
159pub struct MatchesConstantTermOptimizer;
160
161impl PhysicalOptimizerRule for MatchesConstantTermOptimizer {
162 fn optimize(
163 &self,
164 plan: Arc<dyn ExecutionPlan>,
165 _config: &ConfigOptions,
166 ) -> DfResult<Arc<dyn ExecutionPlan>> {
167 let res = plan
168 .transform_down(&|plan: Arc<dyn ExecutionPlan>| {
169 if let Some(filter) = plan.as_any().downcast_ref::<FilterExec>() {
170 let pred = filter.predicate().clone();
171 let new_pred = pred.transform_down(&|expr: Arc<dyn PhysicalExpr>| {
172 if let Some(func) = expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
173 if !func.name().eq_ignore_ascii_case("matches_term") {
174 return Ok(Transformed::no(expr));
175 }
176 let args = func.args();
177 if args.len() != 2 {
178 return Ok(Transformed::no(expr));
179 }
180
181 if let Some(lit) = args[1].as_any().downcast_ref::<Literal>()
182 && let ScalarValue::Utf8(Some(term)) = lit.value()
183 {
184 let finder = MatchesTermFinder::new(term);
185
186 let probes = term
188 .split(|c: char| !c.is_alphanumeric() && c != '_')
189 .filter(|s| !s.is_empty())
190 .map(|s| s.to_string())
191 .collect();
192
193 let expr = PreCompiledMatchesTermExpr {
194 text: args[0].clone(),
195 term: term.clone(),
196 finder,
197 probes,
198 };
199
200 return Ok(Transformed::yes(Arc::new(expr)));
201 }
202 }
203
204 Ok(Transformed::no(expr))
205 })?;
206
207 if new_pred.transformed {
208 let exec = FilterExecBuilder::new(new_pred.data, filter.input().clone())
209 .with_default_selectivity(filter.default_selectivity())
210 .apply_projection_by_ref(filter.projection().as_ref())
211 .and_then(|x| x.build())?;
212 return Ok(Transformed::yes(Arc::new(exec) as _));
213 }
214 }
215
216 Ok(Transformed::no(plan))
217 })?
218 .data;
219
220 Ok(res)
221 }
222
223 fn name(&self) -> &str {
224 "MatchesConstantTerm"
225 }
226
227 fn schema_check(&self) -> bool {
228 false
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use std::sync::Arc;
235
236 use arrow::array::{ArrayRef, StringArray};
237 use arrow::datatypes::{DataType, Field, Schema};
238 use arrow::record_batch::RecordBatch;
239 use catalog::RegisterTableRequest;
240 use catalog::memory::MemoryCatalogManager;
241 use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
242 use common_function::scalars::matches_term::MatchesTermFunction;
243 use common_function::scalars::udf::create_udf;
244 use datafusion::datasource::memory::MemorySourceConfig;
245 use datafusion::datasource::source::DataSourceExec;
246 use datafusion::physical_optimizer::PhysicalOptimizerRule;
247 use datafusion::physical_plan::filter::FilterExec;
248 use datafusion::physical_plan::get_plan_string;
249 use datafusion_common::{Column, DFSchema};
250 use datafusion_expr::expr::ScalarFunction;
251 use datafusion_expr::{Expr, Literal, ScalarUDF};
252 use datafusion_physical_expr::{ScalarFunctionExpr, create_physical_expr};
253 use datatypes::prelude::ConcreteDataType;
254 use datatypes::schema::ColumnSchema;
255 use session::context::QueryContext;
256 use table::metadata::{TableInfoBuilder, TableMetaBuilder};
257 use table::test_util::EmptyTable;
258
259 use super::*;
260 use crate::parser::QueryLanguageParser;
261 use crate::{QueryEngineFactory, QueryEngineRef};
262
263 fn create_test_batch() -> RecordBatch {
264 let schema = Schema::new(vec![Field::new("text", DataType::Utf8, true)]);
265
266 let text_array = StringArray::from(vec![
267 Some("hello world"),
268 Some("greeting"),
269 Some("hello there"),
270 None,
271 ]);
272
273 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(text_array) as ArrayRef]).unwrap()
274 }
275
276 fn create_test_engine() -> QueryEngineRef {
277 let table_name = "test".to_string();
278 let columns = vec![
279 ColumnSchema::new(
280 "text".to_string(),
281 ConcreteDataType::string_datatype(),
282 false,
283 ),
284 ColumnSchema::new(
285 "timestamp".to_string(),
286 ConcreteDataType::timestamp_millisecond_datatype(),
287 false,
288 )
289 .with_time_index(true),
290 ];
291
292 let schema = Arc::new(datatypes::schema::Schema::new(columns));
293 let table_meta = TableMetaBuilder::empty()
294 .schema(schema)
295 .primary_key_indices(vec![])
296 .value_indices(vec![0])
297 .next_column_id(2)
298 .build()
299 .unwrap();
300 let table_info = TableInfoBuilder::default()
301 .name(&table_name)
302 .meta(table_meta)
303 .build()
304 .unwrap();
305 let table = EmptyTable::from_table_info(&table_info);
306 let catalog_list = MemoryCatalogManager::with_default_setup();
307 assert!(
308 catalog_list
309 .register_table_sync(RegisterTableRequest {
310 catalog: DEFAULT_CATALOG_NAME.to_string(),
311 schema: DEFAULT_SCHEMA_NAME.to_string(),
312 table_name,
313 table_id: 1024,
314 table,
315 })
316 .is_ok()
317 );
318 QueryEngineFactory::new(
319 catalog_list,
320 None,
321 None,
322 None,
323 None,
324 false,
325 Default::default(),
326 )
327 .query_engine()
328 }
329
330 fn matches_term_udf() -> Arc<ScalarUDF> {
331 Arc::new(create_udf(Arc::new(MatchesTermFunction::default())))
332 }
333
334 #[test]
335 fn test_matches_term_optimization() {
336 let batch = create_test_batch();
337
338 let predicate = create_physical_expr(
340 &Expr::ScalarFunction(ScalarFunction::new_udf(
341 matches_term_udf(),
342 vec![Expr::Column(Column::from_name("text")), "hello".lit()],
343 )),
344 &DFSchema::try_from(batch.schema().clone()).unwrap(),
345 &Default::default(),
346 )
347 .unwrap();
348
349 let input = DataSourceExec::from_data_source(
350 MemorySourceConfig::try_new(&[vec![batch.clone()]], batch.schema(), None).unwrap(),
351 );
352 let filter = FilterExec::try_new(predicate, input).unwrap();
353
354 let optimizer = MatchesConstantTermOptimizer;
356 let optimized_plan = optimizer
357 .optimize(Arc::new(filter), &Default::default())
358 .unwrap();
359
360 let optimized_filter = optimized_plan
361 .as_any()
362 .downcast_ref::<FilterExec>()
363 .unwrap();
364 let predicate = optimized_filter.predicate();
365
366 assert!(
368 std::any::TypeId::of::<PreCompiledMatchesTermExpr>() == predicate.as_any().type_id()
369 );
370 }
371
372 #[test]
373 fn test_matches_term_no_optimization() {
374 let batch = create_test_batch();
375
376 let predicate = create_physical_expr(
378 &Expr::ScalarFunction(ScalarFunction::new_udf(
379 matches_term_udf(),
380 vec![
381 Expr::Column(Column::from_name("text")),
382 Expr::Column(Column::from_name("text")),
383 ],
384 )),
385 &DFSchema::try_from(batch.schema().clone()).unwrap(),
386 &Default::default(),
387 )
388 .unwrap();
389
390 let input = DataSourceExec::from_data_source(
391 MemorySourceConfig::try_new(&[vec![batch.clone()]], batch.schema(), None).unwrap(),
392 );
393 let filter = FilterExec::try_new(predicate, input).unwrap();
394
395 let optimizer = MatchesConstantTermOptimizer;
396 let optimized_plan = optimizer
397 .optimize(Arc::new(filter), &Default::default())
398 .unwrap();
399
400 let optimized_filter = optimized_plan
401 .as_any()
402 .downcast_ref::<FilterExec>()
403 .unwrap();
404 let predicate = optimized_filter.predicate();
405
406 assert!(std::any::TypeId::of::<ScalarFunctionExpr>() == predicate.as_any().type_id());
408 }
409
410 #[tokio::test]
411 async fn test_matches_term_optimization_from_sql() {
412 let sql = "WITH base AS (
413 SELECT text, timestamp FROM test
414 WHERE MATCHES_TERM(text, 'hello wo_rld')
415 AND timestamp > '2025-01-01 00:00:00'
416 ),
417 subquery1 AS (
418 SELECT * FROM base
419 WHERE MATCHES_TERM(text, 'world')
420 ),
421 subquery2 AS (
422 SELECT * FROM test
423 WHERE MATCHES_TERM(text, 'greeting')
424 AND timestamp < '2025-01-02 00:00:00'
425 ),
426 union_result AS (
427 SELECT * FROM subquery1
428 UNION ALL
429 SELECT * FROM subquery2
430 ),
431 joined_data AS (
432 SELECT a.text, a.timestamp, b.text as other_text
433 FROM union_result a
434 JOIN test b ON a.timestamp = b.timestamp
435 WHERE MATCHES_TERM(a.text, 'there')
436 )
437 SELECT text, other_text
438 FROM joined_data
439 WHERE MATCHES_TERM(text, '42')
440 AND MATCHES_TERM(other_text, 'foo')";
441
442 let query_ctx = QueryContext::arc();
443
444 let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx).unwrap();
445 let engine = create_test_engine();
446 let logical_plan = engine
447 .planner()
448 .plan(&stmt, query_ctx.clone())
449 .await
450 .unwrap();
451
452 let engine_ctx = engine.engine_context(query_ctx);
453 let state = engine_ctx.state();
454
455 let analyzed_plan = state
456 .analyzer()
457 .execute_and_check(logical_plan.clone(), state.config_options(), |_, _| {})
458 .unwrap();
459
460 let optimized_plan = state
461 .optimizer()
462 .optimize(analyzed_plan, state, |_, _| {})
463 .unwrap();
464
465 let physical_plan = state
466 .query_planner()
467 .create_physical_plan(&optimized_plan, state)
468 .await
469 .unwrap();
470
471 let plan_str = get_plan_string(&physical_plan).join("\n");
472 assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"foo\", probes: [\"foo\"]"));
473 assert!(plan_str.contains(
474 "MatchesConstTerm(text@0, term: \"hello wo_rld\", probes: [\"hello\", \"wo_rld\"]"
475 ));
476 assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"world\", probes: [\"world\"]"));
477 assert!(
478 plan_str
479 .contains("MatchesConstTerm(text@0, term: \"greeting\", probes: [\"greeting\"]")
480 );
481 assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"there\", probes: [\"there\"]"));
482 assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"42\", probes: [\"42\"]"));
483 assert!(!plan_str.contains("matches_term"))
484 }
485}