query/optimizer/
transcribe_atat.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_function::scalars::matches_term::MatchesTermFunction;
18use common_function::scalars::udf::create_udf;
19use datafusion::config::ConfigOptions;
20use datafusion_common::Result;
21use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
22use datafusion_expr::expr::ScalarFunction;
23use datafusion_expr::{Expr, LogicalPlan};
24use datafusion_optimizer::analyzer::AnalyzerRule;
25
26use crate::plan::ExtractExpr;
27
28/// TranscribeAtatRule is an analyzer rule that transcribes `@@` operator
29/// to `matches_term` function.
30///
31/// Example:
32/// ```sql
33/// SELECT matches_term('cat!', 'cat') as result;
34///
35/// SELECT matches_term(`log_message`, '/start') as `matches_start` FROM t;
36/// ```
37///
38/// to
39///
40/// ```sql
41/// SELECT 'cat!' @@ 'cat' as result;
42///
43/// SELECT `log_message` @@ '/start' as `matches_start` FROM t;
44/// ```
45#[derive(Debug)]
46pub struct TranscribeAtatRule;
47
48impl AnalyzerRule for TranscribeAtatRule {
49    fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result<LogicalPlan> {
50        plan.transform(Self::do_analyze).map(|x| x.data)
51    }
52
53    fn name(&self) -> &str {
54        "TranscribeAtatRule"
55    }
56}
57
58impl TranscribeAtatRule {
59    fn do_analyze(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
60        let mut rewriter = TranscribeAtatRewriter::default();
61        let new_expr = plan
62            .expressions_consider_join()
63            .into_iter()
64            .map(|e| e.rewrite(&mut rewriter).map(|x| x.data))
65            .collect::<Result<Vec<_>>>()?;
66
67        if rewriter.transcribed {
68            let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
69            plan.with_new_exprs(new_expr, inputs).map(Transformed::yes)
70        } else {
71            Ok(Transformed::no(plan))
72        }
73    }
74}
75
76#[derive(Default)]
77struct TranscribeAtatRewriter {
78    transcribed: bool,
79}
80
81impl TreeNodeRewriter for TranscribeAtatRewriter {
82    type Node = Expr;
83
84    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
85        if let Expr::BinaryExpr(binary_expr) = &expr
86            && matches!(binary_expr.op, datafusion_expr::Operator::AtAt)
87        {
88            self.transcribed = true;
89            let scalar_udf = create_udf(Arc::new(MatchesTermFunction::default()));
90            let exprs = vec![
91                binary_expr.left.as_ref().clone(),
92                binary_expr.right.as_ref().clone(),
93            ];
94            Ok(Transformed::yes(Expr::ScalarFunction(
95                ScalarFunction::new_udf(Arc::new(scalar_udf), exprs),
96            )))
97        } else {
98            Ok(Transformed::no(expr))
99        }
100    }
101}
102#[cfg(test)]
103mod tests {
104
105    use arrow_schema::SchemaRef;
106    use datafusion::datasource::{MemTable, provider_as_source};
107    use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder, col, lit};
108    use datafusion_expr::{BinaryExpr, Operator};
109    use datatypes::arrow::datatypes::{DataType, Field, Schema};
110
111    use super::*;
112
113    fn optimize(plan: LogicalPlan) -> Result<LogicalPlan> {
114        TranscribeAtatRule.analyze(plan, &ConfigOptions::default())
115    }
116
117    fn prepare_test_plan_builder() -> LogicalPlanBuilder {
118        let schema = Schema::new(vec![
119            Field::new("a", DataType::Utf8, false),
120            Field::new("b", DataType::Utf8, false),
121        ]);
122        let table = MemTable::try_new(SchemaRef::from(schema), vec![vec![]]).unwrap();
123        LogicalPlanBuilder::scan("t", provider_as_source(Arc::new(table)), None).unwrap()
124    }
125
126    #[test]
127    fn test_multiple_atat() {
128        let plan = prepare_test_plan_builder()
129            .filter(
130                Expr::BinaryExpr(BinaryExpr {
131                    left: Box::new(col("a")),
132                    op: Operator::AtAt,
133                    right: Box::new(lit("foo")),
134                })
135                .and(Expr::BinaryExpr(BinaryExpr {
136                    left: Box::new(col("b")),
137                    op: Operator::AtAt,
138                    right: Box::new(lit("bar")),
139                })),
140            )
141            .unwrap()
142            .project(vec![
143                Expr::BinaryExpr(BinaryExpr {
144                    left: Box::new(col("a")),
145                    op: Operator::AtAt,
146                    right: Box::new(col("b")),
147                }),
148                col("b"),
149            ])
150            .unwrap()
151            .build()
152            .unwrap();
153
154        let expected = r#"Projection: matches_term(t.a, t.b), t.b
155  Filter: matches_term(t.a, Utf8("foo")) AND matches_term(t.b, Utf8("bar"))
156    TableScan: t"#;
157
158        let optimized_plan = optimize(plan).unwrap();
159        let formatted = optimized_plan.to_string();
160
161        assert_eq!(formatted, expected);
162    }
163
164    #[test]
165    fn test_nested_atat() {
166        let plan = prepare_test_plan_builder()
167            .filter(
168                Expr::BinaryExpr(BinaryExpr {
169                    left: Box::new(col("a")),
170                    op: Operator::AtAt,
171                    right: Box::new(lit("foo")),
172                })
173                .and(
174                    Expr::BinaryExpr(BinaryExpr {
175                        left: Box::new(col("b")),
176                        op: Operator::AtAt,
177                        right: Box::new(lit("bar")),
178                    })
179                    .or(Expr::BinaryExpr(BinaryExpr {
180                        left: Box::new(
181                            // Nested case in function argument
182                            Expr::BinaryExpr(BinaryExpr {
183                                left: Box::new(col("a")),
184                                op: Operator::AtAt,
185                                right: Box::new(lit("nested")),
186                            }),
187                        ),
188                        op: Operator::Eq,
189                        right: Box::new(lit(true)),
190                    })),
191                ),
192            )
193            .unwrap()
194            .project(vec![
195                col("a"),
196                // Complex nested expression with multiple @@ operators
197                Expr::BinaryExpr(BinaryExpr {
198                    left: Box::new(Expr::BinaryExpr(BinaryExpr {
199                        left: Box::new(col("a")),
200                        op: Operator::AtAt,
201                        right: Box::new(lit("foo")),
202                    })),
203                    op: Operator::And,
204                    right: Box::new(Expr::BinaryExpr(BinaryExpr {
205                        left: Box::new(col("b")),
206                        op: Operator::AtAt,
207                        right: Box::new(lit("bar")),
208                    })),
209                }),
210            ])
211            .unwrap()
212            .build()
213            .unwrap();
214
215        let expected = r#"Projection: t.a, matches_term(t.a, Utf8("foo")) AND matches_term(t.b, Utf8("bar"))
216  Filter: matches_term(t.a, Utf8("foo")) AND (matches_term(t.b, Utf8("bar")) OR matches_term(t.a, Utf8("nested")) = Boolean(true))
217    TableScan: t"#;
218
219        let optimized_plan = optimize(plan).unwrap();
220        let formatted = optimized_plan.to_string();
221
222        assert_eq!(formatted, expected);
223    }
224}