query/optimizer/
transcribe_atat.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_function::scalars::matches_term::MatchesTermFunction;
18use common_function::scalars::udf::create_udf;
19use common_function::state::FunctionState;
20use datafusion::config::ConfigOptions;
21use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
22use datafusion_common::Result;
23use datafusion_expr::expr::ScalarFunction;
24use datafusion_expr::{Expr, LogicalPlan};
25use datafusion_optimizer::analyzer::AnalyzerRule;
26use session::context::QueryContext;
27
28use crate::plan::ExtractExpr;
29
30/// TranscribeAtatRule is an analyzer rule that transcribes `@@` operator
31/// to `matches_term` function.
32///
33/// Example:
34/// ```sql
35/// SELECT matches_term('cat!', 'cat') as result;
36///
37/// SELECT matches_term(`log_message`, '/start') as `matches_start` FROM t;
38/// ```
39///
40/// to
41///
42/// ```sql
43/// SELECT 'cat!' @@ 'cat' as result;
44///
45/// SELECT `log_message` @@ '/start' as `matches_start` FROM t;
46/// ```
47#[derive(Debug)]
48pub struct TranscribeAtatRule;
49
50impl AnalyzerRule for TranscribeAtatRule {
51    fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result<LogicalPlan> {
52        plan.transform(Self::do_analyze).map(|x| x.data)
53    }
54
55    fn name(&self) -> &str {
56        "TranscribeAtatRule"
57    }
58}
59
60impl TranscribeAtatRule {
61    fn do_analyze(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
62        let mut rewriter = TranscribeAtatRewriter::default();
63        let new_expr = plan
64            .expressions_consider_join()
65            .into_iter()
66            .map(|e| e.rewrite(&mut rewriter).map(|x| x.data))
67            .collect::<Result<Vec<_>>>()?;
68
69        if rewriter.transcribed {
70            let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
71            plan.with_new_exprs(new_expr, inputs).map(Transformed::yes)
72        } else {
73            Ok(Transformed::no(plan))
74        }
75    }
76}
77
78#[derive(Default)]
79struct TranscribeAtatRewriter {
80    transcribed: bool,
81}
82
83impl TreeNodeRewriter for TranscribeAtatRewriter {
84    type Node = Expr;
85
86    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
87        if let Expr::BinaryExpr(binary_expr) = &expr
88            && matches!(binary_expr.op, datafusion_expr::Operator::AtAt)
89        {
90            self.transcribed = true;
91            let scalar_udf = create_udf(
92                Arc::new(MatchesTermFunction),
93                QueryContext::arc(),
94                Arc::new(FunctionState::default()),
95            );
96            let exprs = vec![
97                binary_expr.left.as_ref().clone(),
98                binary_expr.right.as_ref().clone(),
99            ];
100            Ok(Transformed::yes(Expr::ScalarFunction(
101                ScalarFunction::new_udf(Arc::new(scalar_udf), exprs),
102            )))
103        } else {
104            Ok(Transformed::no(expr))
105        }
106    }
107}
108#[cfg(test)]
109mod tests {
110
111    use arrow_schema::SchemaRef;
112    use datafusion::datasource::{provider_as_source, MemTable};
113    use datafusion::logical_expr::{col, lit, LogicalPlan, LogicalPlanBuilder};
114    use datafusion_expr::{BinaryExpr, Operator};
115    use datatypes::arrow::datatypes::{DataType, Field, Schema};
116
117    use super::*;
118
119    fn optimize(plan: LogicalPlan) -> Result<LogicalPlan> {
120        TranscribeAtatRule.analyze(plan, &ConfigOptions::default())
121    }
122
123    fn prepare_test_plan_builder() -> LogicalPlanBuilder {
124        let schema = Schema::new(vec![
125            Field::new("a", DataType::Utf8, false),
126            Field::new("b", DataType::Utf8, false),
127        ]);
128        let table = MemTable::try_new(SchemaRef::from(schema), vec![]).unwrap();
129        LogicalPlanBuilder::scan("t", provider_as_source(Arc::new(table)), None).unwrap()
130    }
131
132    #[test]
133    fn test_multiple_atat() {
134        let plan = prepare_test_plan_builder()
135            .filter(
136                Expr::BinaryExpr(BinaryExpr {
137                    left: Box::new(col("a")),
138                    op: Operator::AtAt,
139                    right: Box::new(lit("foo")),
140                })
141                .and(Expr::BinaryExpr(BinaryExpr {
142                    left: Box::new(col("b")),
143                    op: Operator::AtAt,
144                    right: Box::new(lit("bar")),
145                })),
146            )
147            .unwrap()
148            .project(vec![
149                Expr::BinaryExpr(BinaryExpr {
150                    left: Box::new(col("a")),
151                    op: Operator::AtAt,
152                    right: Box::new(col("b")),
153                }),
154                col("b"),
155            ])
156            .unwrap()
157            .build()
158            .unwrap();
159
160        let expected = r#"Projection: matches_term(t.a, t.b), t.b
161  Filter: matches_term(t.a, Utf8("foo")) AND matches_term(t.b, Utf8("bar"))
162    TableScan: t"#;
163
164        let optimized_plan = optimize(plan).unwrap();
165        let formatted = optimized_plan.to_string();
166
167        assert_eq!(formatted, expected);
168    }
169
170    #[test]
171    fn test_nested_atat() {
172        let plan = prepare_test_plan_builder()
173            .filter(
174                Expr::BinaryExpr(BinaryExpr {
175                    left: Box::new(col("a")),
176                    op: Operator::AtAt,
177                    right: Box::new(lit("foo")),
178                })
179                .and(
180                    Expr::BinaryExpr(BinaryExpr {
181                        left: Box::new(col("b")),
182                        op: Operator::AtAt,
183                        right: Box::new(lit("bar")),
184                    })
185                    .or(Expr::BinaryExpr(BinaryExpr {
186                        left: Box::new(
187                            // Nested case in function argument
188                            Expr::BinaryExpr(BinaryExpr {
189                                left: Box::new(col("a")),
190                                op: Operator::AtAt,
191                                right: Box::new(lit("nested")),
192                            }),
193                        ),
194                        op: Operator::Eq,
195                        right: Box::new(lit(true)),
196                    })),
197                ),
198            )
199            .unwrap()
200            .project(vec![
201                col("a"),
202                // Complex nested expression with multiple @@ operators
203                Expr::BinaryExpr(BinaryExpr {
204                    left: Box::new(Expr::BinaryExpr(BinaryExpr {
205                        left: Box::new(col("a")),
206                        op: Operator::AtAt,
207                        right: Box::new(lit("foo")),
208                    })),
209                    op: Operator::And,
210                    right: Box::new(Expr::BinaryExpr(BinaryExpr {
211                        left: Box::new(col("b")),
212                        op: Operator::AtAt,
213                        right: Box::new(lit("bar")),
214                    })),
215                }),
216            ])
217            .unwrap()
218            .build()
219            .unwrap();
220
221        let expected = r#"Projection: t.a, matches_term(t.a, Utf8("foo")) AND matches_term(t.b, Utf8("bar"))
222  Filter: matches_term(t.a, Utf8("foo")) AND (matches_term(t.b, Utf8("bar")) OR matches_term(t.a, Utf8("nested")) = Boolean(true))
223    TableScan: t"#;
224
225        let optimized_plan = optimize(plan).unwrap();
226        let formatted = optimized_plan.to_string();
227
228        assert_eq!(formatted, expected);
229    }
230}