query/optimizer/
constant_term.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16use std::hash::{Hash, Hasher};
17use std::sync::Arc;
18
19use arrow::array::{AsArray, BooleanArray};
20use common_function::scalars::matches_term::MatchesTermFinder;
21use datafusion::config::ConfigOptions;
22use datafusion::error::Result as DfResult;
23use datafusion::physical_optimizer::PhysicalOptimizerRule;
24use datafusion::physical_plan::filter::FilterExec;
25use datafusion::physical_plan::ExecutionPlan;
26use datafusion_common::tree_node::{Transformed, TreeNode};
27use datafusion_common::ScalarValue;
28use datafusion_expr::ColumnarValue;
29use datafusion_physical_expr::expressions::Literal;
30use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr};
31
32/// A physical expression that uses a pre-compiled term finder for the `matches_term` function.
33///
34/// This expression optimizes the `matches_term` function by pre-compiling the term
35/// when the term is a constant value. This avoids recompiling the term for each row
36/// during execution.
37#[derive(Debug)]
38pub struct PreCompiledMatchesTermExpr {
39    /// The text column expression to search in
40    text: Arc<dyn PhysicalExpr>,
41    /// The constant term to search for
42    term: String,
43    /// The pre-compiled term finder
44    finder: MatchesTermFinder,
45}
46
47impl fmt::Display for PreCompiledMatchesTermExpr {
48    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
49        write!(f, "MatchesConstTerm({}, \"{}\")", self.text, self.term)
50    }
51}
52
53impl Hash for PreCompiledMatchesTermExpr {
54    fn hash<H: Hasher>(&self, state: &mut H) {
55        self.text.hash(state);
56        self.term.hash(state);
57    }
58}
59
60impl PartialEq for PreCompiledMatchesTermExpr {
61    fn eq(&self, other: &Self) -> bool {
62        self.text.eq(&other.text) && self.term.eq(&other.term)
63    }
64}
65
66impl Eq for PreCompiledMatchesTermExpr {}
67
68impl PhysicalExpr for PreCompiledMatchesTermExpr {
69    fn as_any(&self) -> &dyn std::any::Any {
70        self
71    }
72
73    fn data_type(
74        &self,
75        _input_schema: &arrow_schema::Schema,
76    ) -> datafusion::error::Result<arrow_schema::DataType> {
77        Ok(arrow_schema::DataType::Boolean)
78    }
79
80    fn nullable(&self, input_schema: &arrow_schema::Schema) -> datafusion::error::Result<bool> {
81        self.text.nullable(input_schema)
82    }
83
84    fn evaluate(
85        &self,
86        batch: &common_recordbatch::DfRecordBatch,
87    ) -> datafusion::error::Result<ColumnarValue> {
88        let num_rows = batch.num_rows();
89
90        let text_value = self.text.evaluate(batch)?;
91        let array = text_value.into_array(num_rows)?;
92        let str_array = array.as_string::<i32>();
93
94        let mut result = BooleanArray::builder(num_rows);
95        for text in str_array {
96            match text {
97                Some(text) => {
98                    result.append_value(self.finder.find(text));
99                }
100                None => {
101                    result.append_null();
102                }
103            }
104        }
105
106        Ok(ColumnarValue::Array(Arc::new(result.finish())))
107    }
108
109    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
110        vec![&self.text]
111    }
112
113    fn with_new_children(
114        self: Arc<Self>,
115        children: Vec<Arc<dyn PhysicalExpr>>,
116    ) -> datafusion::error::Result<Arc<dyn PhysicalExpr>> {
117        Ok(Arc::new(PreCompiledMatchesTermExpr {
118            text: children[0].clone(),
119            term: self.term.clone(),
120            finder: self.finder.clone(),
121        }))
122    }
123}
124
125/// Optimizer rule that pre-compiles constant term in `matches_term` function.
126///
127/// This optimizer looks for `matches_term` function calls where the second argument
128/// (the term to match) is a constant value. When found, it replaces the function
129/// call with a specialized `PreCompiledMatchesTermExpr` that uses a pre-compiled
130/// term finder.
131///
132/// Example:
133/// ```sql
134/// -- Before optimization:
135/// matches_term(text_column, 'constant_term')
136///
137/// -- After optimization:
138/// PreCompiledMatchesTermExpr(text_column, 'constant_term')
139/// ```
140///
141/// This optimization improves performance by:
142/// 1. Pre-compiling the term once instead of for each row
143/// 2. Using a specialized expression that avoids function call overhead
144#[derive(Debug)]
145pub struct MatchesConstantTermOptimizer;
146
147impl PhysicalOptimizerRule for MatchesConstantTermOptimizer {
148    fn optimize(
149        &self,
150        plan: Arc<dyn ExecutionPlan>,
151        _config: &ConfigOptions,
152    ) -> DfResult<Arc<dyn ExecutionPlan>> {
153        let res = plan
154            .transform_down(&|plan: Arc<dyn ExecutionPlan>| {
155                if let Some(filter) = plan.as_any().downcast_ref::<FilterExec>() {
156                    let pred = filter.predicate().clone();
157                    let new_pred = pred.transform_down(&|expr: Arc<dyn PhysicalExpr>| {
158                        if let Some(func) = expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
159                            if !func.name().eq_ignore_ascii_case("matches_term") {
160                                return Ok(Transformed::no(expr));
161                            }
162                            let args = func.args();
163                            if args.len() != 2 {
164                                return Ok(Transformed::no(expr));
165                            }
166
167                            if let Some(lit) = args[1].as_any().downcast_ref::<Literal>() {
168                                if let ScalarValue::Utf8(Some(term)) = lit.value() {
169                                    let finder = MatchesTermFinder::new(term);
170                                    let expr = PreCompiledMatchesTermExpr {
171                                        text: args[0].clone(),
172                                        term: term.to_string(),
173                                        finder,
174                                    };
175
176                                    return Ok(Transformed::yes(Arc::new(expr)));
177                                }
178                            }
179                        }
180
181                        Ok(Transformed::no(expr))
182                    })?;
183
184                    if new_pred.transformed {
185                        let exec = FilterExec::try_new(new_pred.data, filter.input().clone())?
186                            .with_default_selectivity(filter.default_selectivity())?
187                            .with_projection(filter.projection().cloned())?;
188                        return Ok(Transformed::yes(Arc::new(exec) as _));
189                    }
190                }
191
192                Ok(Transformed::no(plan))
193            })?
194            .data;
195
196        Ok(res)
197    }
198
199    fn name(&self) -> &str {
200        "MatchesConstantTerm"
201    }
202
203    fn schema_check(&self) -> bool {
204        false
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use std::sync::Arc;
211
212    use arrow::array::{ArrayRef, StringArray};
213    use arrow::datatypes::{DataType, Field, Schema};
214    use arrow::record_batch::RecordBatch;
215    use catalog::memory::MemoryCatalogManager;
216    use catalog::RegisterTableRequest;
217    use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
218    use common_function::scalars::matches_term::MatchesTermFunction;
219    use common_function::scalars::udf::create_udf;
220    use common_function::state::FunctionState;
221    use datafusion::physical_optimizer::PhysicalOptimizerRule;
222    use datafusion::physical_plan::filter::FilterExec;
223    use datafusion::physical_plan::get_plan_string;
224    use datafusion::physical_plan::memory::MemoryExec;
225    use datafusion_common::{Column, DFSchema, ScalarValue};
226    use datafusion_expr::expr::ScalarFunction;
227    use datafusion_expr::{Expr, ScalarUDF};
228    use datafusion_physical_expr::{create_physical_expr, ScalarFunctionExpr};
229    use datatypes::prelude::ConcreteDataType;
230    use datatypes::schema::ColumnSchema;
231    use session::context::QueryContext;
232    use table::metadata::{TableInfoBuilder, TableMetaBuilder};
233    use table::test_util::EmptyTable;
234
235    use super::*;
236    use crate::parser::QueryLanguageParser;
237    use crate::{QueryEngineFactory, QueryEngineRef};
238
239    fn create_test_batch() -> RecordBatch {
240        let schema = Schema::new(vec![Field::new("text", DataType::Utf8, true)]);
241
242        let text_array = StringArray::from(vec![
243            Some("hello world"),
244            Some("greeting"),
245            Some("hello there"),
246            None,
247        ]);
248
249        RecordBatch::try_new(Arc::new(schema), vec![Arc::new(text_array) as ArrayRef]).unwrap()
250    }
251
252    fn create_test_engine() -> QueryEngineRef {
253        let table_name = "test".to_string();
254        let columns = vec![
255            ColumnSchema::new(
256                "text".to_string(),
257                ConcreteDataType::string_datatype(),
258                false,
259            ),
260            ColumnSchema::new(
261                "timestamp".to_string(),
262                ConcreteDataType::timestamp_millisecond_datatype(),
263                false,
264            )
265            .with_time_index(true),
266        ];
267
268        let schema = Arc::new(datatypes::schema::Schema::new(columns));
269        let table_meta = TableMetaBuilder::empty()
270            .schema(schema)
271            .primary_key_indices(vec![])
272            .value_indices(vec![0])
273            .next_column_id(2)
274            .build()
275            .unwrap();
276        let table_info = TableInfoBuilder::default()
277            .name(&table_name)
278            .meta(table_meta)
279            .build()
280            .unwrap();
281        let table = EmptyTable::from_table_info(&table_info);
282        let catalog_list = MemoryCatalogManager::with_default_setup();
283        assert!(catalog_list
284            .register_table_sync(RegisterTableRequest {
285                catalog: DEFAULT_CATALOG_NAME.to_string(),
286                schema: DEFAULT_SCHEMA_NAME.to_string(),
287                table_name,
288                table_id: 1024,
289                table,
290            })
291            .is_ok());
292        QueryEngineFactory::new(
293            catalog_list,
294            None,
295            None,
296            None,
297            None,
298            false,
299            Default::default(),
300        )
301        .query_engine()
302    }
303
304    fn matches_term_udf() -> Arc<ScalarUDF> {
305        Arc::new(create_udf(
306            Arc::new(MatchesTermFunction),
307            QueryContext::arc(),
308            Arc::new(FunctionState::default()),
309        ))
310    }
311
312    #[test]
313    fn test_matches_term_optimization() {
314        let batch = create_test_batch();
315
316        // Create a predicate with a constant pattern
317        let predicate = create_physical_expr(
318            &Expr::ScalarFunction(ScalarFunction::new_udf(
319                matches_term_udf(),
320                vec![
321                    Expr::Column(Column::from_name("text")),
322                    Expr::Literal(ScalarValue::Utf8(Some("hello".to_string()))),
323                ],
324            )),
325            &DFSchema::try_from(batch.schema().clone()).unwrap(),
326            &Default::default(),
327        )
328        .unwrap();
329
330        let input =
331            Arc::new(MemoryExec::try_new(&[vec![batch.clone()]], batch.schema(), None).unwrap());
332        let filter = FilterExec::try_new(predicate, input).unwrap();
333
334        // Apply the optimizer
335        let optimizer = MatchesConstantTermOptimizer;
336        let optimized_plan = optimizer
337            .optimize(Arc::new(filter), &Default::default())
338            .unwrap();
339
340        let optimized_filter = optimized_plan
341            .as_any()
342            .downcast_ref::<FilterExec>()
343            .unwrap();
344        let predicate = optimized_filter.predicate();
345
346        // The predicate should be a PreCompiledMatchesTermExpr
347        assert!(
348            std::any::TypeId::of::<PreCompiledMatchesTermExpr>() == predicate.as_any().type_id()
349        );
350    }
351
352    #[test]
353    fn test_matches_term_no_optimization() {
354        let batch = create_test_batch();
355
356        // Create a predicate with a non-constant pattern
357        let predicate = create_physical_expr(
358            &Expr::ScalarFunction(ScalarFunction::new_udf(
359                matches_term_udf(),
360                vec![
361                    Expr::Column(Column::from_name("text")),
362                    Expr::Column(Column::from_name("text")),
363                ],
364            )),
365            &DFSchema::try_from(batch.schema().clone()).unwrap(),
366            &Default::default(),
367        )
368        .unwrap();
369
370        let input =
371            Arc::new(MemoryExec::try_new(&[vec![batch.clone()]], batch.schema(), None).unwrap());
372        let filter = FilterExec::try_new(predicate, input).unwrap();
373
374        let optimizer = MatchesConstantTermOptimizer;
375        let optimized_plan = optimizer
376            .optimize(Arc::new(filter), &Default::default())
377            .unwrap();
378
379        let optimized_filter = optimized_plan
380            .as_any()
381            .downcast_ref::<FilterExec>()
382            .unwrap();
383        let predicate = optimized_filter.predicate();
384
385        // The predicate should still be a ScalarFunctionExpr
386        assert!(std::any::TypeId::of::<ScalarFunctionExpr>() == predicate.as_any().type_id());
387    }
388
389    #[tokio::test]
390    async fn test_matches_term_optimization_from_sql() {
391        let sql = "WITH base AS (
392        SELECT text, timestamp FROM test 
393        WHERE MATCHES_TERM(text, 'hello') 
394        AND timestamp > '2025-01-01 00:00:00'
395    ),
396    subquery1 AS (
397        SELECT * FROM base 
398        WHERE MATCHES_TERM(text, 'world')
399    ),
400    subquery2 AS (
401        SELECT * FROM test 
402        WHERE MATCHES_TERM(text, 'greeting') 
403        AND timestamp < '2025-01-02 00:00:00'
404    ),
405    union_result AS (
406        SELECT * FROM subquery1 
407        UNION ALL 
408        SELECT * FROM subquery2
409    ),
410    joined_data AS (
411        SELECT a.text, a.timestamp, b.text as other_text 
412        FROM union_result a 
413        JOIN test b ON a.timestamp = b.timestamp 
414        WHERE MATCHES_TERM(a.text, 'there')
415    )
416    SELECT text, other_text 
417    FROM joined_data 
418    WHERE MATCHES_TERM(text, '42') 
419    AND MATCHES_TERM(other_text, 'foo')";
420
421        let query_ctx = QueryContext::arc();
422
423        let stmt = QueryLanguageParser::parse_sql(sql, &query_ctx).unwrap();
424        let engine = create_test_engine();
425        let logical_plan = engine
426            .planner()
427            .plan(&stmt, query_ctx.clone())
428            .await
429            .unwrap();
430
431        let engine_ctx = engine.engine_context(query_ctx);
432        let state = engine_ctx.state();
433
434        let analyzed_plan = state
435            .analyzer()
436            .execute_and_check(logical_plan.clone(), state.config_options(), |_, _| {})
437            .unwrap();
438
439        let optimized_plan = state
440            .optimizer()
441            .optimize(analyzed_plan, state, |_, _| {})
442            .unwrap();
443
444        let physical_plan = state
445            .query_planner()
446            .create_physical_plan(&optimized_plan, state)
447            .await
448            .unwrap();
449
450        let plan_str = get_plan_string(&physical_plan).join("\n");
451        assert!(plan_str.contains("MatchesConstTerm"));
452        assert!(!plan_str.contains("matches_term"))
453    }
454}