query/optimizer/
constant_term.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16use std::fmt::Formatter;
17use std::hash::{Hash, Hasher};
18use std::sync::Arc;
19
20use arrow::array::{AsArray, BooleanArray};
21use common_function::scalars::matches_term::MatchesTermFinder;
22use datafusion::config::ConfigOptions;
23use datafusion::error::Result as DfResult;
24use datafusion::physical_optimizer::PhysicalOptimizerRule;
25use datafusion::physical_plan::filter::FilterExec;
26use datafusion::physical_plan::ExecutionPlan;
27use datafusion_common::tree_node::{Transformed, TreeNode};
28use datafusion_common::ScalarValue;
29use datafusion_expr::ColumnarValue;
30use datafusion_physical_expr::expressions::Literal;
31use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr};
32
33/// A physical expression that uses a pre-compiled term finder for the `matches_term` function.
34///
35/// This expression optimizes the `matches_term` function by pre-compiling the term
36/// when the term is a constant value. This avoids recompiling the term for each row
37/// during execution.
38#[derive(Debug)]
39pub struct PreCompiledMatchesTermExpr {
40    /// The text column expression to search in
41    text: Arc<dyn PhysicalExpr>,
42    /// The constant term to search for
43    term: String,
44    /// The pre-compiled term finder
45    finder: MatchesTermFinder,
46
47    /// No used but show how index tokenizes the term basically.
48    /// Not precise due to column options is unknown but for debugging purpose in most cases it's enough.
49    probes: Vec<String>,
50}
51
52impl fmt::Display for PreCompiledMatchesTermExpr {
53    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54        write!(
55            f,
56            "MatchesConstTerm({}, term: \"{}\", probes: {:?})",
57            self.text, self.term, self.probes
58        )
59    }
60}
61
62impl Hash for PreCompiledMatchesTermExpr {
63    fn hash<H: Hasher>(&self, state: &mut H) {
64        self.text.hash(state);
65        self.term.hash(state);
66    }
67}
68
69impl PartialEq for PreCompiledMatchesTermExpr {
70    fn eq(&self, other: &Self) -> bool {
71        self.text.eq(&other.text) && self.term.eq(&other.term)
72    }
73}
74
75impl Eq for PreCompiledMatchesTermExpr {}
76
77impl PhysicalExpr for PreCompiledMatchesTermExpr {
78    fn as_any(&self) -> &dyn std::any::Any {
79        self
80    }
81
82    fn data_type(
83        &self,
84        _input_schema: &arrow_schema::Schema,
85    ) -> datafusion::error::Result<arrow_schema::DataType> {
86        Ok(arrow_schema::DataType::Boolean)
87    }
88
89    fn nullable(&self, input_schema: &arrow_schema::Schema) -> datafusion::error::Result<bool> {
90        self.text.nullable(input_schema)
91    }
92
93    fn evaluate(
94        &self,
95        batch: &common_recordbatch::DfRecordBatch,
96    ) -> datafusion::error::Result<ColumnarValue> {
97        let num_rows = batch.num_rows();
98
99        let text_value = self.text.evaluate(batch)?;
100        let array = text_value.into_array(num_rows)?;
101        let str_array = array.as_string::<i32>();
102
103        let mut result = BooleanArray::builder(num_rows);
104        for text in str_array {
105            match text {
106                Some(text) => {
107                    result.append_value(self.finder.find(text));
108                }
109                None => {
110                    result.append_null();
111                }
112            }
113        }
114
115        Ok(ColumnarValue::Array(Arc::new(result.finish())))
116    }
117
118    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
119        vec![&self.text]
120    }
121
122    fn with_new_children(
123        self: Arc<Self>,
124        children: Vec<Arc<dyn PhysicalExpr>>,
125    ) -> datafusion::error::Result<Arc<dyn PhysicalExpr>> {
126        Ok(Arc::new(PreCompiledMatchesTermExpr {
127            text: children[0].clone(),
128            term: self.term.clone(),
129            finder: self.finder.clone(),
130            probes: self.probes.clone(),
131        }))
132    }
133
134    fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
135        write!(f, "{}", self)
136    }
137}
138
139/// Optimizer rule that pre-compiles constant term in `matches_term` function.
140///
141/// This optimizer looks for `matches_term` function calls where the second argument
142/// (the term to match) is a constant value. When found, it replaces the function
143/// call with a specialized `PreCompiledMatchesTermExpr` that uses a pre-compiled
144/// term finder.
145///
146/// Example:
147/// ```sql
148/// -- Before optimization:
149/// matches_term(text_column, 'constant_term')
150///
151/// -- After optimization:
152/// PreCompiledMatchesTermExpr(text_column, 'constant_term')
153/// ```
154///
155/// This optimization improves performance by:
156/// 1. Pre-compiling the term once instead of for each row
157/// 2. Using a specialized expression that avoids function call overhead
158#[derive(Debug)]
159pub struct MatchesConstantTermOptimizer;
160
161impl PhysicalOptimizerRule for MatchesConstantTermOptimizer {
162    fn optimize(
163        &self,
164        plan: Arc<dyn ExecutionPlan>,
165        _config: &ConfigOptions,
166    ) -> DfResult<Arc<dyn ExecutionPlan>> {
167        let res = plan
168            .transform_down(&|plan: Arc<dyn ExecutionPlan>| {
169                if let Some(filter) = plan.as_any().downcast_ref::<FilterExec>() {
170                    let pred = filter.predicate().clone();
171                    let new_pred = pred.transform_down(&|expr: Arc<dyn PhysicalExpr>| {
172                        if let Some(func) = expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
173                            if !func.name().eq_ignore_ascii_case("matches_term") {
174                                return Ok(Transformed::no(expr));
175                            }
176                            let args = func.args();
177                            if args.len() != 2 {
178                                return Ok(Transformed::no(expr));
179                            }
180
181                            if let Some(lit) = args[1].as_any().downcast_ref::<Literal>() {
182                                if let ScalarValue::Utf8(Some(term)) = lit.value() {
183                                    let finder = MatchesTermFinder::new(term);
184
185                                    // For debugging purpose. Not really precise but enough for most cases.
186                                    let probes = term
187                                        .split(|c: char| !c.is_alphanumeric() && c != '_')
188                                        .filter(|s| !s.is_empty())
189                                        .map(|s| s.to_string())
190                                        .collect();
191
192                                    let expr = PreCompiledMatchesTermExpr {
193                                        text: args[0].clone(),
194                                        term: term.to_string(),
195                                        finder,
196                                        probes,
197                                    };
198
199                                    return Ok(Transformed::yes(Arc::new(expr)));
200                                }
201                            }
202                        }
203
204                        Ok(Transformed::no(expr))
205                    })?;
206
207                    if new_pred.transformed {
208                        let exec = FilterExec::try_new(new_pred.data, filter.input().clone())?
209                            .with_default_selectivity(filter.default_selectivity())?
210                            .with_projection(filter.projection().cloned())?;
211                        return Ok(Transformed::yes(Arc::new(exec) as _));
212                    }
213                }
214
215                Ok(Transformed::no(plan))
216            })?
217            .data;
218
219        Ok(res)
220    }
221
222    fn name(&self) -> &str {
223        "MatchesConstantTerm"
224    }
225
226    fn schema_check(&self) -> bool {
227        false
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use std::sync::Arc;
234
235    use arrow::array::{ArrayRef, StringArray};
236    use arrow::datatypes::{DataType, Field, Schema};
237    use arrow::record_batch::RecordBatch;
238    use catalog::memory::MemoryCatalogManager;
239    use catalog::RegisterTableRequest;
240    use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
241    use common_function::scalars::matches_term::MatchesTermFunction;
242    use common_function::scalars::udf::create_udf;
243    use common_function::state::FunctionState;
244    use datafusion::datasource::memory::MemorySourceConfig;
245    use datafusion::datasource::source::DataSourceExec;
246    use datafusion::physical_optimizer::PhysicalOptimizerRule;
247    use datafusion::physical_plan::filter::FilterExec;
248    use datafusion::physical_plan::get_plan_string;
249    use datafusion_common::{Column, DFSchema};
250    use datafusion_expr::expr::ScalarFunction;
251    use datafusion_expr::{Expr, Literal, ScalarUDF};
252    use datafusion_physical_expr::{create_physical_expr, ScalarFunctionExpr};
253    use datatypes::prelude::ConcreteDataType;
254    use datatypes::schema::ColumnSchema;
255    use session::context::QueryContext;
256    use table::metadata::{TableInfoBuilder, TableMetaBuilder};
257    use table::test_util::EmptyTable;
258
259    use super::*;
260    use crate::parser::QueryLanguageParser;
261    use crate::{QueryEngineFactory, QueryEngineRef};
262
263    fn create_test_batch() -> RecordBatch {
264        let schema = Schema::new(vec![Field::new("text", DataType::Utf8, true)]);
265
266        let text_array = StringArray::from(vec![
267            Some("hello world"),
268            Some("greeting"),
269            Some("hello there"),
270            None,
271        ]);
272
273        RecordBatch::try_new(Arc::new(schema), vec![Arc::new(text_array) as ArrayRef]).unwrap()
274    }
275
276    fn create_test_engine() -> QueryEngineRef {
277        let table_name = "test".to_string();
278        let columns = vec![
279            ColumnSchema::new(
280                "text".to_string(),
281                ConcreteDataType::string_datatype(),
282                false,
283            ),
284            ColumnSchema::new(
285                "timestamp".to_string(),
286                ConcreteDataType::timestamp_millisecond_datatype(),
287                false,
288            )
289            .with_time_index(true),
290        ];
291
292        let schema = Arc::new(datatypes::schema::Schema::new(columns));
293        let table_meta = TableMetaBuilder::empty()
294            .schema(schema)
295            .primary_key_indices(vec![])
296            .value_indices(vec![0])
297            .next_column_id(2)
298            .build()
299            .unwrap();
300        let table_info = TableInfoBuilder::default()
301            .name(&table_name)
302            .meta(table_meta)
303            .build()
304            .unwrap();
305        let table = EmptyTable::from_table_info(&table_info);
306        let catalog_list = MemoryCatalogManager::with_default_setup();
307        assert!(catalog_list
308            .register_table_sync(RegisterTableRequest {
309                catalog: DEFAULT_CATALOG_NAME.to_string(),
310                schema: DEFAULT_SCHEMA_NAME.to_string(),
311                table_name,
312                table_id: 1024,
313                table,
314            })
315            .is_ok());
316        QueryEngineFactory::new(
317            catalog_list,
318            None,
319            None,
320            None,
321            None,
322            false,
323            Default::default(),
324        )
325        .query_engine()
326    }
327
328    fn matches_term_udf() -> Arc<ScalarUDF> {
329        Arc::new(create_udf(
330            Arc::new(MatchesTermFunction),
331            QueryContext::arc(),
332            Arc::new(FunctionState::default()),
333        ))
334    }
335
336    #[test]
337    fn test_matches_term_optimization() {
338        let batch = create_test_batch();
339
340        // Create a predicate with a constant pattern
341        let predicate = create_physical_expr(
342            &Expr::ScalarFunction(ScalarFunction::new_udf(
343                matches_term_udf(),
344                vec![Expr::Column(Column::from_name("text")), "hello".lit()],
345            )),
346            &DFSchema::try_from(batch.schema().clone()).unwrap(),
347            &Default::default(),
348        )
349        .unwrap();
350
351        let input = DataSourceExec::from_data_source(
352            MemorySourceConfig::try_new(&[vec![batch.clone()]], batch.schema(), None).unwrap(),
353        );
354        let filter = FilterExec::try_new(predicate, input).unwrap();
355
356        // Apply the optimizer
357        let optimizer = MatchesConstantTermOptimizer;
358        let optimized_plan = optimizer
359            .optimize(Arc::new(filter), &Default::default())
360            .unwrap();
361
362        let optimized_filter = optimized_plan
363            .as_any()
364            .downcast_ref::<FilterExec>()
365            .unwrap();
366        let predicate = optimized_filter.predicate();
367
368        // The predicate should be a PreCompiledMatchesTermExpr
369        assert!(
370            std::any::TypeId::of::<PreCompiledMatchesTermExpr>() == predicate.as_any().type_id()
371        );
372    }
373
374    #[test]
375    fn test_matches_term_no_optimization() {
376        let batch = create_test_batch();
377
378        // Create a predicate with a non-constant pattern
379        let predicate = create_physical_expr(
380            &Expr::ScalarFunction(ScalarFunction::new_udf(
381                matches_term_udf(),
382                vec![
383                    Expr::Column(Column::from_name("text")),
384                    Expr::Column(Column::from_name("text")),
385                ],
386            )),
387            &DFSchema::try_from(batch.schema().clone()).unwrap(),
388            &Default::default(),
389        )
390        .unwrap();
391
392        let input = DataSourceExec::from_data_source(
393            MemorySourceConfig::try_new(&[vec![batch.clone()]], batch.schema(), None).unwrap(),
394        );
395        let filter = FilterExec::try_new(predicate, input).unwrap();
396
397        let optimizer = MatchesConstantTermOptimizer;
398        let optimized_plan = optimizer
399            .optimize(Arc::new(filter), &Default::default())
400            .unwrap();
401
402        let optimized_filter = optimized_plan
403            .as_any()
404            .downcast_ref::<FilterExec>()
405            .unwrap();
406        let predicate = optimized_filter.predicate();
407
408        // The predicate should still be a ScalarFunctionExpr
409        assert!(std::any::TypeId::of::<ScalarFunctionExpr>() == predicate.as_any().type_id());
410    }
411
412    #[tokio::test]
413    async fn test_matches_term_optimization_from_sql() {
414        let sql = "WITH base AS (
415        SELECT text, timestamp FROM test 
416        WHERE MATCHES_TERM(text, 'hello wo_rld') 
417        AND timestamp > '2025-01-01 00:00:00'
418    ),
419    subquery1 AS (
420        SELECT * FROM base 
421        WHERE MATCHES_TERM(text, 'world')
422    ),
423    subquery2 AS (
424        SELECT * FROM test 
425        WHERE MATCHES_TERM(text, 'greeting') 
426        AND timestamp < '2025-01-02 00:00:00'
427    ),
428    union_result AS (
429        SELECT * FROM subquery1 
430        UNION ALL 
431        SELECT * FROM subquery2
432    ),
433    joined_data AS (
434        SELECT a.text, a.timestamp, b.text as other_text 
435        FROM union_result a 
436        JOIN test b ON a.timestamp = b.timestamp 
437        WHERE MATCHES_TERM(a.text, 'there')
438    )
439    SELECT text, other_text 
440    FROM joined_data 
441    WHERE MATCHES_TERM(text, '42') 
442    AND MATCHES_TERM(other_text, 'foo')";
443
444        let query_ctx = QueryContext::arc();
445
446        let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx).unwrap();
447        let engine = create_test_engine();
448        let logical_plan = engine
449            .planner()
450            .plan(&stmt, query_ctx.clone())
451            .await
452            .unwrap();
453
454        let engine_ctx = engine.engine_context(query_ctx);
455        let state = engine_ctx.state();
456
457        let analyzed_plan = state
458            .analyzer()
459            .execute_and_check(logical_plan.clone(), state.config_options(), |_, _| {})
460            .unwrap();
461
462        let optimized_plan = state
463            .optimizer()
464            .optimize(analyzed_plan, state, |_, _| {})
465            .unwrap();
466
467        let physical_plan = state
468            .query_planner()
469            .create_physical_plan(&optimized_plan, state)
470            .await
471            .unwrap();
472
473        let plan_str = get_plan_string(&physical_plan).join("\n");
474        assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"foo\", probes: [\"foo\"]"));
475        assert!(plan_str.contains(
476            "MatchesConstTerm(text@0, term: \"hello wo_rld\", probes: [\"hello\", \"wo_rld\"]"
477        ));
478        assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"world\", probes: [\"world\"]"));
479        assert!(plan_str
480            .contains("MatchesConstTerm(text@0, term: \"greeting\", probes: [\"greeting\"]"));
481        assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"there\", probes: [\"there\"]"));
482        assert!(plan_str.contains("MatchesConstTerm(text@0, term: \"42\", probes: [\"42\"]"));
483        assert!(!plan_str.contains("matches_term"))
484    }
485}