1use std::fmt;
16use std::fmt::Formatter;
17use std::hash::{Hash, Hasher};
18use std::sync::Arc;
19
20use arrow::array::{AsArray, BooleanArray};
21use common_function::scalars::matches_term::MatchesTermFinder;
22use datafusion::config::ConfigOptions;
23use datafusion::error::Result as DfResult;
24use datafusion::physical_optimizer::PhysicalOptimizerRule;
25use datafusion::physical_plan::ExecutionPlan;
26use datafusion::physical_plan::filter::FilterExec;
27use datafusion_common::ScalarValue;
28use datafusion_common::tree_node::{Transformed, TreeNode};
29use datafusion_expr::ColumnarValue;
30use datafusion_physical_expr::expressions::Literal;
31use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr};
32
33#[derive(Debug)]
39pub struct PreCompiledMatchesTermExpr {
40 text: Arc<dyn PhysicalExpr>,
42 term: String,
44 finder: MatchesTermFinder,
46
47 probes: Vec<String>,
50}
51
52impl fmt::Display for PreCompiledMatchesTermExpr {
53 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54 write!(
55 f,
56 "MatchesConstTerm({}, term: \"{}\", probes: {:?})",
57 self.text, self.term, self.probes
58 )
59 }
60}
61
62impl Hash for PreCompiledMatchesTermExpr {
63 fn hash<H: Hasher>(&self, state: &mut H) {
64 self.text.hash(state);
65 self.term.hash(state);
66 }
67}
68
69impl PartialEq for PreCompiledMatchesTermExpr {
70 fn eq(&self, other: &Self) -> bool {
71 self.text.eq(&other.text) && self.term.eq(&other.term)
72 }
73}
74
75impl Eq for PreCompiledMatchesTermExpr {}
76
77impl PhysicalExpr for PreCompiledMatchesTermExpr {
78 fn as_any(&self) -> &dyn std::any::Any {
79 self
80 }
81
82 fn data_type(
83 &self,
84 _input_schema: &arrow_schema::Schema,
85 ) -> datafusion::error::Result<arrow_schema::DataType> {
86 Ok(arrow_schema::DataType::Boolean)
87 }
88
89 fn nullable(&self, input_schema: &arrow_schema::Schema) -> datafusion::error::Result<bool> {
90 self.text.nullable(input_schema)
91 }
92
93 fn evaluate(
94 &self,
95 batch: &common_recordbatch::DfRecordBatch,
96 ) -> datafusion::error::Result<ColumnarValue> {
97 let num_rows = batch.num_rows();
98
99 let text_value = self.text.evaluate(batch)?;
100 let array = text_value.into_array(num_rows)?;
101 let str_array = array.as_string::<i32>();
102
103 let mut result = BooleanArray::builder(num_rows);
104 for text in str_array {
105 match text {
106 Some(text) => {
107 result.append_value(self.finder.find(text));
108 }
109 None => {
110 result.append_null();
111 }
112 }
113 }
114
115 Ok(ColumnarValue::Array(Arc::new(result.finish())))
116 }
117
118 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
119 vec![&self.text]
120 }
121
122 fn with_new_children(
123 self: Arc<Self>,
124 children: Vec<Arc<dyn PhysicalExpr>>,
125 ) -> datafusion::error::Result<Arc<dyn PhysicalExpr>> {
126 Ok(Arc::new(PreCompiledMatchesTermExpr {
127 text: children[0].clone(),
128 term: self.term.clone(),
129 finder: self.finder.clone(),
130 probes: self.probes.clone(),
131 }))
132 }
133
134 fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
135 write!(f, "{}", self)
136 }
137}
138
139#[derive(Debug)]
159pub struct MatchesConstantTermOptimizer;
160
161impl PhysicalOptimizerRule for MatchesConstantTermOptimizer {
162 fn optimize(
163 &self,
164 plan: Arc<dyn ExecutionPlan>,
165 _config: &ConfigOptions,
166 ) -> DfResult<Arc<dyn ExecutionPlan>> {
167 let res = plan
168 .transform_down(&|plan: Arc<dyn ExecutionPlan>| {
169 if let Some(filter) = plan.as_any().downcast_ref::<FilterExec>() {
170 let pred = filter.predicate().clone();
171 let new_pred = pred.transform_down(&|expr: Arc<dyn PhysicalExpr>| {
172 if let Some(func) = expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
173 if !func.name().eq_ignore_ascii_case("matches_term") {
174 return Ok(Transformed::no(expr));
175 }
176 let args = func.args();
177 if args.len() != 2 {
178 return Ok(Transformed::no(expr));
179 }
180
181 if let Some(lit) = args[1].as_any().downcast_ref::<Literal>()
182 && let ScalarValue::Utf8(Some(term)) = lit.value()
183 {
184 let finder = MatchesTermFinder::new(term);
185
186 let probes = term
188 .split(|c: char| !c.is_alphanumeric() && c != '_')
189 .filter(|s| !s.is_empty())
190 .map(|s| s.to_string())
191 .collect();
192
193 let expr = PreCompiledMatchesTermExpr {
194 text: args[0].clone(),
195 term: term.to_string(),
196 finder,
197 probes,
198 };
199
200 return Ok(Transformed::yes(Arc::new(expr)));
201 }
202 }
203
204 Ok(Transformed::no(expr))
205 })?;
206
207 if new_pred.transformed {
208 let exec = FilterExec::try_new(new_pred.data, filter.input().clone())?
209 .with_default_selectivity(filter.default_selectivity())?
210 .with_projection(filter.projection().cloned())?;
211 return Ok(Transformed::yes(Arc::new(exec) as _));
212 }
213 }
214
215 Ok(Transformed::no(plan))
216 })?
217 .data;
218
219 Ok(res)
220 }
221
222 fn name(&self) -> &str {
223 "MatchesConstantTerm"
224 }
225
226 fn schema_check(&self) -> bool {
227 false
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use std::sync::Arc;
234
235 use arrow::array::{ArrayRef, StringArray};
236 use arrow::datatypes::{DataType, Field, Schema};
237 use arrow::record_batch::RecordBatch;
238 use catalog::RegisterTableRequest;
239 use catalog::memory::MemoryCatalogManager;
240 use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
241 use common_function::scalars::matches_term::MatchesTermFunction;
242 use common_function::scalars::udf::create_udf;
243 use common_function::state::FunctionState;
244 use datafusion::datasource::memory::MemorySourceConfig;
245 use datafusion::datasource::source::DataSourceExec;
246 use datafusion::physical_optimizer::PhysicalOptimizerRule;
247 use datafusion::physical_plan::filter::FilterExec;
248 use datafusion::physical_plan::get_plan_string;
249 use datafusion_common::{Column, DFSchema};
250 use datafusion_expr::expr::ScalarFunction;
251 use datafusion_expr::{Expr, Literal, ScalarUDF};
252 use datafusion_physical_expr::{ScalarFunctionExpr, create_physical_expr};
253 use datatypes::prelude::ConcreteDataType;
254 use datatypes::schema::ColumnSchema;
255 use session::context::QueryContext;
256 use table::metadata::{TableInfoBuilder, TableMetaBuilder};
257 use table::test_util::EmptyTable;
258
259 use super::*;
260 use crate::parser::QueryLanguageParser;
261 use crate::{QueryEngineFactory, QueryEngineRef};
262
263 fn create_test_batch() -> RecordBatch {
264 let schema = Schema::new(vec![Field::new("text", DataType::Utf8, true)]);
265
266 let text_array = StringArray::from(vec![
267 Some("hello world"),
268 Some("greeting"),
269 Some("hello there"),
270 None,
271 ]);
272
273 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(text_array) as ArrayRef]).unwrap()
274 }
275
276 fn create_test_engine() -> QueryEngineRef {
277 let table_name = "test".to_string();
278 let columns = vec![
279 ColumnSchema::new(
280 "text".to_string(),
281 ConcreteDataType::string_datatype(),
282 false,
283 ),
284 ColumnSchema::new(
285 "timestamp".to_string(),
286 ConcreteDataType::timestamp_millisecond_datatype(),
287 false,
288 )
289 .with_time_index(true),
290 ];
291
292 let schema = Arc::new(datatypes::schema::Schema::new(columns));
293 let table_meta = TableMetaBuilder::empty()
294 .schema(schema)
295 .primary_key_indices(vec![])
296 .value_indices(vec![0])
297 .next_column_id(2)
298 .build()
299 .unwrap();
300 let table_info = TableInfoBuilder::default()
301 .name(&table_name)
302 .meta(table_meta)
303 .build()
304 .unwrap();
305 let table = EmptyTable::from_table_info(&table_info);
306 let catalog_list = MemoryCatalogManager::with_default_setup();
307 assert!(
308 catalog_list
309 .register_table_sync(RegisterTableRequest {
310 catalog: DEFAULT_CATALOG_NAME.to_string(),
311 schema: DEFAULT_SCHEMA_NAME.to_string(),
312 table_name,
313 table_id: 1024,
314 table,
315 })
316 .is_ok()
317 );
318 QueryEngineFactory::new(
319 catalog_list,
320 None,
321 None,
322 None,
323 None,
324 false,
325 Default::default(),
326 )
327 .query_engine()
328 }
329
330 fn matches_term_udf() -> Arc<ScalarUDF> {
331 Arc::new(create_udf(
332 Arc::new(MatchesTermFunction),
333 QueryContext::arc(),
334 Arc::new(FunctionState::default()),
335 ))
336 }
337
338 #[test]
339 fn test_matches_term_optimization() {
340 let batch = create_test_batch();
341
342 let predicate = create_physical_expr(
344 &Expr::ScalarFunction(ScalarFunction::new_udf(
345 matches_term_udf(),
346 vec![Expr::Column(Column::from_name("text")), "hello".lit()],
347 )),
348 &DFSchema::try_from(batch.schema().clone()).unwrap(),
349 &Default::default(),
350 )
351 .unwrap();
352
353 let input = DataSourceExec::from_data_source(
354 MemorySourceConfig::try_new(&[vec![batch.clone()]], batch.schema(), None).unwrap(),
355 );
356 let filter = FilterExec::try_new(predicate, input).unwrap();
357
358 let optimizer = MatchesConstantTermOptimizer;
360 let optimized_plan = optimizer
361 .optimize(Arc::new(filter), &Default::default())
362 .unwrap();
363
364 let optimized_filter = optimized_plan
365 .as_any()
366 .downcast_ref::<FilterExec>()
367 .unwrap();
368 let predicate = optimized_filter.predicate();
369
370 assert!(
372 std::any::TypeId::of::<PreCompiledMatchesTermExpr>() == predicate.as_any().type_id()
373 );
374 }
375
376 #[test]
377 fn test_matches_term_no_optimization() {
378 let batch = create_test_batch();
379
380 let predicate = create_physical_expr(
382 &Expr::ScalarFunction(ScalarFunction::new_udf(
383 matches_term_udf(),
384 vec![
385 Expr::Column(Column::from_name("text")),
386 Expr::Column(Column::from_name("text")),
387 ],
388 )),
389 &DFSchema::try_from(batch.schema().clone()).unwrap(),
390 &Default::default(),
391 )
392 .unwrap();
393
394 let input = DataSourceExec::from_data_source(
395 MemorySourceConfig::try_new(&[vec![batch.clone()]], batch.schema(), None).unwrap(),
396 );
397 let filter = FilterExec::try_new(predicate, input).unwrap();
398
399 let optimizer = MatchesConstantTermOptimizer;
400 let optimized_plan = optimizer
401 .optimize(Arc::new(filter), &Default::default())
402 .unwrap();
403
404 let optimized_filter = optimized_plan
405 .as_any()
406 .downcast_ref::<FilterExec>()
407 .unwrap();
408 let predicate = optimized_filter.predicate();
409
410 assert!(std::any::TypeId::of::<ScalarFunctionExpr>() == predicate.as_any().type_id());
412 }
413
414 #[tokio::test]
415 async fn test_matches_term_optimization_from_sql() {
416 let sql = "WITH base AS (
417 SELECT text, timestamp FROM test
418 WHERE MATCHES_TERM(text, 'hello wo_rld')
419 AND timestamp > '2025-01-01 00:00:00'
420 ),
421 subquery1 AS (
422 SELECT * FROM base
423 WHERE MATCHES_TERM(text, 'world')
424 ),
425 subquery2 AS (
426 SELECT * FROM test
427 WHERE MATCHES_TERM(text, 'greeting')
428 AND timestamp < '2025-01-02 00:00:00'
429 ),
430 union_result AS (
431 SELECT * FROM subquery1
432 UNION ALL
433 SELECT * FROM subquery2
434 ),
435 joined_data AS (
436 SELECT a.text, a.timestamp, b.text as other_text
437 FROM union_result a
438 JOIN test b ON a.timestamp = b.timestamp
439 WHERE MATCHES_TERM(a.text, 'there')
440 )
441 SELECT text, other_text
442 FROM joined_data
443 WHERE MATCHES_TERM(text, '42')
444 AND MATCHES_TERM(other_text, 'foo')";
445
446 let query_ctx = QueryContext::arc();
447
448 let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx).unwrap();
449 let engine = create_test_engine();
450 let logical_plan = engine
451 .planner()
452 .plan(&stmt, query_ctx.clone())
453 .await
454 .unwrap();
455
456 let engine_ctx = engine.engine_context(query_ctx);
457 let state = engine_ctx.state();
458
459 let analyzed_plan = state
460 .analyzer()
461 .execute_and_check(logical_plan.clone(), state.config_options(), |_, _| {})
462 .unwrap();
463
464 let optimized_plan = state
465 .optimizer()
466 .optimize(analyzed_plan, state, |_, _| {})
467 .unwrap();
468
469 let physical_plan = state
470 .query_planner()
471 .create_physical_plan(&optimized_plan, state)
472 .await
473 .unwrap();
474
475 let plan_str = get_plan_string(&physical_plan).join("\n");
476 assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"foo\", probes: [\"foo\"]"));
477 assert!(plan_str.contains(
478 "MatchesConstTerm(text@0, term: \"hello wo_rld\", probes: [\"hello\", \"wo_rld\"]"
479 ));
480 assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"world\", probes: [\"world\"]"));
481 assert!(
482 plan_str
483 .contains("MatchesConstTerm(text@0, term: \"greeting\", probes: [\"greeting\"]")
484 );
485 assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"there\", probes: [\"there\"]"));
486 assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"42\", probes: [\"42\"]"));
487 assert!(!plan_str.contains("matches_term"))
488 }
489}