query/optimizer/
string_normalization.rs1use arrow_schema::DataType;
16use datafusion::config::ConfigOptions;
17use datafusion::logical_expr::expr::Cast;
18use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
19use datafusion_common::{Result, ScalarValue};
20use datafusion_expr::{Expr, LogicalPlan};
21use datafusion_optimizer::analyzer::AnalyzerRule;
22
23use crate::plan::ExtractExpr;
24
25#[derive(Debug)]
28pub struct StringNormalizationRule;
29
30impl AnalyzerRule for StringNormalizationRule {
31 fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result<LogicalPlan> {
32 plan.transform(|plan| match plan {
33 LogicalPlan::Projection(_)
34 | LogicalPlan::Filter(_)
35 | LogicalPlan::Window(_)
36 | LogicalPlan::Aggregate(_)
37 | LogicalPlan::Sort(_)
38 | LogicalPlan::Join(_)
39 | LogicalPlan::Repartition(_)
40 | LogicalPlan::Union(_)
41 | LogicalPlan::TableScan(_)
42 | LogicalPlan::EmptyRelation(_)
43 | LogicalPlan::Subquery(_)
44 | LogicalPlan::SubqueryAlias(_)
45 | LogicalPlan::Statement(_)
46 | LogicalPlan::Values(_)
47 | LogicalPlan::Analyze(_)
48 | LogicalPlan::Extension(_)
49 | LogicalPlan::Dml(_)
50 | LogicalPlan::Copy(_)
51 | LogicalPlan::RecursiveQuery(_) => {
52 let mut converter = StringNormalizationConverter;
53 let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
54 let expr = plan
55 .expressions_consider_join()
56 .into_iter()
57 .map(|e| e.rewrite(&mut converter).map(|x| x.data))
58 .collect::<Result<Vec<_>>>()?;
59 if expr != plan.expressions_consider_join() {
60 plan.with_new_exprs(expr, inputs).map(Transformed::yes)
61 } else {
62 Ok(Transformed::no(plan))
63 }
64 }
65 LogicalPlan::Distinct(_)
66 | LogicalPlan::Limit(_)
67 | LogicalPlan::Explain(_)
68 | LogicalPlan::Unnest(_)
69 | LogicalPlan::Ddl(_)
70 | LogicalPlan::DescribeTable(_) => Ok(Transformed::no(plan)),
71 })
72 .map(|x| x.data)
73 }
74
75 fn name(&self) -> &str {
76 "StringNormalizationRule"
77 }
78}
79
80struct StringNormalizationConverter;
81
82impl TreeNodeRewriter for StringNormalizationConverter {
83 type Node = Expr;
84
85 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
89 let new_expr = match expr {
90 Expr::Cast(Cast { expr, data_type }) => {
91 let expr = match data_type {
92 DataType::Timestamp(_, _) => match *expr {
93 Expr::Literal(value, _) => match value {
94 ScalarValue::Utf8(Some(s)) => trim_utf_expr(s),
95 _ => Expr::Literal(value, None),
96 },
97 expr => expr,
98 },
99 _ => *expr,
100 };
101 Expr::Cast(Cast {
102 expr: Box::new(expr),
103 data_type,
104 })
105 }
106 expr => expr,
107 };
108 Ok(Transformed::yes(new_expr))
109 }
110}
111
112fn trim_utf_expr(s: String) -> Expr {
113 let parts: Vec<_> = s.split_whitespace().collect();
114 let trimmed = parts.join(" ");
115 Expr::Literal(ScalarValue::Utf8(Some(trimmed)), None)
116}
117
118#[cfg(test)]
119mod tests {
120 use std::sync::Arc;
121
122 use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
123 use arrow::datatypes::{DataType, SchemaRef};
124 use arrow_schema::{Field, Schema, TimeUnit};
125 use datafusion::datasource::{MemTable, provider_as_source};
126 use datafusion_common::config::ConfigOptions;
127 use datafusion_expr::{Cast, Expr, LogicalPlan, LogicalPlanBuilder, lit};
128 use datafusion_optimizer::analyzer::AnalyzerRule;
129
130 use crate::optimizer::string_normalization::StringNormalizationRule;
131
132 #[test]
133 fn test_normalization_for_string_with_extra_whitespaces_to_timestamp_cast() {
134 let timestamp_str_with_whitespaces = " 2017-07-23 13:10:11 ";
135 let config = &ConfigOptions::default();
136 let projects = vec![
137 create_timestamp_cast_project(Nanosecond, timestamp_str_with_whitespaces),
138 create_timestamp_cast_project(Microsecond, timestamp_str_with_whitespaces),
139 create_timestamp_cast_project(Millisecond, timestamp_str_with_whitespaces),
140 create_timestamp_cast_project(Second, timestamp_str_with_whitespaces),
141 ];
142 for (time_unit, proj) in projects {
143 let plan = create_test_plan_with_project(proj);
144 let result = StringNormalizationRule.analyze(plan, config).unwrap();
145 let expected = format!(
146 "Projection: CAST(Utf8(\"2017-07-23 13:10:11\") AS Timestamp({:#?}, None))\n TableScan: t",
147 time_unit
148 );
149 assert_eq!(expected, result.to_string());
150 }
151 }
152
153 #[test]
154 fn test_normalization_for_non_timestamp_casts() {
155 let config = &ConfigOptions::default();
156 let proj_int_to_timestamp = vec![Expr::Cast(Cast::new(
157 Box::new(lit(158412331400600000_i64)),
158 DataType::Timestamp(Nanosecond, None),
159 ))];
160 let int_to_timestamp_plan = create_test_plan_with_project(proj_int_to_timestamp);
161 let result = StringNormalizationRule
162 .analyze(int_to_timestamp_plan, config)
163 .unwrap();
164 let expected = String::from(
165 "Projection: CAST(Int64(158412331400600000) AS Timestamp(Nanosecond, None))\n TableScan: t",
166 );
167 assert_eq!(expected, result.to_string());
168
169 let proj_string_to_int = vec![Expr::Cast(Cast::new(
170 Box::new(lit(" 5 ")),
171 DataType::Int32,
172 ))];
173 let string_to_int_plan = create_test_plan_with_project(proj_string_to_int);
174 let result = StringNormalizationRule
175 .analyze(string_to_int_plan, &ConfigOptions::default())
176 .unwrap();
177 let expected = String::from("Projection: CAST(Utf8(\" 5 \") AS Int32)\n TableScan: t");
178 assert_eq!(expected, result.to_string());
179 }
180
181 fn create_test_plan_with_project(proj: Vec<Expr>) -> LogicalPlan {
182 prepare_test_plan_builder()
183 .project(proj)
184 .unwrap()
185 .build()
186 .unwrap()
187 }
188
189 fn create_timestamp_cast_project(unit: TimeUnit, timestamp_str: &str) -> (TimeUnit, Vec<Expr>) {
190 let proj = vec![Expr::Cast(Cast::new(
191 Box::new(lit(timestamp_str)),
192 DataType::Timestamp(unit, None),
193 ))];
194 (unit, proj)
195 }
196
197 fn prepare_test_plan_builder() -> LogicalPlanBuilder {
198 let schema = Schema::new(vec![Field::new("f", DataType::Float64, false)]);
199 let table = MemTable::try_new(SchemaRef::from(schema), vec![vec![]]).unwrap();
200 LogicalPlanBuilder::scan("t", provider_as_source(Arc::new(table)), None).unwrap()
201 }
202}