query/optimizer/
type_conversion.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use common_time::timestamp::{TimeUnit, Timestamp};
16use common_time::Timezone;
17use datafusion::config::ConfigOptions;
18use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
19use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
20use datafusion_expr::expr::InList;
21use datafusion_expr::{
22    Between, BinaryExpr, Expr, ExprSchemable, Filter, LogicalPlan, Operator, TableScan,
23};
24use datatypes::arrow::compute;
25use datatypes::arrow::datatypes::DataType;
26use session::context::QueryContextRef;
27
28use crate::optimizer::ExtensionAnalyzerRule;
29use crate::plan::ExtractExpr;
30use crate::QueryEngineContext;
31
32/// TypeConversionRule converts some literal values in logical plan to other types according
33/// to data type of corresponding columns.
34/// Specifically:
35/// - string literal of timestamp is converted to `Expr::Literal(ScalarValue::TimestampMillis)`
36/// - string literal of boolean is converted to `Expr::Literal(ScalarValue::Boolean)`
37pub struct TypeConversionRule;
38
39impl ExtensionAnalyzerRule for TypeConversionRule {
40    fn analyze(
41        &self,
42        plan: LogicalPlan,
43        ctx: &QueryEngineContext,
44        _config: &ConfigOptions,
45    ) -> Result<LogicalPlan> {
46        plan.transform(&|plan| match plan {
47            LogicalPlan::Filter(filter) => {
48                let mut converter = TypeConverter {
49                    schema: filter.input.schema().clone(),
50                    query_ctx: ctx.query_ctx(),
51                };
52                let rewritten = filter.predicate.clone().rewrite(&mut converter)?.data;
53                Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new(
54                    rewritten,
55                    filter.input,
56                )?)))
57            }
58            LogicalPlan::TableScan(TableScan {
59                table_name,
60                source,
61                projection,
62                projected_schema,
63                filters,
64                fetch,
65            }) => {
66                let mut converter = TypeConverter {
67                    schema: projected_schema.clone(),
68                    query_ctx: ctx.query_ctx(),
69                };
70                let rewrite_filters = filters
71                    .into_iter()
72                    .map(|e| e.rewrite(&mut converter).map(|x| x.data))
73                    .collect::<Result<Vec<_>>>()?;
74                Ok(Transformed::yes(LogicalPlan::TableScan(TableScan {
75                    table_name: table_name.clone(),
76                    source: source.clone(),
77                    projection,
78                    projected_schema,
79                    filters: rewrite_filters,
80                    fetch,
81                })))
82            }
83            LogicalPlan::Projection { .. }
84            | LogicalPlan::Window { .. }
85            | LogicalPlan::Aggregate { .. }
86            | LogicalPlan::Repartition { .. }
87            | LogicalPlan::Extension { .. }
88            | LogicalPlan::Sort { .. }
89            | LogicalPlan::Union { .. }
90            | LogicalPlan::Join { .. }
91            | LogicalPlan::Values { .. }
92            | LogicalPlan::Analyze { .. } => {
93                let mut converter = TypeConverter {
94                    schema: plan.schema().clone(),
95                    query_ctx: ctx.query_ctx(),
96                };
97                let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
98                let expr = plan
99                    .expressions_consider_join()
100                    .into_iter()
101                    .map(|e| e.rewrite(&mut converter).map(|x| x.data))
102                    .collect::<Result<Vec<_>>>()?;
103
104                plan.with_new_exprs(expr, inputs).map(Transformed::yes)
105            }
106
107            LogicalPlan::Distinct { .. }
108            | LogicalPlan::Limit { .. }
109            | LogicalPlan::Subquery { .. }
110            | LogicalPlan::Explain { .. }
111            | LogicalPlan::SubqueryAlias { .. }
112            | LogicalPlan::EmptyRelation(_)
113            | LogicalPlan::Dml(_)
114            | LogicalPlan::DescribeTable(_)
115            | LogicalPlan::Unnest(_)
116            | LogicalPlan::Statement(_)
117            | LogicalPlan::Ddl(_)
118            | LogicalPlan::Copy(_)
119            | LogicalPlan::RecursiveQuery(_) => Ok(Transformed::no(plan)),
120        })
121        .map(|x| x.data)
122    }
123}
124
125struct TypeConverter {
126    query_ctx: QueryContextRef,
127    schema: DFSchemaRef,
128}
129
130impl TypeConverter {
131    fn column_type(&self, expr: &Expr) -> Option<DataType> {
132        if let Expr::Column(_) = expr {
133            if let Ok(v) = expr.get_type(&self.schema) {
134                return Some(v);
135            }
136        }
137        None
138    }
139
140    fn cast_scalar_value(
141        &self,
142        value: &ScalarValue,
143        target_type: &DataType,
144    ) -> Result<ScalarValue> {
145        match (target_type, value) {
146            (DataType::Timestamp(_, _), ScalarValue::Utf8(Some(v))) => {
147                string_to_timestamp_ms(v, Some(&self.query_ctx.timezone()))
148            }
149            (DataType::Boolean, ScalarValue::Utf8(Some(v))) => match v.to_lowercase().as_str() {
150                "true" => Ok(ScalarValue::Boolean(Some(true))),
151                "false" => Ok(ScalarValue::Boolean(Some(false))),
152                _ => Ok(ScalarValue::Boolean(None)),
153            },
154            (target_type, value) => {
155                let value_arr = value.to_array()?;
156                let arr = compute::cast(&value_arr, target_type)
157                    .map_err(|e| DataFusionError::ArrowError(e, None))?;
158
159                ScalarValue::try_from_array(
160                    &arr,
161                    0, // index: Converts a value in `array` at `index` into a ScalarValue
162                )
163            }
164        }
165    }
166
167    fn convert_type<'b>(&self, left: &'b Expr, right: &'b Expr) -> Result<(Expr, Expr)> {
168        let left_type = self.column_type(left);
169        let right_type = self.column_type(right);
170
171        let target_type = match (&left_type, &right_type) {
172            (Some(v), None) => v,
173            (None, Some(v)) => v,
174            _ => return Ok((left.clone(), right.clone())),
175        };
176
177        // only try to convert timestamp or boolean types
178        if !matches!(target_type, DataType::Timestamp(_, _) | DataType::Boolean) {
179            return Ok((left.clone(), right.clone()));
180        }
181
182        match (left, right) {
183            (Expr::Column(col), Expr::Literal(value)) => {
184                let casted_right = self.cast_scalar_value(value, target_type)?;
185                if casted_right.is_null() {
186                    return Err(DataFusionError::Plan(format!(
187                        "column:{col:?}. Casting value:{value:?} to {target_type:?} is invalid",
188                    )));
189                }
190                Ok((left.clone(), Expr::Literal(casted_right)))
191            }
192            (Expr::Literal(value), Expr::Column(col)) => {
193                let casted_left = self.cast_scalar_value(value, target_type)?;
194                if casted_left.is_null() {
195                    return Err(DataFusionError::Plan(format!(
196                        "column:{col:?}. Casting value:{value:?} to {target_type:?} is invalid",
197                    )));
198                }
199                Ok((Expr::Literal(casted_left), right.clone()))
200            }
201            _ => Ok((left.clone(), right.clone())),
202        }
203    }
204}
205
206impl TreeNodeRewriter for TypeConverter {
207    type Node = Expr;
208
209    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
210        let new_expr = match expr {
211            Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
212                Operator::Eq
213                | Operator::NotEq
214                | Operator::Lt
215                | Operator::LtEq
216                | Operator::Gt
217                | Operator::GtEq => {
218                    let (left, right) = self.convert_type(&left, &right)?;
219                    Expr::BinaryExpr(BinaryExpr {
220                        left: Box::new(left),
221                        op,
222                        right: Box::new(right),
223                    })
224                }
225                _ => Expr::BinaryExpr(BinaryExpr { left, op, right }),
226            },
227            Expr::Between(Between {
228                expr,
229                negated,
230                low,
231                high,
232            }) => {
233                let (expr, low) = self.convert_type(&expr, &low)?;
234                let (expr, high) = self.convert_type(&expr, &high)?;
235                Expr::Between(Between {
236                    expr: Box::new(expr),
237                    negated,
238                    low: Box::new(low),
239                    high: Box::new(high),
240                })
241            }
242            Expr::InList(InList {
243                expr,
244                list,
245                negated,
246            }) => {
247                let mut list_expr = Vec::with_capacity(list.len());
248                for e in list {
249                    let (_, expr_conversion) = self.convert_type(&expr, &e)?;
250                    list_expr.push(expr_conversion);
251                }
252                Expr::InList(InList {
253                    expr,
254                    list: list_expr,
255                    negated,
256                })
257            }
258            Expr::Literal(value) => match value {
259                ScalarValue::TimestampSecond(Some(i), _) => {
260                    timestamp_to_timestamp_ms_expr(i, TimeUnit::Second)
261                }
262                ScalarValue::TimestampMillisecond(Some(i), _) => {
263                    timestamp_to_timestamp_ms_expr(i, TimeUnit::Millisecond)
264                }
265                ScalarValue::TimestampMicrosecond(Some(i), _) => {
266                    timestamp_to_timestamp_ms_expr(i, TimeUnit::Microsecond)
267                }
268                ScalarValue::TimestampNanosecond(Some(i), _) => {
269                    timestamp_to_timestamp_ms_expr(i, TimeUnit::Nanosecond)
270                }
271                _ => Expr::Literal(value),
272            },
273            expr => expr,
274        };
275        Ok(Transformed::yes(new_expr))
276    }
277}
278
279fn timestamp_to_timestamp_ms_expr(val: i64, unit: TimeUnit) -> Expr {
280    let timestamp = match unit {
281        TimeUnit::Second => val * 1_000,
282        TimeUnit::Millisecond => val,
283        TimeUnit::Microsecond => val / 1_000,
284        TimeUnit::Nanosecond => val / 1_000 / 1_000,
285    };
286
287    Expr::Literal(ScalarValue::TimestampMillisecond(Some(timestamp), None))
288}
289
290fn string_to_timestamp_ms(string: &str, timezone: Option<&Timezone>) -> Result<ScalarValue> {
291    let ts = Timestamp::from_str(string, timezone)
292        .map_err(|e| DataFusionError::External(Box::new(e)))?;
293
294    let value = Some(ts.value());
295    let scalar = match ts.unit() {
296        TimeUnit::Second => ScalarValue::TimestampSecond(value, None),
297        TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(value, None),
298        TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(value, None),
299        TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(value, None),
300    };
301    Ok(scalar)
302}
303
304#[cfg(test)]
305mod tests {
306    use std::collections::HashMap;
307    use std::sync::Arc;
308
309    use datafusion_common::arrow::datatypes::Field;
310    use datafusion_common::{Column, DFSchema};
311    use datafusion_expr::LogicalPlanBuilder;
312    use datafusion_sql::TableReference;
313    use session::context::QueryContext;
314
315    use super::*;
316
317    #[test]
318    fn test_string_to_timestamp_ms() {
319        assert_eq!(
320            string_to_timestamp_ms("2022-02-02 19:00:00+08:00", None).unwrap(),
321            ScalarValue::TimestampSecond(Some(1643799600), None)
322        );
323        assert_eq!(
324            string_to_timestamp_ms("2009-02-13 23:31:30Z", None).unwrap(),
325            ScalarValue::TimestampSecond(Some(1234567890), None)
326        );
327
328        assert_eq!(
329            string_to_timestamp_ms(
330                "2009-02-13 23:31:30",
331                Some(&Timezone::from_tz_string("Asia/Shanghai").unwrap())
332            )
333            .unwrap(),
334            ScalarValue::TimestampSecond(Some(1234567890 - 8 * 3600), None)
335        );
336
337        assert_eq!(
338            string_to_timestamp_ms(
339                "2009-02-13 23:31:30",
340                Some(&Timezone::from_tz_string("-8:00").unwrap())
341            )
342            .unwrap(),
343            ScalarValue::TimestampSecond(Some(1234567890 + 8 * 3600), None)
344        );
345    }
346
347    #[test]
348    fn test_timestamp_to_timestamp_ms_expr() {
349        assert_eq!(
350            timestamp_to_timestamp_ms_expr(123, TimeUnit::Second),
351            Expr::Literal(ScalarValue::TimestampMillisecond(Some(123000), None))
352        );
353
354        assert_eq!(
355            timestamp_to_timestamp_ms_expr(123, TimeUnit::Millisecond),
356            Expr::Literal(ScalarValue::TimestampMillisecond(Some(123), None))
357        );
358
359        assert_eq!(
360            timestamp_to_timestamp_ms_expr(123, TimeUnit::Microsecond),
361            Expr::Literal(ScalarValue::TimestampMillisecond(Some(0), None))
362        );
363
364        assert_eq!(
365            timestamp_to_timestamp_ms_expr(1230, TimeUnit::Microsecond),
366            Expr::Literal(ScalarValue::TimestampMillisecond(Some(1), None))
367        );
368
369        assert_eq!(
370            timestamp_to_timestamp_ms_expr(123000, TimeUnit::Microsecond),
371            Expr::Literal(ScalarValue::TimestampMillisecond(Some(123), None))
372        );
373
374        assert_eq!(
375            timestamp_to_timestamp_ms_expr(1230, TimeUnit::Nanosecond),
376            Expr::Literal(ScalarValue::TimestampMillisecond(Some(0), None))
377        );
378        assert_eq!(
379            timestamp_to_timestamp_ms_expr(123_000_000, TimeUnit::Nanosecond),
380            Expr::Literal(ScalarValue::TimestampMillisecond(Some(123), None))
381        );
382    }
383
384    #[test]
385    fn test_convert_timestamp_str() {
386        use datatypes::arrow::datatypes::TimeUnit as ArrowTimeUnit;
387
388        let schema = Arc::new(
389            DFSchema::new_with_metadata(
390                vec![(
391                    None::<TableReference>,
392                    Arc::new(Field::new(
393                        "ts",
394                        DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
395                        true,
396                    )),
397                )],
398                HashMap::new(),
399            )
400            .unwrap(),
401        );
402        let mut converter = TypeConverter {
403            schema,
404            query_ctx: QueryContext::arc(),
405        };
406
407        assert_eq!(
408            Expr::Column(Column::from_name("ts")).gt(Expr::Literal(ScalarValue::TimestampSecond(
409                Some(1599514949),
410                None
411            ))),
412            converter
413                .f_up(
414                    Expr::Column(Column::from_name("ts")).gt(Expr::Literal(ScalarValue::Utf8(
415                        Some("2020-09-08T05:42:29+08:00".to_string()),
416                    )))
417                )
418                .unwrap()
419                .data
420        );
421    }
422
423    #[test]
424    fn test_convert_bool() {
425        let col_name = "is_valid";
426        let schema = Arc::new(
427            DFSchema::new_with_metadata(
428                vec![(
429                    None::<TableReference>,
430                    Arc::new(Field::new(col_name, DataType::Boolean, false)),
431                )],
432                HashMap::new(),
433            )
434            .unwrap(),
435        );
436        let mut converter = TypeConverter {
437            schema,
438            query_ctx: QueryContext::arc(),
439        };
440
441        assert_eq!(
442            Expr::Column(Column::from_name(col_name))
443                .eq(Expr::Literal(ScalarValue::Boolean(Some(true)))),
444            converter
445                .f_up(
446                    Expr::Column(Column::from_name(col_name))
447                        .eq(Expr::Literal(ScalarValue::Utf8(Some("true".to_string()))))
448                )
449                .unwrap()
450                .data
451        );
452    }
453
454    #[test]
455    fn test_retrieve_type_from_aggr_plan() {
456        let plan =
457            LogicalPlanBuilder::values(vec![vec![
458                Expr::Literal(ScalarValue::Int64(Some(1))),
459                Expr::Literal(ScalarValue::Float64(Some(1.0))),
460                Expr::Literal(ScalarValue::TimestampMillisecond(Some(1), None)),
461            ]])
462            .unwrap()
463            .filter(Expr::Column(Column::from_name("column3")).gt(Expr::Literal(
464                ScalarValue::Utf8(Some("1970-01-01 00:00:00+08:00".to_string())),
465            )))
466            .unwrap()
467            .filter(
468                Expr::Literal(ScalarValue::Utf8(Some(
469                    "1970-01-01 00:00:00+08:00".to_string(),
470                )))
471                .lt_eq(Expr::Column(Column::from_name("column3"))),
472            )
473            .unwrap()
474            .aggregate(
475                Vec::<Expr>::new(),
476                vec![Expr::AggregateFunction(
477                    datafusion_expr::expr::AggregateFunction::new_udf(
478                        datafusion::functions_aggregate::count::count_udaf(),
479                        vec![Expr::Column(Column::from_name("column1"))],
480                        false,
481                        None,
482                        None,
483                        None,
484                    ),
485                )],
486            )
487            .unwrap()
488            .build()
489            .unwrap();
490        let context = QueryEngineContext::mock();
491
492        let transformed_plan = TypeConversionRule
493            .analyze(plan, &context, &ConfigOptions::default())
494            .unwrap();
495        let expected = String::from(
496            "Aggregate: groupBy=[[]], aggr=[[count(column1)]]\
497            \n  Filter: TimestampSecond(-28800, None) <= column3\
498            \n    Filter: column3 > TimestampSecond(-28800, None)\
499            \n      Values: (Int64(1), Float64(1), TimestampMillisecond(1, None))",
500        );
501        assert_eq!(format!("{}", transformed_plan.display_indent()), expected);
502    }
503
504    #[test]
505    fn test_reverse_non_ts_type() {
506        let context = QueryEngineContext::mock();
507
508        let plan =
509            LogicalPlanBuilder::values(vec![vec![Expr::Literal(ScalarValue::Float64(Some(1.0)))]])
510                .unwrap()
511                .filter(
512                    Expr::Column(Column::from_name("column1"))
513                        .gt_eq(Expr::Literal(ScalarValue::Utf8(Some("1.2345".to_string())))),
514                )
515                .unwrap()
516                .filter(
517                    Expr::Literal(ScalarValue::Utf8(Some("1.2345".to_string())))
518                        .lt(Expr::Column(Column::from_name("column1"))),
519                )
520                .unwrap()
521                .build()
522                .unwrap();
523        let transformed_plan = TypeConversionRule
524            .analyze(plan, &context, &ConfigOptions::default())
525            .unwrap();
526        let expected = String::from(
527            "Filter: Utf8(\"1.2345\") < column1\
528            \n  Filter: column1 >= Utf8(\"1.2345\")\
529            \n    Values: (Float64(1))",
530        );
531        assert_eq!(format!("{}", transformed_plan.display_indent()), expected);
532    }
533}