query/optimizer/
transcribe_atat.rs1use std::sync::Arc;
16
17use common_function::scalars::matches_term::MatchesTermFunction;
18use common_function::scalars::udf::create_udf;
19use common_function::state::FunctionState;
20use datafusion::config::ConfigOptions;
21use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
22use datafusion_common::Result;
23use datafusion_expr::expr::ScalarFunction;
24use datafusion_expr::{Expr, LogicalPlan};
25use datafusion_optimizer::analyzer::AnalyzerRule;
26use session::context::QueryContext;
27
28use crate::plan::ExtractExpr;
29
30#[derive(Debug)]
48pub struct TranscribeAtatRule;
49
50impl AnalyzerRule for TranscribeAtatRule {
51 fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result<LogicalPlan> {
52 plan.transform(Self::do_analyze).map(|x| x.data)
53 }
54
55 fn name(&self) -> &str {
56 "TranscribeAtatRule"
57 }
58}
59
60impl TranscribeAtatRule {
61 fn do_analyze(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
62 let mut rewriter = TranscribeAtatRewriter::default();
63 let new_expr = plan
64 .expressions_consider_join()
65 .into_iter()
66 .map(|e| e.rewrite(&mut rewriter).map(|x| x.data))
67 .collect::<Result<Vec<_>>>()?;
68
69 if rewriter.transcribed {
70 let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
71 plan.with_new_exprs(new_expr, inputs).map(Transformed::yes)
72 } else {
73 Ok(Transformed::no(plan))
74 }
75 }
76}
77
78#[derive(Default)]
79struct TranscribeAtatRewriter {
80 transcribed: bool,
81}
82
83impl TreeNodeRewriter for TranscribeAtatRewriter {
84 type Node = Expr;
85
86 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
87 if let Expr::BinaryExpr(binary_expr) = &expr
88 && matches!(binary_expr.op, datafusion_expr::Operator::AtAt)
89 {
90 self.transcribed = true;
91 let scalar_udf = create_udf(
92 Arc::new(MatchesTermFunction),
93 QueryContext::arc(),
94 Arc::new(FunctionState::default()),
95 );
96 let exprs = vec![
97 binary_expr.left.as_ref().clone(),
98 binary_expr.right.as_ref().clone(),
99 ];
100 Ok(Transformed::yes(Expr::ScalarFunction(
101 ScalarFunction::new_udf(Arc::new(scalar_udf), exprs),
102 )))
103 } else {
104 Ok(Transformed::no(expr))
105 }
106 }
107}
108#[cfg(test)]
109mod tests {
110
111 use arrow_schema::SchemaRef;
112 use datafusion::datasource::{provider_as_source, MemTable};
113 use datafusion::logical_expr::{col, lit, LogicalPlan, LogicalPlanBuilder};
114 use datafusion_expr::{BinaryExpr, Operator};
115 use datatypes::arrow::datatypes::{DataType, Field, Schema};
116
117 use super::*;
118
119 fn optimize(plan: LogicalPlan) -> Result<LogicalPlan> {
120 TranscribeAtatRule.analyze(plan, &ConfigOptions::default())
121 }
122
123 fn prepare_test_plan_builder() -> LogicalPlanBuilder {
124 let schema = Schema::new(vec![
125 Field::new("a", DataType::Utf8, false),
126 Field::new("b", DataType::Utf8, false),
127 ]);
128 let table = MemTable::try_new(SchemaRef::from(schema), vec![]).unwrap();
129 LogicalPlanBuilder::scan("t", provider_as_source(Arc::new(table)), None).unwrap()
130 }
131
132 #[test]
133 fn test_multiple_atat() {
134 let plan = prepare_test_plan_builder()
135 .filter(
136 Expr::BinaryExpr(BinaryExpr {
137 left: Box::new(col("a")),
138 op: Operator::AtAt,
139 right: Box::new(lit("foo")),
140 })
141 .and(Expr::BinaryExpr(BinaryExpr {
142 left: Box::new(col("b")),
143 op: Operator::AtAt,
144 right: Box::new(lit("bar")),
145 })),
146 )
147 .unwrap()
148 .project(vec![
149 Expr::BinaryExpr(BinaryExpr {
150 left: Box::new(col("a")),
151 op: Operator::AtAt,
152 right: Box::new(col("b")),
153 }),
154 col("b"),
155 ])
156 .unwrap()
157 .build()
158 .unwrap();
159
160 let expected = r#"Projection: matches_term(t.a, t.b), t.b
161 Filter: matches_term(t.a, Utf8("foo")) AND matches_term(t.b, Utf8("bar"))
162 TableScan: t"#;
163
164 let optimized_plan = optimize(plan).unwrap();
165 let formatted = optimized_plan.to_string();
166
167 assert_eq!(formatted, expected);
168 }
169
170 #[test]
171 fn test_nested_atat() {
172 let plan = prepare_test_plan_builder()
173 .filter(
174 Expr::BinaryExpr(BinaryExpr {
175 left: Box::new(col("a")),
176 op: Operator::AtAt,
177 right: Box::new(lit("foo")),
178 })
179 .and(
180 Expr::BinaryExpr(BinaryExpr {
181 left: Box::new(col("b")),
182 op: Operator::AtAt,
183 right: Box::new(lit("bar")),
184 })
185 .or(Expr::BinaryExpr(BinaryExpr {
186 left: Box::new(
187 Expr::BinaryExpr(BinaryExpr {
189 left: Box::new(col("a")),
190 op: Operator::AtAt,
191 right: Box::new(lit("nested")),
192 }),
193 ),
194 op: Operator::Eq,
195 right: Box::new(lit(true)),
196 })),
197 ),
198 )
199 .unwrap()
200 .project(vec![
201 col("a"),
202 Expr::BinaryExpr(BinaryExpr {
204 left: Box::new(Expr::BinaryExpr(BinaryExpr {
205 left: Box::new(col("a")),
206 op: Operator::AtAt,
207 right: Box::new(lit("foo")),
208 })),
209 op: Operator::And,
210 right: Box::new(Expr::BinaryExpr(BinaryExpr {
211 left: Box::new(col("b")),
212 op: Operator::AtAt,
213 right: Box::new(lit("bar")),
214 })),
215 }),
216 ])
217 .unwrap()
218 .build()
219 .unwrap();
220
221 let expected = r#"Projection: t.a, matches_term(t.a, Utf8("foo")) AND matches_term(t.b, Utf8("bar"))
222 Filter: matches_term(t.a, Utf8("foo")) AND (matches_term(t.b, Utf8("bar")) OR matches_term(t.a, Utf8("nested")) = Boolean(true))
223 TableScan: t"#;
224
225 let optimized_plan = optimize(plan).unwrap();
226 let formatted = optimized_plan.to_string();
227
228 assert_eq!(formatted, expected);
229 }
230}