1use common_time::timestamp::{TimeUnit, Timestamp};
16use common_time::Timezone;
17use datafusion::config::ConfigOptions;
18use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
19use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
20use datafusion_expr::expr::InList;
21use datafusion_expr::{
22 Between, BinaryExpr, Expr, ExprSchemable, Filter, LogicalPlan, Operator, TableScan,
23};
24use datatypes::arrow::compute;
25use datatypes::arrow::datatypes::DataType;
26use session::context::QueryContextRef;
27
28use crate::optimizer::ExtensionAnalyzerRule;
29use crate::plan::ExtractExpr;
30use crate::QueryEngineContext;
31
32pub struct TypeConversionRule;
38
39impl ExtensionAnalyzerRule for TypeConversionRule {
40 fn analyze(
41 &self,
42 plan: LogicalPlan,
43 ctx: &QueryEngineContext,
44 _config: &ConfigOptions,
45 ) -> Result<LogicalPlan> {
46 plan.transform(&|plan| match plan {
47 LogicalPlan::Filter(filter) => {
48 let mut converter = TypeConverter {
49 schema: filter.input.schema().clone(),
50 query_ctx: ctx.query_ctx(),
51 };
52 let rewritten = filter.predicate.clone().rewrite(&mut converter)?.data;
53 Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new(
54 rewritten,
55 filter.input,
56 )?)))
57 }
58 LogicalPlan::TableScan(TableScan {
59 table_name,
60 source,
61 projection,
62 projected_schema,
63 filters,
64 fetch,
65 }) => {
66 let mut converter = TypeConverter {
67 schema: projected_schema.clone(),
68 query_ctx: ctx.query_ctx(),
69 };
70 let rewrite_filters = filters
71 .into_iter()
72 .map(|e| e.rewrite(&mut converter).map(|x| x.data))
73 .collect::<Result<Vec<_>>>()?;
74 Ok(Transformed::yes(LogicalPlan::TableScan(TableScan {
75 table_name: table_name.clone(),
76 source: source.clone(),
77 projection,
78 projected_schema,
79 filters: rewrite_filters,
80 fetch,
81 })))
82 }
83 LogicalPlan::Projection { .. }
84 | LogicalPlan::Window { .. }
85 | LogicalPlan::Aggregate { .. }
86 | LogicalPlan::Repartition { .. }
87 | LogicalPlan::Extension { .. }
88 | LogicalPlan::Sort { .. }
89 | LogicalPlan::Union { .. }
90 | LogicalPlan::Join { .. }
91 | LogicalPlan::Values { .. }
92 | LogicalPlan::Analyze { .. } => {
93 let mut converter = TypeConverter {
94 schema: plan.schema().clone(),
95 query_ctx: ctx.query_ctx(),
96 };
97 let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
98 let expr = plan
99 .expressions_consider_join()
100 .into_iter()
101 .map(|e| e.rewrite(&mut converter).map(|x| x.data))
102 .collect::<Result<Vec<_>>>()?;
103
104 plan.with_new_exprs(expr, inputs).map(Transformed::yes)
105 }
106
107 LogicalPlan::Distinct { .. }
108 | LogicalPlan::Limit { .. }
109 | LogicalPlan::Subquery { .. }
110 | LogicalPlan::Explain { .. }
111 | LogicalPlan::SubqueryAlias { .. }
112 | LogicalPlan::EmptyRelation(_)
113 | LogicalPlan::Dml(_)
114 | LogicalPlan::DescribeTable(_)
115 | LogicalPlan::Unnest(_)
116 | LogicalPlan::Statement(_)
117 | LogicalPlan::Ddl(_)
118 | LogicalPlan::Copy(_)
119 | LogicalPlan::RecursiveQuery(_) => Ok(Transformed::no(plan)),
120 })
121 .map(|x| x.data)
122 }
123}
124
125struct TypeConverter {
126 query_ctx: QueryContextRef,
127 schema: DFSchemaRef,
128}
129
130impl TypeConverter {
131 fn column_type(&self, expr: &Expr) -> Option<DataType> {
132 if let Expr::Column(_) = expr {
133 if let Ok(v) = expr.get_type(&self.schema) {
134 return Some(v);
135 }
136 }
137 None
138 }
139
140 fn cast_scalar_value(
141 &self,
142 value: &ScalarValue,
143 target_type: &DataType,
144 ) -> Result<ScalarValue> {
145 match (target_type, value) {
146 (DataType::Timestamp(_, _), ScalarValue::Utf8(Some(v))) => {
147 string_to_timestamp_ms(v, Some(&self.query_ctx.timezone()))
148 }
149 (DataType::Boolean, ScalarValue::Utf8(Some(v))) => match v.to_lowercase().as_str() {
150 "true" => Ok(ScalarValue::Boolean(Some(true))),
151 "false" => Ok(ScalarValue::Boolean(Some(false))),
152 _ => Ok(ScalarValue::Boolean(None)),
153 },
154 (target_type, value) => {
155 let value_arr = value.to_array()?;
156 let arr = compute::cast(&value_arr, target_type)
157 .map_err(|e| DataFusionError::ArrowError(e, None))?;
158
159 ScalarValue::try_from_array(
160 &arr,
161 0, )
163 }
164 }
165 }
166
167 fn convert_type<'b>(&self, left: &'b Expr, right: &'b Expr) -> Result<(Expr, Expr)> {
168 let left_type = self.column_type(left);
169 let right_type = self.column_type(right);
170
171 let target_type = match (&left_type, &right_type) {
172 (Some(v), None) => v,
173 (None, Some(v)) => v,
174 _ => return Ok((left.clone(), right.clone())),
175 };
176
177 if !matches!(target_type, DataType::Timestamp(_, _) | DataType::Boolean) {
179 return Ok((left.clone(), right.clone()));
180 }
181
182 match (left, right) {
183 (Expr::Column(col), Expr::Literal(value)) => {
184 let casted_right = self.cast_scalar_value(value, target_type)?;
185 if casted_right.is_null() {
186 return Err(DataFusionError::Plan(format!(
187 "column:{col:?}. Casting value:{value:?} to {target_type:?} is invalid",
188 )));
189 }
190 Ok((left.clone(), Expr::Literal(casted_right)))
191 }
192 (Expr::Literal(value), Expr::Column(col)) => {
193 let casted_left = self.cast_scalar_value(value, target_type)?;
194 if casted_left.is_null() {
195 return Err(DataFusionError::Plan(format!(
196 "column:{col:?}. Casting value:{value:?} to {target_type:?} is invalid",
197 )));
198 }
199 Ok((Expr::Literal(casted_left), right.clone()))
200 }
201 _ => Ok((left.clone(), right.clone())),
202 }
203 }
204}
205
206impl TreeNodeRewriter for TypeConverter {
207 type Node = Expr;
208
209 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
210 let new_expr = match expr {
211 Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
212 Operator::Eq
213 | Operator::NotEq
214 | Operator::Lt
215 | Operator::LtEq
216 | Operator::Gt
217 | Operator::GtEq => {
218 let (left, right) = self.convert_type(&left, &right)?;
219 Expr::BinaryExpr(BinaryExpr {
220 left: Box::new(left),
221 op,
222 right: Box::new(right),
223 })
224 }
225 _ => Expr::BinaryExpr(BinaryExpr { left, op, right }),
226 },
227 Expr::Between(Between {
228 expr,
229 negated,
230 low,
231 high,
232 }) => {
233 let (expr, low) = self.convert_type(&expr, &low)?;
234 let (expr, high) = self.convert_type(&expr, &high)?;
235 Expr::Between(Between {
236 expr: Box::new(expr),
237 negated,
238 low: Box::new(low),
239 high: Box::new(high),
240 })
241 }
242 Expr::InList(InList {
243 expr,
244 list,
245 negated,
246 }) => {
247 let mut list_expr = Vec::with_capacity(list.len());
248 for e in list {
249 let (_, expr_conversion) = self.convert_type(&expr, &e)?;
250 list_expr.push(expr_conversion);
251 }
252 Expr::InList(InList {
253 expr,
254 list: list_expr,
255 negated,
256 })
257 }
258 Expr::Literal(value) => match value {
259 ScalarValue::TimestampSecond(Some(i), _) => {
260 timestamp_to_timestamp_ms_expr(i, TimeUnit::Second)
261 }
262 ScalarValue::TimestampMillisecond(Some(i), _) => {
263 timestamp_to_timestamp_ms_expr(i, TimeUnit::Millisecond)
264 }
265 ScalarValue::TimestampMicrosecond(Some(i), _) => {
266 timestamp_to_timestamp_ms_expr(i, TimeUnit::Microsecond)
267 }
268 ScalarValue::TimestampNanosecond(Some(i), _) => {
269 timestamp_to_timestamp_ms_expr(i, TimeUnit::Nanosecond)
270 }
271 _ => Expr::Literal(value),
272 },
273 expr => expr,
274 };
275 Ok(Transformed::yes(new_expr))
276 }
277}
278
279fn timestamp_to_timestamp_ms_expr(val: i64, unit: TimeUnit) -> Expr {
280 let timestamp = match unit {
281 TimeUnit::Second => val * 1_000,
282 TimeUnit::Millisecond => val,
283 TimeUnit::Microsecond => val / 1_000,
284 TimeUnit::Nanosecond => val / 1_000 / 1_000,
285 };
286
287 Expr::Literal(ScalarValue::TimestampMillisecond(Some(timestamp), None))
288}
289
290fn string_to_timestamp_ms(string: &str, timezone: Option<&Timezone>) -> Result<ScalarValue> {
291 let ts = Timestamp::from_str(string, timezone)
292 .map_err(|e| DataFusionError::External(Box::new(e)))?;
293
294 let value = Some(ts.value());
295 let scalar = match ts.unit() {
296 TimeUnit::Second => ScalarValue::TimestampSecond(value, None),
297 TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(value, None),
298 TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(value, None),
299 TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(value, None),
300 };
301 Ok(scalar)
302}
303
304#[cfg(test)]
305mod tests {
306 use std::collections::HashMap;
307 use std::sync::Arc;
308
309 use datafusion_common::arrow::datatypes::Field;
310 use datafusion_common::{Column, DFSchema};
311 use datafusion_expr::LogicalPlanBuilder;
312 use datafusion_sql::TableReference;
313 use session::context::QueryContext;
314
315 use super::*;
316
317 #[test]
318 fn test_string_to_timestamp_ms() {
319 assert_eq!(
320 string_to_timestamp_ms("2022-02-02 19:00:00+08:00", None).unwrap(),
321 ScalarValue::TimestampSecond(Some(1643799600), None)
322 );
323 assert_eq!(
324 string_to_timestamp_ms("2009-02-13 23:31:30Z", None).unwrap(),
325 ScalarValue::TimestampSecond(Some(1234567890), None)
326 );
327
328 assert_eq!(
329 string_to_timestamp_ms(
330 "2009-02-13 23:31:30",
331 Some(&Timezone::from_tz_string("Asia/Shanghai").unwrap())
332 )
333 .unwrap(),
334 ScalarValue::TimestampSecond(Some(1234567890 - 8 * 3600), None)
335 );
336
337 assert_eq!(
338 string_to_timestamp_ms(
339 "2009-02-13 23:31:30",
340 Some(&Timezone::from_tz_string("-8:00").unwrap())
341 )
342 .unwrap(),
343 ScalarValue::TimestampSecond(Some(1234567890 + 8 * 3600), None)
344 );
345 }
346
347 #[test]
348 fn test_timestamp_to_timestamp_ms_expr() {
349 assert_eq!(
350 timestamp_to_timestamp_ms_expr(123, TimeUnit::Second),
351 Expr::Literal(ScalarValue::TimestampMillisecond(Some(123000), None))
352 );
353
354 assert_eq!(
355 timestamp_to_timestamp_ms_expr(123, TimeUnit::Millisecond),
356 Expr::Literal(ScalarValue::TimestampMillisecond(Some(123), None))
357 );
358
359 assert_eq!(
360 timestamp_to_timestamp_ms_expr(123, TimeUnit::Microsecond),
361 Expr::Literal(ScalarValue::TimestampMillisecond(Some(0), None))
362 );
363
364 assert_eq!(
365 timestamp_to_timestamp_ms_expr(1230, TimeUnit::Microsecond),
366 Expr::Literal(ScalarValue::TimestampMillisecond(Some(1), None))
367 );
368
369 assert_eq!(
370 timestamp_to_timestamp_ms_expr(123000, TimeUnit::Microsecond),
371 Expr::Literal(ScalarValue::TimestampMillisecond(Some(123), None))
372 );
373
374 assert_eq!(
375 timestamp_to_timestamp_ms_expr(1230, TimeUnit::Nanosecond),
376 Expr::Literal(ScalarValue::TimestampMillisecond(Some(0), None))
377 );
378 assert_eq!(
379 timestamp_to_timestamp_ms_expr(123_000_000, TimeUnit::Nanosecond),
380 Expr::Literal(ScalarValue::TimestampMillisecond(Some(123), None))
381 );
382 }
383
384 #[test]
385 fn test_convert_timestamp_str() {
386 use datatypes::arrow::datatypes::TimeUnit as ArrowTimeUnit;
387
388 let schema = Arc::new(
389 DFSchema::new_with_metadata(
390 vec![(
391 None::<TableReference>,
392 Arc::new(Field::new(
393 "ts",
394 DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
395 true,
396 )),
397 )],
398 HashMap::new(),
399 )
400 .unwrap(),
401 );
402 let mut converter = TypeConverter {
403 schema,
404 query_ctx: QueryContext::arc(),
405 };
406
407 assert_eq!(
408 Expr::Column(Column::from_name("ts")).gt(Expr::Literal(ScalarValue::TimestampSecond(
409 Some(1599514949),
410 None
411 ))),
412 converter
413 .f_up(
414 Expr::Column(Column::from_name("ts")).gt(Expr::Literal(ScalarValue::Utf8(
415 Some("2020-09-08T05:42:29+08:00".to_string()),
416 )))
417 )
418 .unwrap()
419 .data
420 );
421 }
422
423 #[test]
424 fn test_convert_bool() {
425 let col_name = "is_valid";
426 let schema = Arc::new(
427 DFSchema::new_with_metadata(
428 vec![(
429 None::<TableReference>,
430 Arc::new(Field::new(col_name, DataType::Boolean, false)),
431 )],
432 HashMap::new(),
433 )
434 .unwrap(),
435 );
436 let mut converter = TypeConverter {
437 schema,
438 query_ctx: QueryContext::arc(),
439 };
440
441 assert_eq!(
442 Expr::Column(Column::from_name(col_name))
443 .eq(Expr::Literal(ScalarValue::Boolean(Some(true)))),
444 converter
445 .f_up(
446 Expr::Column(Column::from_name(col_name))
447 .eq(Expr::Literal(ScalarValue::Utf8(Some("true".to_string()))))
448 )
449 .unwrap()
450 .data
451 );
452 }
453
454 #[test]
455 fn test_retrieve_type_from_aggr_plan() {
456 let plan =
457 LogicalPlanBuilder::values(vec![vec![
458 Expr::Literal(ScalarValue::Int64(Some(1))),
459 Expr::Literal(ScalarValue::Float64(Some(1.0))),
460 Expr::Literal(ScalarValue::TimestampMillisecond(Some(1), None)),
461 ]])
462 .unwrap()
463 .filter(Expr::Column(Column::from_name("column3")).gt(Expr::Literal(
464 ScalarValue::Utf8(Some("1970-01-01 00:00:00+08:00".to_string())),
465 )))
466 .unwrap()
467 .filter(
468 Expr::Literal(ScalarValue::Utf8(Some(
469 "1970-01-01 00:00:00+08:00".to_string(),
470 )))
471 .lt_eq(Expr::Column(Column::from_name("column3"))),
472 )
473 .unwrap()
474 .aggregate(
475 Vec::<Expr>::new(),
476 vec![Expr::AggregateFunction(
477 datafusion_expr::expr::AggregateFunction::new_udf(
478 datafusion::functions_aggregate::count::count_udaf(),
479 vec![Expr::Column(Column::from_name("column1"))],
480 false,
481 None,
482 None,
483 None,
484 ),
485 )],
486 )
487 .unwrap()
488 .build()
489 .unwrap();
490 let context = QueryEngineContext::mock();
491
492 let transformed_plan = TypeConversionRule
493 .analyze(plan, &context, &ConfigOptions::default())
494 .unwrap();
495 let expected = String::from(
496 "Aggregate: groupBy=[[]], aggr=[[count(column1)]]\
497 \n Filter: TimestampSecond(-28800, None) <= column3\
498 \n Filter: column3 > TimestampSecond(-28800, None)\
499 \n Values: (Int64(1), Float64(1), TimestampMillisecond(1, None))",
500 );
501 assert_eq!(format!("{}", transformed_plan.display_indent()), expected);
502 }
503
504 #[test]
505 fn test_reverse_non_ts_type() {
506 let context = QueryEngineContext::mock();
507
508 let plan =
509 LogicalPlanBuilder::values(vec![vec![Expr::Literal(ScalarValue::Float64(Some(1.0)))]])
510 .unwrap()
511 .filter(
512 Expr::Column(Column::from_name("column1"))
513 .gt_eq(Expr::Literal(ScalarValue::Utf8(Some("1.2345".to_string())))),
514 )
515 .unwrap()
516 .filter(
517 Expr::Literal(ScalarValue::Utf8(Some("1.2345".to_string())))
518 .lt(Expr::Column(Column::from_name("column1"))),
519 )
520 .unwrap()
521 .build()
522 .unwrap();
523 let transformed_plan = TypeConversionRule
524 .analyze(plan, &context, &ConfigOptions::default())
525 .unwrap();
526 let expected = String::from(
527 "Filter: Utf8(\"1.2345\") < column1\
528 \n Filter: column1 >= Utf8(\"1.2345\")\
529 \n Values: (Float64(1))",
530 );
531 assert_eq!(format!("{}", transformed_plan.display_indent()), expected);
532 }
533}