query/optimizer/
string_normalization.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use arrow_schema::DataType;
16use datafusion::config::ConfigOptions;
17use datafusion::logical_expr::expr::Cast;
18use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
19use datafusion_common::{Result, ScalarValue};
20use datafusion_expr::{Expr, LogicalPlan};
21use datafusion_optimizer::analyzer::AnalyzerRule;
22
23use crate::plan::ExtractExpr;
24
25/// StringNormalizationRule normalizes(trims) string values in logical plan.
26/// Mainly used for timestamp trimming
27#[derive(Debug)]
28pub struct StringNormalizationRule;
29
30impl AnalyzerRule for StringNormalizationRule {
31    fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result<LogicalPlan> {
32        plan.transform(|plan| match plan {
33            LogicalPlan::Projection(_)
34            | LogicalPlan::Filter(_)
35            | LogicalPlan::Window(_)
36            | LogicalPlan::Aggregate(_)
37            | LogicalPlan::Sort(_)
38            | LogicalPlan::Join(_)
39            | LogicalPlan::Repartition(_)
40            | LogicalPlan::Union(_)
41            | LogicalPlan::TableScan(_)
42            | LogicalPlan::EmptyRelation(_)
43            | LogicalPlan::Subquery(_)
44            | LogicalPlan::SubqueryAlias(_)
45            | LogicalPlan::Statement(_)
46            | LogicalPlan::Values(_)
47            | LogicalPlan::Analyze(_)
48            | LogicalPlan::Extension(_)
49            | LogicalPlan::Dml(_)
50            | LogicalPlan::Copy(_)
51            | LogicalPlan::RecursiveQuery(_) => {
52                let mut converter = StringNormalizationConverter;
53                let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
54                let expr = plan
55                    .expressions_consider_join()
56                    .into_iter()
57                    .map(|e| e.rewrite(&mut converter).map(|x| x.data))
58                    .collect::<Result<Vec<_>>>()?;
59                if expr != plan.expressions_consider_join() {
60                    plan.with_new_exprs(expr, inputs).map(Transformed::yes)
61                } else {
62                    Ok(Transformed::no(plan))
63                }
64            }
65            LogicalPlan::Distinct(_)
66            | LogicalPlan::Limit(_)
67            | LogicalPlan::Explain(_)
68            | LogicalPlan::Unnest(_)
69            | LogicalPlan::Ddl(_)
70            | LogicalPlan::DescribeTable(_) => Ok(Transformed::no(plan)),
71        })
72        .map(|x| x.data)
73    }
74
75    fn name(&self) -> &str {
76        "StringNormalizationRule"
77    }
78}
79
80struct StringNormalizationConverter;
81
82impl TreeNodeRewriter for StringNormalizationConverter {
83    type Node = Expr;
84
85    /// remove extra whitespaces from the String value when
86    /// there is a CAST from a String to Timestamp.
87    /// Otherwise - no modifications applied
88    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
89        let new_expr = match expr {
90            Expr::Cast(Cast { expr, data_type }) => {
91                let expr = match data_type {
92                    DataType::Timestamp(_, _) => match *expr {
93                        Expr::Literal(value) => match value {
94                            ScalarValue::Utf8(Some(s)) => trim_utf_expr(s),
95                            _ => Expr::Literal(value),
96                        },
97                        expr => expr,
98                    },
99                    _ => *expr,
100                };
101                Expr::Cast(Cast {
102                    expr: Box::new(expr),
103                    data_type,
104                })
105            }
106            expr => expr,
107        };
108        Ok(Transformed::yes(new_expr))
109    }
110}
111
112fn trim_utf_expr(s: String) -> Expr {
113    let parts: Vec<_> = s.split_whitespace().collect();
114    let trimmed = parts.join(" ");
115    Expr::Literal(ScalarValue::Utf8(Some(trimmed)))
116}
117
118#[cfg(test)]
119mod tests {
120    use std::sync::Arc;
121
122    use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
123    use arrow::datatypes::{DataType, SchemaRef};
124    use arrow_schema::{Field, Schema, TimeUnit};
125    use datafusion::datasource::{provider_as_source, MemTable};
126    use datafusion_common::config::ConfigOptions;
127    use datafusion_expr::{lit, Cast, Expr, LogicalPlan, LogicalPlanBuilder};
128    use datafusion_optimizer::analyzer::AnalyzerRule;
129
130    use crate::optimizer::string_normalization::StringNormalizationRule;
131
132    #[test]
133    fn test_normalization_for_string_with_extra_whitespaces_to_timestamp_cast() {
134        let timestamp_str_with_whitespaces = "    2017-07-23    13:10:11   ";
135        let config = &ConfigOptions::default();
136        let projects = vec![
137            create_timestamp_cast_project(Nanosecond, timestamp_str_with_whitespaces),
138            create_timestamp_cast_project(Microsecond, timestamp_str_with_whitespaces),
139            create_timestamp_cast_project(Millisecond, timestamp_str_with_whitespaces),
140            create_timestamp_cast_project(Second, timestamp_str_with_whitespaces),
141        ];
142        for (time_unit, proj) in projects {
143            let plan = create_test_plan_with_project(proj);
144            let result = StringNormalizationRule.analyze(plan, config).unwrap();
145            let expected = format!("Projection: CAST(Utf8(\"2017-07-23 13:10:11\") AS Timestamp({:#?}, None))\n  TableScan: t",
146                                   time_unit
147            );
148            assert_eq!(expected, result.to_string());
149        }
150    }
151
152    #[test]
153    fn test_normalization_for_non_timestamp_casts() {
154        let config = &ConfigOptions::default();
155        let proj_int_to_timestamp = vec![Expr::Cast(Cast::new(
156            Box::new(lit(158412331400600000_i64)),
157            DataType::Timestamp(Nanosecond, None),
158        ))];
159        let int_to_timestamp_plan = create_test_plan_with_project(proj_int_to_timestamp);
160        let result = StringNormalizationRule
161            .analyze(int_to_timestamp_plan, config)
162            .unwrap();
163        let expected = String::from(
164            "Projection: CAST(Int64(158412331400600000) AS Timestamp(Nanosecond, None))\n  TableScan: t"
165        );
166        assert_eq!(expected, result.to_string());
167
168        let proj_string_to_int = vec![Expr::Cast(Cast::new(
169            Box::new(lit("  5   ")),
170            DataType::Int32,
171        ))];
172        let string_to_int_plan = create_test_plan_with_project(proj_string_to_int);
173        let result = StringNormalizationRule
174            .analyze(string_to_int_plan, &ConfigOptions::default())
175            .unwrap();
176        let expected = String::from("Projection: CAST(Utf8(\"  5   \") AS Int32)\n  TableScan: t");
177        assert_eq!(expected, result.to_string());
178    }
179
180    fn create_test_plan_with_project(proj: Vec<Expr>) -> LogicalPlan {
181        prepare_test_plan_builder()
182            .project(proj)
183            .unwrap()
184            .build()
185            .unwrap()
186    }
187
188    fn create_timestamp_cast_project(unit: TimeUnit, timestamp_str: &str) -> (TimeUnit, Vec<Expr>) {
189        let proj = vec![Expr::Cast(Cast::new(
190            Box::new(lit(timestamp_str)),
191            DataType::Timestamp(unit, None),
192        ))];
193        (unit, proj)
194    }
195
196    fn prepare_test_plan_builder() -> LogicalPlanBuilder {
197        let schema = Schema::new(vec![Field::new("f", DataType::Float64, false)]);
198        let table = MemTable::try_new(SchemaRef::from(schema), vec![]).unwrap();
199        LogicalPlanBuilder::scan("t", provider_as_source(Arc::new(table)), None).unwrap()
200    }
201}