query/optimizer/
string_normalization.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use arrow_schema::DataType;
16use datafusion::config::ConfigOptions;
17use datafusion::logical_expr::expr::Cast;
18use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
19use datafusion_common::{Result, ScalarValue};
20use datafusion_expr::{Expr, LogicalPlan};
21use datafusion_optimizer::analyzer::AnalyzerRule;
22
23use crate::plan::ExtractExpr;
24
25/// StringNormalizationRule normalizes(trims) string values in logical plan.
26/// Mainly used for timestamp trimming
27#[derive(Debug)]
28pub struct StringNormalizationRule;
29
30impl AnalyzerRule for StringNormalizationRule {
31    fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result<LogicalPlan> {
32        plan.transform(|plan| match plan {
33            LogicalPlan::Projection(_)
34            | LogicalPlan::Filter(_)
35            | LogicalPlan::Window(_)
36            | LogicalPlan::Aggregate(_)
37            | LogicalPlan::Sort(_)
38            | LogicalPlan::Join(_)
39            | LogicalPlan::Repartition(_)
40            | LogicalPlan::Union(_)
41            | LogicalPlan::TableScan(_)
42            | LogicalPlan::EmptyRelation(_)
43            | LogicalPlan::Subquery(_)
44            | LogicalPlan::SubqueryAlias(_)
45            | LogicalPlan::Statement(_)
46            | LogicalPlan::Values(_)
47            | LogicalPlan::Analyze(_)
48            | LogicalPlan::Extension(_)
49            | LogicalPlan::Dml(_)
50            | LogicalPlan::Copy(_)
51            | LogicalPlan::RecursiveQuery(_) => {
52                let mut converter = StringNormalizationConverter;
53                let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
54                let expr = plan
55                    .expressions_consider_join()
56                    .into_iter()
57                    .map(|e| e.rewrite(&mut converter).map(|x| x.data))
58                    .collect::<Result<Vec<_>>>()?;
59                if expr != plan.expressions_consider_join() {
60                    plan.with_new_exprs(expr, inputs).map(Transformed::yes)
61                } else {
62                    Ok(Transformed::no(plan))
63                }
64            }
65            LogicalPlan::Distinct(_)
66            | LogicalPlan::Limit(_)
67            | LogicalPlan::Explain(_)
68            | LogicalPlan::Unnest(_)
69            | LogicalPlan::Ddl(_)
70            | LogicalPlan::DescribeTable(_) => Ok(Transformed::no(plan)),
71        })
72        .map(|x| x.data)
73    }
74
75    fn name(&self) -> &str {
76        "StringNormalizationRule"
77    }
78}
79
80struct StringNormalizationConverter;
81
82impl TreeNodeRewriter for StringNormalizationConverter {
83    type Node = Expr;
84
85    /// remove extra whitespaces from the String value when
86    /// there is a CAST from a String to Timestamp.
87    /// Otherwise - no modifications applied
88    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
89        let new_expr = match expr {
90            Expr::Cast(Cast { expr, data_type }) => {
91                let expr = match data_type {
92                    DataType::Timestamp(_, _) => match *expr {
93                        Expr::Literal(value, _) => match value {
94                            ScalarValue::Utf8(Some(s)) => trim_utf_expr(s),
95                            _ => Expr::Literal(value, None),
96                        },
97                        expr => expr,
98                    },
99                    _ => *expr,
100                };
101                Expr::Cast(Cast {
102                    expr: Box::new(expr),
103                    data_type,
104                })
105            }
106            expr => expr,
107        };
108        Ok(Transformed::yes(new_expr))
109    }
110}
111
112fn trim_utf_expr(s: String) -> Expr {
113    let parts: Vec<_> = s.split_whitespace().collect();
114    let trimmed = parts.join(" ");
115    Expr::Literal(ScalarValue::Utf8(Some(trimmed)), None)
116}
117
118#[cfg(test)]
119mod tests {
120    use std::sync::Arc;
121
122    use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
123    use arrow::datatypes::{DataType, SchemaRef};
124    use arrow_schema::{Field, Schema, TimeUnit};
125    use datafusion::datasource::{MemTable, provider_as_source};
126    use datafusion_common::config::ConfigOptions;
127    use datafusion_expr::{Cast, Expr, LogicalPlan, LogicalPlanBuilder, lit};
128    use datafusion_optimizer::analyzer::AnalyzerRule;
129
130    use crate::optimizer::string_normalization::StringNormalizationRule;
131
132    #[test]
133    fn test_normalization_for_string_with_extra_whitespaces_to_timestamp_cast() {
134        let timestamp_str_with_whitespaces = "    2017-07-23    13:10:11   ";
135        let config = &ConfigOptions::default();
136        let projects = vec![
137            create_timestamp_cast_project(Nanosecond, timestamp_str_with_whitespaces),
138            create_timestamp_cast_project(Microsecond, timestamp_str_with_whitespaces),
139            create_timestamp_cast_project(Millisecond, timestamp_str_with_whitespaces),
140            create_timestamp_cast_project(Second, timestamp_str_with_whitespaces),
141        ];
142        for (time_unit, proj) in projects {
143            let plan = create_test_plan_with_project(proj);
144            let result = StringNormalizationRule.analyze(plan, config).unwrap();
145            let expected = format!(
146                "Projection: CAST(Utf8(\"2017-07-23 13:10:11\") AS Timestamp({:#?}, None))\n  TableScan: t",
147                time_unit
148            );
149            assert_eq!(expected, result.to_string());
150        }
151    }
152
153    #[test]
154    fn test_normalization_for_non_timestamp_casts() {
155        let config = &ConfigOptions::default();
156        let proj_int_to_timestamp = vec![Expr::Cast(Cast::new(
157            Box::new(lit(158412331400600000_i64)),
158            DataType::Timestamp(Nanosecond, None),
159        ))];
160        let int_to_timestamp_plan = create_test_plan_with_project(proj_int_to_timestamp);
161        let result = StringNormalizationRule
162            .analyze(int_to_timestamp_plan, config)
163            .unwrap();
164        let expected = String::from(
165            "Projection: CAST(Int64(158412331400600000) AS Timestamp(Nanosecond, None))\n  TableScan: t",
166        );
167        assert_eq!(expected, result.to_string());
168
169        let proj_string_to_int = vec![Expr::Cast(Cast::new(
170            Box::new(lit("  5   ")),
171            DataType::Int32,
172        ))];
173        let string_to_int_plan = create_test_plan_with_project(proj_string_to_int);
174        let result = StringNormalizationRule
175            .analyze(string_to_int_plan, &ConfigOptions::default())
176            .unwrap();
177        let expected = String::from("Projection: CAST(Utf8(\"  5   \") AS Int32)\n  TableScan: t");
178        assert_eq!(expected, result.to_string());
179    }
180
181    fn create_test_plan_with_project(proj: Vec<Expr>) -> LogicalPlan {
182        prepare_test_plan_builder()
183            .project(proj)
184            .unwrap()
185            .build()
186            .unwrap()
187    }
188
189    fn create_timestamp_cast_project(unit: TimeUnit, timestamp_str: &str) -> (TimeUnit, Vec<Expr>) {
190        let proj = vec![Expr::Cast(Cast::new(
191            Box::new(lit(timestamp_str)),
192            DataType::Timestamp(unit, None),
193        ))];
194        (unit, proj)
195    }
196
197    fn prepare_test_plan_builder() -> LogicalPlanBuilder {
198        let schema = Schema::new(vec![Field::new("f", DataType::Float64, false)]);
199        let table = MemTable::try_new(SchemaRef::from(schema), vec![vec![]]).unwrap();
200        LogicalPlanBuilder::scan("t", provider_as_source(Arc::new(table)), None).unwrap()
201    }
202}