query/optimizer/
string_normalization.rs1use arrow_schema::DataType;
16use datafusion::config::ConfigOptions;
17use datafusion::logical_expr::expr::Cast;
18use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
19use datafusion_common::{Result, ScalarValue};
20use datafusion_expr::{Expr, LogicalPlan};
21use datafusion_optimizer::analyzer::AnalyzerRule;
22
23use crate::plan::ExtractExpr;
24
25#[derive(Debug)]
28pub struct StringNormalizationRule;
29
30impl AnalyzerRule for StringNormalizationRule {
31 fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result<LogicalPlan> {
32 plan.transform(|plan| match plan {
33 LogicalPlan::Projection(_)
34 | LogicalPlan::Filter(_)
35 | LogicalPlan::Window(_)
36 | LogicalPlan::Aggregate(_)
37 | LogicalPlan::Sort(_)
38 | LogicalPlan::Join(_)
39 | LogicalPlan::Repartition(_)
40 | LogicalPlan::Union(_)
41 | LogicalPlan::TableScan(_)
42 | LogicalPlan::EmptyRelation(_)
43 | LogicalPlan::Subquery(_)
44 | LogicalPlan::SubqueryAlias(_)
45 | LogicalPlan::Statement(_)
46 | LogicalPlan::Values(_)
47 | LogicalPlan::Analyze(_)
48 | LogicalPlan::Extension(_)
49 | LogicalPlan::Dml(_)
50 | LogicalPlan::Copy(_)
51 | LogicalPlan::RecursiveQuery(_) => {
52 let mut converter = StringNormalizationConverter;
53 let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
54 let expr = plan
55 .expressions_consider_join()
56 .into_iter()
57 .map(|e| e.rewrite(&mut converter).map(|x| x.data))
58 .collect::<Result<Vec<_>>>()?;
59 if expr != plan.expressions_consider_join() {
60 plan.with_new_exprs(expr, inputs).map(Transformed::yes)
61 } else {
62 Ok(Transformed::no(plan))
63 }
64 }
65 LogicalPlan::Distinct(_)
66 | LogicalPlan::Limit(_)
67 | LogicalPlan::Explain(_)
68 | LogicalPlan::Unnest(_)
69 | LogicalPlan::Ddl(_)
70 | LogicalPlan::DescribeTable(_) => Ok(Transformed::no(plan)),
71 })
72 .map(|x| x.data)
73 }
74
75 fn name(&self) -> &str {
76 "StringNormalizationRule"
77 }
78}
79
80struct StringNormalizationConverter;
81
82impl TreeNodeRewriter for StringNormalizationConverter {
83 type Node = Expr;
84
85 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
89 let new_expr = match expr {
90 Expr::Cast(Cast { expr, data_type }) => {
91 let expr = match data_type {
92 DataType::Timestamp(_, _) => match *expr {
93 Expr::Literal(value) => match value {
94 ScalarValue::Utf8(Some(s)) => trim_utf_expr(s),
95 _ => Expr::Literal(value),
96 },
97 expr => expr,
98 },
99 _ => *expr,
100 };
101 Expr::Cast(Cast {
102 expr: Box::new(expr),
103 data_type,
104 })
105 }
106 expr => expr,
107 };
108 Ok(Transformed::yes(new_expr))
109 }
110}
111
112fn trim_utf_expr(s: String) -> Expr {
113 let parts: Vec<_> = s.split_whitespace().collect();
114 let trimmed = parts.join(" ");
115 Expr::Literal(ScalarValue::Utf8(Some(trimmed)))
116}
117
118#[cfg(test)]
119mod tests {
120 use std::sync::Arc;
121
122 use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
123 use arrow::datatypes::{DataType, SchemaRef};
124 use arrow_schema::{Field, Schema, TimeUnit};
125 use datafusion::datasource::{provider_as_source, MemTable};
126 use datafusion_common::config::ConfigOptions;
127 use datafusion_expr::{lit, Cast, Expr, LogicalPlan, LogicalPlanBuilder};
128 use datafusion_optimizer::analyzer::AnalyzerRule;
129
130 use crate::optimizer::string_normalization::StringNormalizationRule;
131
132 #[test]
133 fn test_normalization_for_string_with_extra_whitespaces_to_timestamp_cast() {
134 let timestamp_str_with_whitespaces = " 2017-07-23 13:10:11 ";
135 let config = &ConfigOptions::default();
136 let projects = vec![
137 create_timestamp_cast_project(Nanosecond, timestamp_str_with_whitespaces),
138 create_timestamp_cast_project(Microsecond, timestamp_str_with_whitespaces),
139 create_timestamp_cast_project(Millisecond, timestamp_str_with_whitespaces),
140 create_timestamp_cast_project(Second, timestamp_str_with_whitespaces),
141 ];
142 for (time_unit, proj) in projects {
143 let plan = create_test_plan_with_project(proj);
144 let result = StringNormalizationRule.analyze(plan, config).unwrap();
145 let expected = format!("Projection: CAST(Utf8(\"2017-07-23 13:10:11\") AS Timestamp({:#?}, None))\n TableScan: t",
146 time_unit
147 );
148 assert_eq!(expected, result.to_string());
149 }
150 }
151
152 #[test]
153 fn test_normalization_for_non_timestamp_casts() {
154 let config = &ConfigOptions::default();
155 let proj_int_to_timestamp = vec![Expr::Cast(Cast::new(
156 Box::new(lit(158412331400600000_i64)),
157 DataType::Timestamp(Nanosecond, None),
158 ))];
159 let int_to_timestamp_plan = create_test_plan_with_project(proj_int_to_timestamp);
160 let result = StringNormalizationRule
161 .analyze(int_to_timestamp_plan, config)
162 .unwrap();
163 let expected = String::from(
164 "Projection: CAST(Int64(158412331400600000) AS Timestamp(Nanosecond, None))\n TableScan: t"
165 );
166 assert_eq!(expected, result.to_string());
167
168 let proj_string_to_int = vec![Expr::Cast(Cast::new(
169 Box::new(lit(" 5 ")),
170 DataType::Int32,
171 ))];
172 let string_to_int_plan = create_test_plan_with_project(proj_string_to_int);
173 let result = StringNormalizationRule
174 .analyze(string_to_int_plan, &ConfigOptions::default())
175 .unwrap();
176 let expected = String::from("Projection: CAST(Utf8(\" 5 \") AS Int32)\n TableScan: t");
177 assert_eq!(expected, result.to_string());
178 }
179
180 fn create_test_plan_with_project(proj: Vec<Expr>) -> LogicalPlan {
181 prepare_test_plan_builder()
182 .project(proj)
183 .unwrap()
184 .build()
185 .unwrap()
186 }
187
188 fn create_timestamp_cast_project(unit: TimeUnit, timestamp_str: &str) -> (TimeUnit, Vec<Expr>) {
189 let proj = vec![Expr::Cast(Cast::new(
190 Box::new(lit(timestamp_str)),
191 DataType::Timestamp(unit, None),
192 ))];
193 (unit, proj)
194 }
195
196 fn prepare_test_plan_builder() -> LogicalPlanBuilder {
197 let schema = Schema::new(vec![Field::new("f", DataType::Float64, false)]);
198 let table = MemTable::try_new(SchemaRef::from(schema), vec![]).unwrap();
199 LogicalPlanBuilder::scan("t", provider_as_source(Arc::new(table)), None).unwrap()
200 }
201}