1use std::fmt;
16use std::hash::{Hash, Hasher};
17use std::sync::Arc;
18
19use arrow::array::{AsArray, BooleanArray};
20use common_function::scalars::matches_term::MatchesTermFinder;
21use datafusion::config::ConfigOptions;
22use datafusion::error::Result as DfResult;
23use datafusion::physical_optimizer::PhysicalOptimizerRule;
24use datafusion::physical_plan::filter::FilterExec;
25use datafusion::physical_plan::ExecutionPlan;
26use datafusion_common::tree_node::{Transformed, TreeNode};
27use datafusion_common::ScalarValue;
28use datafusion_expr::ColumnarValue;
29use datafusion_physical_expr::expressions::Literal;
30use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr};
31
32#[derive(Debug)]
38pub struct PreCompiledMatchesTermExpr {
39 text: Arc<dyn PhysicalExpr>,
41 term: String,
43 finder: MatchesTermFinder,
45}
46
47impl fmt::Display for PreCompiledMatchesTermExpr {
48 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
49 write!(f, "MatchesConstTerm({}, \"{}\")", self.text, self.term)
50 }
51}
52
53impl Hash for PreCompiledMatchesTermExpr {
54 fn hash<H: Hasher>(&self, state: &mut H) {
55 self.text.hash(state);
56 self.term.hash(state);
57 }
58}
59
60impl PartialEq for PreCompiledMatchesTermExpr {
61 fn eq(&self, other: &Self) -> bool {
62 self.text.eq(&other.text) && self.term.eq(&other.term)
63 }
64}
65
66impl Eq for PreCompiledMatchesTermExpr {}
67
68impl PhysicalExpr for PreCompiledMatchesTermExpr {
69 fn as_any(&self) -> &dyn std::any::Any {
70 self
71 }
72
73 fn data_type(
74 &self,
75 _input_schema: &arrow_schema::Schema,
76 ) -> datafusion::error::Result<arrow_schema::DataType> {
77 Ok(arrow_schema::DataType::Boolean)
78 }
79
80 fn nullable(&self, input_schema: &arrow_schema::Schema) -> datafusion::error::Result<bool> {
81 self.text.nullable(input_schema)
82 }
83
84 fn evaluate(
85 &self,
86 batch: &common_recordbatch::DfRecordBatch,
87 ) -> datafusion::error::Result<ColumnarValue> {
88 let num_rows = batch.num_rows();
89
90 let text_value = self.text.evaluate(batch)?;
91 let array = text_value.into_array(num_rows)?;
92 let str_array = array.as_string::<i32>();
93
94 let mut result = BooleanArray::builder(num_rows);
95 for text in str_array {
96 match text {
97 Some(text) => {
98 result.append_value(self.finder.find(text));
99 }
100 None => {
101 result.append_null();
102 }
103 }
104 }
105
106 Ok(ColumnarValue::Array(Arc::new(result.finish())))
107 }
108
109 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
110 vec![&self.text]
111 }
112
113 fn with_new_children(
114 self: Arc<Self>,
115 children: Vec<Arc<dyn PhysicalExpr>>,
116 ) -> datafusion::error::Result<Arc<dyn PhysicalExpr>> {
117 Ok(Arc::new(PreCompiledMatchesTermExpr {
118 text: children[0].clone(),
119 term: self.term.clone(),
120 finder: self.finder.clone(),
121 }))
122 }
123}
124
125#[derive(Debug)]
145pub struct MatchesConstantTermOptimizer;
146
147impl PhysicalOptimizerRule for MatchesConstantTermOptimizer {
148 fn optimize(
149 &self,
150 plan: Arc<dyn ExecutionPlan>,
151 _config: &ConfigOptions,
152 ) -> DfResult<Arc<dyn ExecutionPlan>> {
153 let res = plan
154 .transform_down(&|plan: Arc<dyn ExecutionPlan>| {
155 if let Some(filter) = plan.as_any().downcast_ref::<FilterExec>() {
156 let pred = filter.predicate().clone();
157 let new_pred = pred.transform_down(&|expr: Arc<dyn PhysicalExpr>| {
158 if let Some(func) = expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
159 if !func.name().eq_ignore_ascii_case("matches_term") {
160 return Ok(Transformed::no(expr));
161 }
162 let args = func.args();
163 if args.len() != 2 {
164 return Ok(Transformed::no(expr));
165 }
166
167 if let Some(lit) = args[1].as_any().downcast_ref::<Literal>() {
168 if let ScalarValue::Utf8(Some(term)) = lit.value() {
169 let finder = MatchesTermFinder::new(term);
170 let expr = PreCompiledMatchesTermExpr {
171 text: args[0].clone(),
172 term: term.to_string(),
173 finder,
174 };
175
176 return Ok(Transformed::yes(Arc::new(expr)));
177 }
178 }
179 }
180
181 Ok(Transformed::no(expr))
182 })?;
183
184 if new_pred.transformed {
185 let exec = FilterExec::try_new(new_pred.data, filter.input().clone())?
186 .with_default_selectivity(filter.default_selectivity())?
187 .with_projection(filter.projection().cloned())?;
188 return Ok(Transformed::yes(Arc::new(exec) as _));
189 }
190 }
191
192 Ok(Transformed::no(plan))
193 })?
194 .data;
195
196 Ok(res)
197 }
198
199 fn name(&self) -> &str {
200 "MatchesConstantTerm"
201 }
202
203 fn schema_check(&self) -> bool {
204 false
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use std::sync::Arc;
211
212 use arrow::array::{ArrayRef, StringArray};
213 use arrow::datatypes::{DataType, Field, Schema};
214 use arrow::record_batch::RecordBatch;
215 use catalog::memory::MemoryCatalogManager;
216 use catalog::RegisterTableRequest;
217 use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
218 use common_function::scalars::matches_term::MatchesTermFunction;
219 use common_function::scalars::udf::create_udf;
220 use common_function::state::FunctionState;
221 use datafusion::physical_optimizer::PhysicalOptimizerRule;
222 use datafusion::physical_plan::filter::FilterExec;
223 use datafusion::physical_plan::get_plan_string;
224 use datafusion::physical_plan::memory::MemoryExec;
225 use datafusion_common::{Column, DFSchema, ScalarValue};
226 use datafusion_expr::expr::ScalarFunction;
227 use datafusion_expr::{Expr, ScalarUDF};
228 use datafusion_physical_expr::{create_physical_expr, ScalarFunctionExpr};
229 use datatypes::prelude::ConcreteDataType;
230 use datatypes::schema::ColumnSchema;
231 use session::context::QueryContext;
232 use table::metadata::{TableInfoBuilder, TableMetaBuilder};
233 use table::test_util::EmptyTable;
234
235 use super::*;
236 use crate::parser::QueryLanguageParser;
237 use crate::{QueryEngineFactory, QueryEngineRef};
238
239 fn create_test_batch() -> RecordBatch {
240 let schema = Schema::new(vec![Field::new("text", DataType::Utf8, true)]);
241
242 let text_array = StringArray::from(vec![
243 Some("hello world"),
244 Some("greeting"),
245 Some("hello there"),
246 None,
247 ]);
248
249 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(text_array) as ArrayRef]).unwrap()
250 }
251
252 fn create_test_engine() -> QueryEngineRef {
253 let table_name = "test".to_string();
254 let columns = vec![
255 ColumnSchema::new(
256 "text".to_string(),
257 ConcreteDataType::string_datatype(),
258 false,
259 ),
260 ColumnSchema::new(
261 "timestamp".to_string(),
262 ConcreteDataType::timestamp_millisecond_datatype(),
263 false,
264 )
265 .with_time_index(true),
266 ];
267
268 let schema = Arc::new(datatypes::schema::Schema::new(columns));
269 let table_meta = TableMetaBuilder::empty()
270 .schema(schema)
271 .primary_key_indices(vec![])
272 .value_indices(vec![0])
273 .next_column_id(2)
274 .build()
275 .unwrap();
276 let table_info = TableInfoBuilder::default()
277 .name(&table_name)
278 .meta(table_meta)
279 .build()
280 .unwrap();
281 let table = EmptyTable::from_table_info(&table_info);
282 let catalog_list = MemoryCatalogManager::with_default_setup();
283 assert!(catalog_list
284 .register_table_sync(RegisterTableRequest {
285 catalog: DEFAULT_CATALOG_NAME.to_string(),
286 schema: DEFAULT_SCHEMA_NAME.to_string(),
287 table_name,
288 table_id: 1024,
289 table,
290 })
291 .is_ok());
292 QueryEngineFactory::new(
293 catalog_list,
294 None,
295 None,
296 None,
297 None,
298 false,
299 Default::default(),
300 )
301 .query_engine()
302 }
303
304 fn matches_term_udf() -> Arc<ScalarUDF> {
305 Arc::new(create_udf(
306 Arc::new(MatchesTermFunction),
307 QueryContext::arc(),
308 Arc::new(FunctionState::default()),
309 ))
310 }
311
312 #[test]
313 fn test_matches_term_optimization() {
314 let batch = create_test_batch();
315
316 let predicate = create_physical_expr(
318 &Expr::ScalarFunction(ScalarFunction::new_udf(
319 matches_term_udf(),
320 vec![
321 Expr::Column(Column::from_name("text")),
322 Expr::Literal(ScalarValue::Utf8(Some("hello".to_string()))),
323 ],
324 )),
325 &DFSchema::try_from(batch.schema().clone()).unwrap(),
326 &Default::default(),
327 )
328 .unwrap();
329
330 let input =
331 Arc::new(MemoryExec::try_new(&[vec![batch.clone()]], batch.schema(), None).unwrap());
332 let filter = FilterExec::try_new(predicate, input).unwrap();
333
334 let optimizer = MatchesConstantTermOptimizer;
336 let optimized_plan = optimizer
337 .optimize(Arc::new(filter), &Default::default())
338 .unwrap();
339
340 let optimized_filter = optimized_plan
341 .as_any()
342 .downcast_ref::<FilterExec>()
343 .unwrap();
344 let predicate = optimized_filter.predicate();
345
346 assert!(
348 std::any::TypeId::of::<PreCompiledMatchesTermExpr>() == predicate.as_any().type_id()
349 );
350 }
351
352 #[test]
353 fn test_matches_term_no_optimization() {
354 let batch = create_test_batch();
355
356 let predicate = create_physical_expr(
358 &Expr::ScalarFunction(ScalarFunction::new_udf(
359 matches_term_udf(),
360 vec![
361 Expr::Column(Column::from_name("text")),
362 Expr::Column(Column::from_name("text")),
363 ],
364 )),
365 &DFSchema::try_from(batch.schema().clone()).unwrap(),
366 &Default::default(),
367 )
368 .unwrap();
369
370 let input =
371 Arc::new(MemoryExec::try_new(&[vec![batch.clone()]], batch.schema(), None).unwrap());
372 let filter = FilterExec::try_new(predicate, input).unwrap();
373
374 let optimizer = MatchesConstantTermOptimizer;
375 let optimized_plan = optimizer
376 .optimize(Arc::new(filter), &Default::default())
377 .unwrap();
378
379 let optimized_filter = optimized_plan
380 .as_any()
381 .downcast_ref::<FilterExec>()
382 .unwrap();
383 let predicate = optimized_filter.predicate();
384
385 assert!(std::any::TypeId::of::<ScalarFunctionExpr>() == predicate.as_any().type_id());
387 }
388
389 #[tokio::test]
390 async fn test_matches_term_optimization_from_sql() {
391 let sql = "WITH base AS (
392 SELECT text, timestamp FROM test
393 WHERE MATCHES_TERM(text, 'hello')
394 AND timestamp > '2025-01-01 00:00:00'
395 ),
396 subquery1 AS (
397 SELECT * FROM base
398 WHERE MATCHES_TERM(text, 'world')
399 ),
400 subquery2 AS (
401 SELECT * FROM test
402 WHERE MATCHES_TERM(text, 'greeting')
403 AND timestamp < '2025-01-02 00:00:00'
404 ),
405 union_result AS (
406 SELECT * FROM subquery1
407 UNION ALL
408 SELECT * FROM subquery2
409 ),
410 joined_data AS (
411 SELECT a.text, a.timestamp, b.text as other_text
412 FROM union_result a
413 JOIN test b ON a.timestamp = b.timestamp
414 WHERE MATCHES_TERM(a.text, 'there')
415 )
416 SELECT text, other_text
417 FROM joined_data
418 WHERE MATCHES_TERM(text, '42')
419 AND MATCHES_TERM(other_text, 'foo')";
420
421 let query_ctx = QueryContext::arc();
422
423 let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx).unwrap();
424 let engine = create_test_engine();
425 let logical_plan = engine
426 .planner()
427 .plan(&stmt, query_ctx.clone())
428 .await
429 .unwrap();
430
431 let engine_ctx = engine.engine_context(query_ctx);
432 let state = engine_ctx.state();
433
434 let analyzed_plan = state
435 .analyzer()
436 .execute_and_check(logical_plan.clone(), state.config_options(), |_, _| {})
437 .unwrap();
438
439 let optimized_plan = state
440 .optimizer()
441 .optimize(analyzed_plan, state, |_, _| {})
442 .unwrap();
443
444 let physical_plan = state
445 .query_planner()
446 .create_physical_plan(&optimized_plan, state)
447 .await
448 .unwrap();
449
450 let plan_str = get_plan_string(&physical_plan).join("\n");
451 assert!(plan_str.contains("MatchesConstTerm"));
452 assert!(!plan_str.contains("matches_term"))
453 }
454}