1use std::fmt;
16use std::fmt::Formatter;
17use std::hash::{Hash, Hasher};
18use std::sync::Arc;
19
20use arrow::array::{AsArray, BooleanArray};
21use common_function::scalars::matches_term::MatchesTermFinder;
22use datafusion::config::ConfigOptions;
23use datafusion::error::Result as DfResult;
24use datafusion::physical_optimizer::PhysicalOptimizerRule;
25use datafusion::physical_plan::filter::FilterExec;
26use datafusion::physical_plan::ExecutionPlan;
27use datafusion_common::tree_node::{Transformed, TreeNode};
28use datafusion_common::ScalarValue;
29use datafusion_expr::ColumnarValue;
30use datafusion_physical_expr::expressions::Literal;
31use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr};
32
33#[derive(Debug)]
39pub struct PreCompiledMatchesTermExpr {
40 text: Arc<dyn PhysicalExpr>,
42 term: String,
44 finder: MatchesTermFinder,
46
47 probes: Vec<String>,
50}
51
52impl fmt::Display for PreCompiledMatchesTermExpr {
53 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54 write!(
55 f,
56 "MatchesConstTerm({}, term: \"{}\", probes: {:?})",
57 self.text, self.term, self.probes
58 )
59 }
60}
61
62impl Hash for PreCompiledMatchesTermExpr {
63 fn hash<H: Hasher>(&self, state: &mut H) {
64 self.text.hash(state);
65 self.term.hash(state);
66 }
67}
68
69impl PartialEq for PreCompiledMatchesTermExpr {
70 fn eq(&self, other: &Self) -> bool {
71 self.text.eq(&other.text) && self.term.eq(&other.term)
72 }
73}
74
75impl Eq for PreCompiledMatchesTermExpr {}
76
77impl PhysicalExpr for PreCompiledMatchesTermExpr {
78 fn as_any(&self) -> &dyn std::any::Any {
79 self
80 }
81
82 fn data_type(
83 &self,
84 _input_schema: &arrow_schema::Schema,
85 ) -> datafusion::error::Result<arrow_schema::DataType> {
86 Ok(arrow_schema::DataType::Boolean)
87 }
88
89 fn nullable(&self, input_schema: &arrow_schema::Schema) -> datafusion::error::Result<bool> {
90 self.text.nullable(input_schema)
91 }
92
93 fn evaluate(
94 &self,
95 batch: &common_recordbatch::DfRecordBatch,
96 ) -> datafusion::error::Result<ColumnarValue> {
97 let num_rows = batch.num_rows();
98
99 let text_value = self.text.evaluate(batch)?;
100 let array = text_value.into_array(num_rows)?;
101 let str_array = array.as_string::<i32>();
102
103 let mut result = BooleanArray::builder(num_rows);
104 for text in str_array {
105 match text {
106 Some(text) => {
107 result.append_value(self.finder.find(text));
108 }
109 None => {
110 result.append_null();
111 }
112 }
113 }
114
115 Ok(ColumnarValue::Array(Arc::new(result.finish())))
116 }
117
118 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
119 vec![&self.text]
120 }
121
122 fn with_new_children(
123 self: Arc<Self>,
124 children: Vec<Arc<dyn PhysicalExpr>>,
125 ) -> datafusion::error::Result<Arc<dyn PhysicalExpr>> {
126 Ok(Arc::new(PreCompiledMatchesTermExpr {
127 text: children[0].clone(),
128 term: self.term.clone(),
129 finder: self.finder.clone(),
130 probes: self.probes.clone(),
131 }))
132 }
133
134 fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
135 write!(f, "{}", self)
136 }
137}
138
139#[derive(Debug)]
159pub struct MatchesConstantTermOptimizer;
160
161impl PhysicalOptimizerRule for MatchesConstantTermOptimizer {
162 fn optimize(
163 &self,
164 plan: Arc<dyn ExecutionPlan>,
165 _config: &ConfigOptions,
166 ) -> DfResult<Arc<dyn ExecutionPlan>> {
167 let res = plan
168 .transform_down(&|plan: Arc<dyn ExecutionPlan>| {
169 if let Some(filter) = plan.as_any().downcast_ref::<FilterExec>() {
170 let pred = filter.predicate().clone();
171 let new_pred = pred.transform_down(&|expr: Arc<dyn PhysicalExpr>| {
172 if let Some(func) = expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
173 if !func.name().eq_ignore_ascii_case("matches_term") {
174 return Ok(Transformed::no(expr));
175 }
176 let args = func.args();
177 if args.len() != 2 {
178 return Ok(Transformed::no(expr));
179 }
180
181 if let Some(lit) = args[1].as_any().downcast_ref::<Literal>() {
182 if let ScalarValue::Utf8(Some(term)) = lit.value() {
183 let finder = MatchesTermFinder::new(term);
184
185 let probes = term
187 .split(|c: char| !c.is_alphanumeric() && c != '_')
188 .filter(|s| !s.is_empty())
189 .map(|s| s.to_string())
190 .collect();
191
192 let expr = PreCompiledMatchesTermExpr {
193 text: args[0].clone(),
194 term: term.to_string(),
195 finder,
196 probes,
197 };
198
199 return Ok(Transformed::yes(Arc::new(expr)));
200 }
201 }
202 }
203
204 Ok(Transformed::no(expr))
205 })?;
206
207 if new_pred.transformed {
208 let exec = FilterExec::try_new(new_pred.data, filter.input().clone())?
209 .with_default_selectivity(filter.default_selectivity())?
210 .with_projection(filter.projection().cloned())?;
211 return Ok(Transformed::yes(Arc::new(exec) as _));
212 }
213 }
214
215 Ok(Transformed::no(plan))
216 })?
217 .data;
218
219 Ok(res)
220 }
221
222 fn name(&self) -> &str {
223 "MatchesConstantTerm"
224 }
225
226 fn schema_check(&self) -> bool {
227 false
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use std::sync::Arc;
234
235 use arrow::array::{ArrayRef, StringArray};
236 use arrow::datatypes::{DataType, Field, Schema};
237 use arrow::record_batch::RecordBatch;
238 use catalog::memory::MemoryCatalogManager;
239 use catalog::RegisterTableRequest;
240 use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
241 use common_function::scalars::matches_term::MatchesTermFunction;
242 use common_function::scalars::udf::create_udf;
243 use common_function::state::FunctionState;
244 use datafusion::datasource::memory::MemorySourceConfig;
245 use datafusion::datasource::source::DataSourceExec;
246 use datafusion::physical_optimizer::PhysicalOptimizerRule;
247 use datafusion::physical_plan::filter::FilterExec;
248 use datafusion::physical_plan::get_plan_string;
249 use datafusion_common::{Column, DFSchema};
250 use datafusion_expr::expr::ScalarFunction;
251 use datafusion_expr::{Expr, Literal, ScalarUDF};
252 use datafusion_physical_expr::{create_physical_expr, ScalarFunctionExpr};
253 use datatypes::prelude::ConcreteDataType;
254 use datatypes::schema::ColumnSchema;
255 use session::context::QueryContext;
256 use table::metadata::{TableInfoBuilder, TableMetaBuilder};
257 use table::test_util::EmptyTable;
258
259 use super::*;
260 use crate::parser::QueryLanguageParser;
261 use crate::{QueryEngineFactory, QueryEngineRef};
262
263 fn create_test_batch() -> RecordBatch {
264 let schema = Schema::new(vec![Field::new("text", DataType::Utf8, true)]);
265
266 let text_array = StringArray::from(vec![
267 Some("hello world"),
268 Some("greeting"),
269 Some("hello there"),
270 None,
271 ]);
272
273 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(text_array) as ArrayRef]).unwrap()
274 }
275
276 fn create_test_engine() -> QueryEngineRef {
277 let table_name = "test".to_string();
278 let columns = vec![
279 ColumnSchema::new(
280 "text".to_string(),
281 ConcreteDataType::string_datatype(),
282 false,
283 ),
284 ColumnSchema::new(
285 "timestamp".to_string(),
286 ConcreteDataType::timestamp_millisecond_datatype(),
287 false,
288 )
289 .with_time_index(true),
290 ];
291
292 let schema = Arc::new(datatypes::schema::Schema::new(columns));
293 let table_meta = TableMetaBuilder::empty()
294 .schema(schema)
295 .primary_key_indices(vec![])
296 .value_indices(vec![0])
297 .next_column_id(2)
298 .build()
299 .unwrap();
300 let table_info = TableInfoBuilder::default()
301 .name(&table_name)
302 .meta(table_meta)
303 .build()
304 .unwrap();
305 let table = EmptyTable::from_table_info(&table_info);
306 let catalog_list = MemoryCatalogManager::with_default_setup();
307 assert!(catalog_list
308 .register_table_sync(RegisterTableRequest {
309 catalog: DEFAULT_CATALOG_NAME.to_string(),
310 schema: DEFAULT_SCHEMA_NAME.to_string(),
311 table_name,
312 table_id: 1024,
313 table,
314 })
315 .is_ok());
316 QueryEngineFactory::new(
317 catalog_list,
318 None,
319 None,
320 None,
321 None,
322 false,
323 Default::default(),
324 )
325 .query_engine()
326 }
327
328 fn matches_term_udf() -> Arc<ScalarUDF> {
329 Arc::new(create_udf(
330 Arc::new(MatchesTermFunction),
331 QueryContext::arc(),
332 Arc::new(FunctionState::default()),
333 ))
334 }
335
336 #[test]
337 fn test_matches_term_optimization() {
338 let batch = create_test_batch();
339
340 let predicate = create_physical_expr(
342 &Expr::ScalarFunction(ScalarFunction::new_udf(
343 matches_term_udf(),
344 vec![Expr::Column(Column::from_name("text")), "hello".lit()],
345 )),
346 &DFSchema::try_from(batch.schema().clone()).unwrap(),
347 &Default::default(),
348 )
349 .unwrap();
350
351 let input = DataSourceExec::from_data_source(
352 MemorySourceConfig::try_new(&[vec![batch.clone()]], batch.schema(), None).unwrap(),
353 );
354 let filter = FilterExec::try_new(predicate, input).unwrap();
355
356 let optimizer = MatchesConstantTermOptimizer;
358 let optimized_plan = optimizer
359 .optimize(Arc::new(filter), &Default::default())
360 .unwrap();
361
362 let optimized_filter = optimized_plan
363 .as_any()
364 .downcast_ref::<FilterExec>()
365 .unwrap();
366 let predicate = optimized_filter.predicate();
367
368 assert!(
370 std::any::TypeId::of::<PreCompiledMatchesTermExpr>() == predicate.as_any().type_id()
371 );
372 }
373
374 #[test]
375 fn test_matches_term_no_optimization() {
376 let batch = create_test_batch();
377
378 let predicate = create_physical_expr(
380 &Expr::ScalarFunction(ScalarFunction::new_udf(
381 matches_term_udf(),
382 vec![
383 Expr::Column(Column::from_name("text")),
384 Expr::Column(Column::from_name("text")),
385 ],
386 )),
387 &DFSchema::try_from(batch.schema().clone()).unwrap(),
388 &Default::default(),
389 )
390 .unwrap();
391
392 let input = DataSourceExec::from_data_source(
393 MemorySourceConfig::try_new(&[vec![batch.clone()]], batch.schema(), None).unwrap(),
394 );
395 let filter = FilterExec::try_new(predicate, input).unwrap();
396
397 let optimizer = MatchesConstantTermOptimizer;
398 let optimized_plan = optimizer
399 .optimize(Arc::new(filter), &Default::default())
400 .unwrap();
401
402 let optimized_filter = optimized_plan
403 .as_any()
404 .downcast_ref::<FilterExec>()
405 .unwrap();
406 let predicate = optimized_filter.predicate();
407
408 assert!(std::any::TypeId::of::<ScalarFunctionExpr>() == predicate.as_any().type_id());
410 }
411
412 #[tokio::test]
413 async fn test_matches_term_optimization_from_sql() {
414 let sql = "WITH base AS (
415 SELECT text, timestamp FROM test
416 WHERE MATCHES_TERM(text, 'hello wo_rld')
417 AND timestamp > '2025-01-01 00:00:00'
418 ),
419 subquery1 AS (
420 SELECT * FROM base
421 WHERE MATCHES_TERM(text, 'world')
422 ),
423 subquery2 AS (
424 SELECT * FROM test
425 WHERE MATCHES_TERM(text, 'greeting')
426 AND timestamp < '2025-01-02 00:00:00'
427 ),
428 union_result AS (
429 SELECT * FROM subquery1
430 UNION ALL
431 SELECT * FROM subquery2
432 ),
433 joined_data AS (
434 SELECT a.text, a.timestamp, b.text as other_text
435 FROM union_result a
436 JOIN test b ON a.timestamp = b.timestamp
437 WHERE MATCHES_TERM(a.text, 'there')
438 )
439 SELECT text, other_text
440 FROM joined_data
441 WHERE MATCHES_TERM(text, '42')
442 AND MATCHES_TERM(other_text, 'foo')";
443
444 let query_ctx = QueryContext::arc();
445
446 let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx).unwrap();
447 let engine = create_test_engine();
448 let logical_plan = engine
449 .planner()
450 .plan(&stmt, query_ctx.clone())
451 .await
452 .unwrap();
453
454 let engine_ctx = engine.engine_context(query_ctx);
455 let state = engine_ctx.state();
456
457 let analyzed_plan = state
458 .analyzer()
459 .execute_and_check(logical_plan.clone(), state.config_options(), |_, _| {})
460 .unwrap();
461
462 let optimized_plan = state
463 .optimizer()
464 .optimize(analyzed_plan, state, |_, _| {})
465 .unwrap();
466
467 let physical_plan = state
468 .query_planner()
469 .create_physical_plan(&optimized_plan, state)
470 .await
471 .unwrap();
472
473 let plan_str = get_plan_string(&physical_plan).join("\n");
474 assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"foo\", probes: [\"foo\"]"));
475 assert!(plan_str.contains(
476 "MatchesConstTerm(text@0, term: \"hello wo_rld\", probes: [\"hello\", \"wo_rld\"]"
477 ));
478 assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"world\", probes: [\"world\"]"));
479 assert!(plan_str
480 .contains("MatchesConstTerm(text@0, term: \"greeting\", probes: [\"greeting\"]"));
481 assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"there\", probes: [\"there\"]"));
482 assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"42\", probes: [\"42\"]"));
483 assert!(!plan_str.contains("matches_term"))
484 }
485}