query/optimizer/
type_conversion.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use common_time::timestamp::{TimeUnit, Timestamp};
16use common_time::Timezone;
17use datafusion::config::ConfigOptions;
18use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
19use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
20use datafusion_expr::expr::InList;
21use datafusion_expr::{
22    Between, BinaryExpr, Expr, ExprSchemable, Filter, LogicalPlan, Operator, TableScan,
23};
24use datatypes::arrow::compute;
25use datatypes::arrow::datatypes::DataType;
26use session::context::QueryContextRef;
27
28use crate::optimizer::ExtensionAnalyzerRule;
29use crate::plan::ExtractExpr;
30use crate::QueryEngineContext;
31
32/// TypeConversionRule converts some literal values in logical plan to other types according
33/// to data type of corresponding columns.
34/// Specifically:
35/// - string literal of timestamp is converted to `Expr::Literal(ScalarValue::TimestampMillis)`
36/// - string literal of boolean is converted to `Expr::Literal(ScalarValue::Boolean)`
37pub struct TypeConversionRule;
38
39impl ExtensionAnalyzerRule for TypeConversionRule {
40    fn analyze(
41        &self,
42        plan: LogicalPlan,
43        ctx: &QueryEngineContext,
44        _config: &ConfigOptions,
45    ) -> Result<LogicalPlan> {
46        plan.transform(&|plan| match plan {
47            LogicalPlan::Filter(filter) => {
48                let mut converter = TypeConverter {
49                    schema: filter.input.schema().clone(),
50                    query_ctx: ctx.query_ctx(),
51                };
52                let rewritten = filter.predicate.clone().rewrite(&mut converter)?.data;
53                Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new(
54                    rewritten,
55                    filter.input,
56                )?)))
57            }
58            LogicalPlan::TableScan(TableScan {
59                table_name,
60                source,
61                projection,
62                projected_schema,
63                filters,
64                fetch,
65            }) => {
66                let mut converter = TypeConverter {
67                    schema: projected_schema.clone(),
68                    query_ctx: ctx.query_ctx(),
69                };
70                let rewrite_filters = filters
71                    .into_iter()
72                    .map(|e| e.rewrite(&mut converter).map(|x| x.data))
73                    .collect::<Result<Vec<_>>>()?;
74                Ok(Transformed::yes(LogicalPlan::TableScan(TableScan {
75                    table_name: table_name.clone(),
76                    source: source.clone(),
77                    projection,
78                    projected_schema,
79                    filters: rewrite_filters,
80                    fetch,
81                })))
82            }
83            LogicalPlan::Projection { .. }
84            | LogicalPlan::Window { .. }
85            | LogicalPlan::Aggregate { .. }
86            | LogicalPlan::Repartition { .. }
87            | LogicalPlan::Extension { .. }
88            | LogicalPlan::Sort { .. }
89            | LogicalPlan::Union { .. }
90            | LogicalPlan::Values { .. }
91            | LogicalPlan::Analyze { .. } => {
92                let mut converter = TypeConverter {
93                    schema: plan.schema().clone(),
94                    query_ctx: ctx.query_ctx(),
95                };
96                let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
97                let expr = plan
98                    .expressions_consider_join()
99                    .into_iter()
100                    .map(|e| e.rewrite(&mut converter).map(|x| x.data))
101                    .collect::<Result<Vec<_>>>()?;
102
103                plan.with_new_exprs(expr, inputs).map(Transformed::yes)
104            }
105
106            LogicalPlan::Distinct { .. }
107            | LogicalPlan::Limit { .. }
108            | LogicalPlan::Subquery { .. }
109            | LogicalPlan::Explain { .. }
110            | LogicalPlan::SubqueryAlias { .. }
111            | LogicalPlan::EmptyRelation(_)
112            | LogicalPlan::Dml(_)
113            | LogicalPlan::DescribeTable(_)
114            | LogicalPlan::Unnest(_)
115            | LogicalPlan::Statement(_)
116            | LogicalPlan::Ddl(_)
117            | LogicalPlan::Copy(_)
118            | LogicalPlan::RecursiveQuery(_)
119            | LogicalPlan::Join { .. } => Ok(Transformed::no(plan)),
120        })
121        .map(|x| x.data)
122    }
123}
124
125struct TypeConverter {
126    query_ctx: QueryContextRef,
127    schema: DFSchemaRef,
128}
129
130impl TypeConverter {
131    fn column_type(&self, expr: &Expr) -> Option<DataType> {
132        if let Expr::Column(_) = expr {
133            if let Ok(v) = expr.get_type(&self.schema) {
134                return Some(v);
135            }
136        }
137        None
138    }
139
140    fn cast_scalar_value(
141        &self,
142        value: &ScalarValue,
143        target_type: &DataType,
144    ) -> Result<ScalarValue> {
145        match (target_type, value) {
146            (DataType::Timestamp(_, _), ScalarValue::Utf8(Some(v))) => {
147                string_to_timestamp_ms(v, Some(&self.query_ctx.timezone()))
148            }
149            (DataType::Boolean, ScalarValue::Utf8(Some(v))) => match v.to_lowercase().as_str() {
150                "true" => Ok(ScalarValue::Boolean(Some(true))),
151                "false" => Ok(ScalarValue::Boolean(Some(false))),
152                _ => Ok(ScalarValue::Boolean(None)),
153            },
154            (target_type, value) => {
155                let value_arr = value.to_array()?;
156                let arr = compute::cast(&value_arr, target_type)
157                    .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
158
159                ScalarValue::try_from_array(
160                    &arr,
161                    0, // index: Converts a value in `array` at `index` into a ScalarValue
162                )
163            }
164        }
165    }
166
167    fn convert_type<'b>(&self, left: &'b Expr, right: &'b Expr) -> Result<(Expr, Expr)> {
168        let left_type = self.column_type(left);
169        let right_type = self.column_type(right);
170
171        let target_type = match (&left_type, &right_type) {
172            (Some(v), None) => v,
173            (None, Some(v)) => v,
174            _ => return Ok((left.clone(), right.clone())),
175        };
176
177        // only try to convert timestamp or boolean types
178        if !matches!(target_type, DataType::Timestamp(_, _) | DataType::Boolean) {
179            return Ok((left.clone(), right.clone()));
180        }
181
182        match (left, right) {
183            (Expr::Column(col), Expr::Literal(value, _)) => {
184                let casted_right = self.cast_scalar_value(value, target_type)?;
185                if casted_right.is_null() {
186                    return Err(DataFusionError::Plan(format!(
187                        "column:{col:?}. Casting value:{value:?} to {target_type:?} is invalid",
188                    )));
189                }
190                Ok((left.clone(), Expr::Literal(casted_right, None)))
191            }
192            (Expr::Literal(value, _), Expr::Column(col)) => {
193                let casted_left = self.cast_scalar_value(value, target_type)?;
194                if casted_left.is_null() {
195                    return Err(DataFusionError::Plan(format!(
196                        "column:{col:?}. Casting value:{value:?} to {target_type:?} is invalid",
197                    )));
198                }
199                Ok((Expr::Literal(casted_left, None), right.clone()))
200            }
201            _ => Ok((left.clone(), right.clone())),
202        }
203    }
204}
205
206impl TreeNodeRewriter for TypeConverter {
207    type Node = Expr;
208
209    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
210        let new_expr = match expr {
211            Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
212                Operator::Eq
213                | Operator::NotEq
214                | Operator::Lt
215                | Operator::LtEq
216                | Operator::Gt
217                | Operator::GtEq => {
218                    let (left, right) = self.convert_type(&left, &right)?;
219                    Expr::BinaryExpr(BinaryExpr {
220                        left: Box::new(left),
221                        op,
222                        right: Box::new(right),
223                    })
224                }
225                _ => Expr::BinaryExpr(BinaryExpr { left, op, right }),
226            },
227            Expr::Between(Between {
228                expr,
229                negated,
230                low,
231                high,
232            }) => {
233                let (expr, low) = self.convert_type(&expr, &low)?;
234                let (expr, high) = self.convert_type(&expr, &high)?;
235                Expr::Between(Between {
236                    expr: Box::new(expr),
237                    negated,
238                    low: Box::new(low),
239                    high: Box::new(high),
240                })
241            }
242            Expr::InList(InList {
243                expr,
244                list,
245                negated,
246            }) => {
247                let mut list_expr = Vec::with_capacity(list.len());
248                for e in list {
249                    let (_, expr_conversion) = self.convert_type(&expr, &e)?;
250                    list_expr.push(expr_conversion);
251                }
252                Expr::InList(InList {
253                    expr,
254                    list: list_expr,
255                    negated,
256                })
257            }
258            Expr::Literal(value, _) => match value {
259                ScalarValue::TimestampSecond(Some(i), _) => {
260                    timestamp_to_timestamp_ms_expr(i, TimeUnit::Second)
261                }
262                ScalarValue::TimestampMillisecond(Some(i), _) => {
263                    timestamp_to_timestamp_ms_expr(i, TimeUnit::Millisecond)
264                }
265                ScalarValue::TimestampMicrosecond(Some(i), _) => {
266                    timestamp_to_timestamp_ms_expr(i, TimeUnit::Microsecond)
267                }
268                ScalarValue::TimestampNanosecond(Some(i), _) => {
269                    timestamp_to_timestamp_ms_expr(i, TimeUnit::Nanosecond)
270                }
271                _ => Expr::Literal(value, None),
272            },
273            expr => expr,
274        };
275        Ok(Transformed::yes(new_expr))
276    }
277}
278
279fn timestamp_to_timestamp_ms_expr(val: i64, unit: TimeUnit) -> Expr {
280    let timestamp = match unit {
281        TimeUnit::Second => val * 1_000,
282        TimeUnit::Millisecond => val,
283        TimeUnit::Microsecond => val / 1_000,
284        TimeUnit::Nanosecond => val / 1_000 / 1_000,
285    };
286
287    Expr::Literal(
288        ScalarValue::TimestampMillisecond(Some(timestamp), None),
289        None,
290    )
291}
292
293fn string_to_timestamp_ms(string: &str, timezone: Option<&Timezone>) -> Result<ScalarValue> {
294    let ts = Timestamp::from_str(string, timezone)
295        .map_err(|e| DataFusionError::External(Box::new(e)))?;
296
297    let value = Some(ts.value());
298    let scalar = match ts.unit() {
299        TimeUnit::Second => ScalarValue::TimestampSecond(value, None),
300        TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(value, None),
301        TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(value, None),
302        TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(value, None),
303    };
304    Ok(scalar)
305}
306
307#[cfg(test)]
308mod tests {
309    use std::collections::HashMap;
310    use std::sync::Arc;
311
312    use datafusion_common::arrow::datatypes::Field;
313    use datafusion_common::{Column, DFSchema};
314    use datafusion_expr::{Literal, LogicalPlanBuilder};
315    use datafusion_sql::TableReference;
316    use session::context::QueryContext;
317
318    use super::*;
319
320    #[test]
321    fn test_string_to_timestamp_ms() {
322        assert_eq!(
323            string_to_timestamp_ms("2022-02-02 19:00:00+08:00", None).unwrap(),
324            ScalarValue::TimestampSecond(Some(1643799600), None)
325        );
326        assert_eq!(
327            string_to_timestamp_ms("2009-02-13 23:31:30Z", None).unwrap(),
328            ScalarValue::TimestampSecond(Some(1234567890), None)
329        );
330
331        assert_eq!(
332            string_to_timestamp_ms(
333                "2009-02-13 23:31:30",
334                Some(&Timezone::from_tz_string("Asia/Shanghai").unwrap())
335            )
336            .unwrap(),
337            ScalarValue::TimestampSecond(Some(1234567890 - 8 * 3600), None)
338        );
339
340        assert_eq!(
341            string_to_timestamp_ms(
342                "2009-02-13 23:31:30",
343                Some(&Timezone::from_tz_string("-8:00").unwrap())
344            )
345            .unwrap(),
346            ScalarValue::TimestampSecond(Some(1234567890 + 8 * 3600), None)
347        );
348    }
349
350    #[test]
351    fn test_timestamp_to_timestamp_ms_expr() {
352        assert_eq!(
353            timestamp_to_timestamp_ms_expr(123, TimeUnit::Second),
354            ScalarValue::TimestampMillisecond(Some(123000), None).lit()
355        );
356
357        assert_eq!(
358            timestamp_to_timestamp_ms_expr(123, TimeUnit::Millisecond),
359            ScalarValue::TimestampMillisecond(Some(123), None).lit()
360        );
361
362        assert_eq!(
363            timestamp_to_timestamp_ms_expr(123, TimeUnit::Microsecond),
364            ScalarValue::TimestampMillisecond(Some(0), None).lit()
365        );
366
367        assert_eq!(
368            timestamp_to_timestamp_ms_expr(1230, TimeUnit::Microsecond),
369            ScalarValue::TimestampMillisecond(Some(1), None).lit()
370        );
371
372        assert_eq!(
373            timestamp_to_timestamp_ms_expr(123000, TimeUnit::Microsecond),
374            ScalarValue::TimestampMillisecond(Some(123), None).lit()
375        );
376
377        assert_eq!(
378            timestamp_to_timestamp_ms_expr(1230, TimeUnit::Nanosecond),
379            ScalarValue::TimestampMillisecond(Some(0), None).lit()
380        );
381        assert_eq!(
382            timestamp_to_timestamp_ms_expr(123_000_000, TimeUnit::Nanosecond),
383            ScalarValue::TimestampMillisecond(Some(123), None).lit()
384        );
385    }
386
387    #[test]
388    fn test_convert_timestamp_str() {
389        use datatypes::arrow::datatypes::TimeUnit as ArrowTimeUnit;
390
391        let schema = Arc::new(
392            DFSchema::new_with_metadata(
393                vec![(
394                    None::<TableReference>,
395                    Arc::new(Field::new(
396                        "ts",
397                        DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
398                        true,
399                    )),
400                )],
401                HashMap::new(),
402            )
403            .unwrap(),
404        );
405        let mut converter = TypeConverter {
406            schema,
407            query_ctx: QueryContext::arc(),
408        };
409
410        assert_eq!(
411            Expr::Column(Column::from_name("ts")).gt(ScalarValue::TimestampSecond(
412                Some(1599514949),
413                None
414            )
415            .lit()),
416            converter
417                .f_up(Expr::Column(Column::from_name("ts")).gt("2020-09-08T05:42:29+08:00".lit()))
418                .unwrap()
419                .data
420        );
421    }
422
423    #[test]
424    fn test_convert_bool() {
425        let col_name = "is_valid";
426        let schema = Arc::new(
427            DFSchema::new_with_metadata(
428                vec![(
429                    None::<TableReference>,
430                    Arc::new(Field::new(col_name, DataType::Boolean, false)),
431                )],
432                HashMap::new(),
433            )
434            .unwrap(),
435        );
436        let mut converter = TypeConverter {
437            schema,
438            query_ctx: QueryContext::arc(),
439        };
440
441        assert_eq!(
442            Expr::Column(Column::from_name(col_name)).eq(true.lit()),
443            converter
444                .f_up(Expr::Column(Column::from_name(col_name)).eq("true".lit()))
445                .unwrap()
446                .data
447        );
448    }
449
450    #[test]
451    fn test_retrieve_type_from_aggr_plan() {
452        let plan = LogicalPlanBuilder::values(vec![vec![
453            ScalarValue::Int64(Some(1)).lit(),
454            ScalarValue::Float64(Some(1.0)).lit(),
455            ScalarValue::TimestampMillisecond(Some(1), None).lit(),
456        ]])
457        .unwrap()
458        .filter(Expr::Column(Column::from_name("column3")).gt("1970-01-01 00:00:00+08:00".lit()))
459        .unwrap()
460        .filter(
461            "1970-01-01 00:00:00+08:00"
462                .lit()
463                .lt_eq(Expr::Column(Column::from_name("column3"))),
464        )
465        .unwrap()
466        .aggregate(
467            Vec::<Expr>::new(),
468            vec![Expr::AggregateFunction(
469                datafusion_expr::expr::AggregateFunction::new_udf(
470                    datafusion::functions_aggregate::count::count_udaf(),
471                    vec![Expr::Column(Column::from_name("column1"))],
472                    false,
473                    None,
474                    vec![],
475                    None,
476                ),
477            )],
478        )
479        .unwrap()
480        .build()
481        .unwrap();
482        let context = QueryEngineContext::mock();
483
484        let transformed_plan = TypeConversionRule
485            .analyze(plan, &context, &ConfigOptions::default())
486            .unwrap();
487        let expected = String::from(
488            "Aggregate: groupBy=[[]], aggr=[[count(column1)]]\
489            \n  Filter: TimestampSecond(-28800, None) <= column3\
490            \n    Filter: column3 > TimestampSecond(-28800, None)\
491            \n      Values: (Int64(1), Float64(1), TimestampMillisecond(1, None))",
492        );
493        assert_eq!(format!("{}", transformed_plan.display_indent()), expected);
494    }
495
496    #[test]
497    fn test_reverse_non_ts_type() {
498        let context = QueryEngineContext::mock();
499
500        let plan = LogicalPlanBuilder::values(vec![vec![1.0f64.lit()]])
501            .unwrap()
502            .filter(Expr::Column(Column::from_name("column1")).gt_eq("1.2345".lit()))
503            .unwrap()
504            .filter(
505                "1.2345"
506                    .lit()
507                    .lt(Expr::Column(Column::from_name("column1"))),
508            )
509            .unwrap()
510            .build()
511            .unwrap();
512        let transformed_plan = TypeConversionRule
513            .analyze(plan, &context, &ConfigOptions::default())
514            .unwrap();
515        let expected = String::from(
516            "Filter: Utf8(\"1.2345\") < column1\
517            \n  Filter: column1 >= Utf8(\"1.2345\")\
518            \n    Values: (Float64(1))",
519        );
520        assert_eq!(format!("{}", transformed_plan.display_indent()), expected);
521    }
522}