query/optimizer/
transcribe_atat.rs1use std::sync::Arc;
16
17use common_function::scalars::matches_term::MatchesTermFunction;
18use common_function::scalars::udf::create_udf;
19use datafusion::config::ConfigOptions;
20use datafusion_common::Result;
21use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
22use datafusion_expr::expr::ScalarFunction;
23use datafusion_expr::{Expr, LogicalPlan};
24use datafusion_optimizer::analyzer::AnalyzerRule;
25
26use crate::plan::ExtractExpr;
27
28#[derive(Debug)]
46pub struct TranscribeAtatRule;
47
48impl AnalyzerRule for TranscribeAtatRule {
49 fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result<LogicalPlan> {
50 plan.transform(Self::do_analyze).map(|x| x.data)
51 }
52
53 fn name(&self) -> &str {
54 "TranscribeAtatRule"
55 }
56}
57
58impl TranscribeAtatRule {
59 fn do_analyze(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
60 let mut rewriter = TranscribeAtatRewriter::default();
61 let new_expr = plan
62 .expressions_consider_join()
63 .into_iter()
64 .map(|e| e.rewrite(&mut rewriter).map(|x| x.data))
65 .collect::<Result<Vec<_>>>()?;
66
67 if rewriter.transcribed {
68 let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
69 plan.with_new_exprs(new_expr, inputs).map(Transformed::yes)
70 } else {
71 Ok(Transformed::no(plan))
72 }
73 }
74}
75
76#[derive(Default)]
77struct TranscribeAtatRewriter {
78 transcribed: bool,
79}
80
81impl TreeNodeRewriter for TranscribeAtatRewriter {
82 type Node = Expr;
83
84 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
85 if let Expr::BinaryExpr(binary_expr) = &expr
86 && matches!(binary_expr.op, datafusion_expr::Operator::AtAt)
87 {
88 self.transcribed = true;
89 let scalar_udf = create_udf(Arc::new(MatchesTermFunction::default()));
90 let exprs = vec![
91 binary_expr.left.as_ref().clone(),
92 binary_expr.right.as_ref().clone(),
93 ];
94 Ok(Transformed::yes(Expr::ScalarFunction(
95 ScalarFunction::new_udf(Arc::new(scalar_udf), exprs),
96 )))
97 } else {
98 Ok(Transformed::no(expr))
99 }
100 }
101}
102#[cfg(test)]
103mod tests {
104
105 use arrow_schema::SchemaRef;
106 use datafusion::datasource::{MemTable, provider_as_source};
107 use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder, col, lit};
108 use datafusion_expr::{BinaryExpr, Operator};
109 use datatypes::arrow::datatypes::{DataType, Field, Schema};
110
111 use super::*;
112
113 fn optimize(plan: LogicalPlan) -> Result<LogicalPlan> {
114 TranscribeAtatRule.analyze(plan, &ConfigOptions::default())
115 }
116
117 fn prepare_test_plan_builder() -> LogicalPlanBuilder {
118 let schema = Schema::new(vec![
119 Field::new("a", DataType::Utf8, false),
120 Field::new("b", DataType::Utf8, false),
121 ]);
122 let table = MemTable::try_new(SchemaRef::from(schema), vec![vec![]]).unwrap();
123 LogicalPlanBuilder::scan("t", provider_as_source(Arc::new(table)), None).unwrap()
124 }
125
126 #[test]
127 fn test_multiple_atat() {
128 let plan = prepare_test_plan_builder()
129 .filter(
130 Expr::BinaryExpr(BinaryExpr {
131 left: Box::new(col("a")),
132 op: Operator::AtAt,
133 right: Box::new(lit("foo")),
134 })
135 .and(Expr::BinaryExpr(BinaryExpr {
136 left: Box::new(col("b")),
137 op: Operator::AtAt,
138 right: Box::new(lit("bar")),
139 })),
140 )
141 .unwrap()
142 .project(vec![
143 Expr::BinaryExpr(BinaryExpr {
144 left: Box::new(col("a")),
145 op: Operator::AtAt,
146 right: Box::new(col("b")),
147 }),
148 col("b"),
149 ])
150 .unwrap()
151 .build()
152 .unwrap();
153
154 let expected = r#"Projection: matches_term(t.a, t.b), t.b
155 Filter: matches_term(t.a, Utf8("foo")) AND matches_term(t.b, Utf8("bar"))
156 TableScan: t"#;
157
158 let optimized_plan = optimize(plan).unwrap();
159 let formatted = optimized_plan.to_string();
160
161 assert_eq!(formatted, expected);
162 }
163
164 #[test]
165 fn test_nested_atat() {
166 let plan = prepare_test_plan_builder()
167 .filter(
168 Expr::BinaryExpr(BinaryExpr {
169 left: Box::new(col("a")),
170 op: Operator::AtAt,
171 right: Box::new(lit("foo")),
172 })
173 .and(
174 Expr::BinaryExpr(BinaryExpr {
175 left: Box::new(col("b")),
176 op: Operator::AtAt,
177 right: Box::new(lit("bar")),
178 })
179 .or(Expr::BinaryExpr(BinaryExpr {
180 left: Box::new(
181 Expr::BinaryExpr(BinaryExpr {
183 left: Box::new(col("a")),
184 op: Operator::AtAt,
185 right: Box::new(lit("nested")),
186 }),
187 ),
188 op: Operator::Eq,
189 right: Box::new(lit(true)),
190 })),
191 ),
192 )
193 .unwrap()
194 .project(vec![
195 col("a"),
196 Expr::BinaryExpr(BinaryExpr {
198 left: Box::new(Expr::BinaryExpr(BinaryExpr {
199 left: Box::new(col("a")),
200 op: Operator::AtAt,
201 right: Box::new(lit("foo")),
202 })),
203 op: Operator::And,
204 right: Box::new(Expr::BinaryExpr(BinaryExpr {
205 left: Box::new(col("b")),
206 op: Operator::AtAt,
207 right: Box::new(lit("bar")),
208 })),
209 }),
210 ])
211 .unwrap()
212 .build()
213 .unwrap();
214
215 let expected = r#"Projection: t.a, matches_term(t.a, Utf8("foo")) AND matches_term(t.b, Utf8("bar"))
216 Filter: matches_term(t.a, Utf8("foo")) AND (matches_term(t.b, Utf8("bar")) OR matches_term(t.a, Utf8("nested")) = Boolean(true))
217 TableScan: t"#;
218
219 let optimized_plan = optimize(plan).unwrap();
220 let formatted = optimized_plan.to_string();
221
222 assert_eq!(formatted, expected);
223 }
224}