1use common_time::timestamp::{TimeUnit, Timestamp};
16use common_time::Timezone;
17use datafusion::config::ConfigOptions;
18use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
19use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
20use datafusion_expr::expr::InList;
21use datafusion_expr::{
22 Between, BinaryExpr, Expr, ExprSchemable, Filter, LogicalPlan, Operator, TableScan,
23};
24use datatypes::arrow::compute;
25use datatypes::arrow::datatypes::DataType;
26use session::context::QueryContextRef;
27
28use crate::optimizer::ExtensionAnalyzerRule;
29use crate::plan::ExtractExpr;
30use crate::QueryEngineContext;
31
32pub struct TypeConversionRule;
38
39impl ExtensionAnalyzerRule for TypeConversionRule {
40 fn analyze(
41 &self,
42 plan: LogicalPlan,
43 ctx: &QueryEngineContext,
44 _config: &ConfigOptions,
45 ) -> Result<LogicalPlan> {
46 plan.transform(&|plan| match plan {
47 LogicalPlan::Filter(filter) => {
48 let mut converter = TypeConverter {
49 schema: filter.input.schema().clone(),
50 query_ctx: ctx.query_ctx(),
51 };
52 let rewritten = filter.predicate.clone().rewrite(&mut converter)?.data;
53 Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new(
54 rewritten,
55 filter.input,
56 )?)))
57 }
58 LogicalPlan::TableScan(TableScan {
59 table_name,
60 source,
61 projection,
62 projected_schema,
63 filters,
64 fetch,
65 }) => {
66 let mut converter = TypeConverter {
67 schema: projected_schema.clone(),
68 query_ctx: ctx.query_ctx(),
69 };
70 let rewrite_filters = filters
71 .into_iter()
72 .map(|e| e.rewrite(&mut converter).map(|x| x.data))
73 .collect::<Result<Vec<_>>>()?;
74 Ok(Transformed::yes(LogicalPlan::TableScan(TableScan {
75 table_name: table_name.clone(),
76 source: source.clone(),
77 projection,
78 projected_schema,
79 filters: rewrite_filters,
80 fetch,
81 })))
82 }
83 LogicalPlan::Projection { .. }
84 | LogicalPlan::Window { .. }
85 | LogicalPlan::Aggregate { .. }
86 | LogicalPlan::Repartition { .. }
87 | LogicalPlan::Extension { .. }
88 | LogicalPlan::Sort { .. }
89 | LogicalPlan::Union { .. }
90 | LogicalPlan::Values { .. }
91 | LogicalPlan::Analyze { .. } => {
92 let mut converter = TypeConverter {
93 schema: plan.schema().clone(),
94 query_ctx: ctx.query_ctx(),
95 };
96 let inputs = plan.inputs().into_iter().cloned().collect::<Vec<_>>();
97 let expr = plan
98 .expressions_consider_join()
99 .into_iter()
100 .map(|e| e.rewrite(&mut converter).map(|x| x.data))
101 .collect::<Result<Vec<_>>>()?;
102
103 plan.with_new_exprs(expr, inputs).map(Transformed::yes)
104 }
105
106 LogicalPlan::Distinct { .. }
107 | LogicalPlan::Limit { .. }
108 | LogicalPlan::Subquery { .. }
109 | LogicalPlan::Explain { .. }
110 | LogicalPlan::SubqueryAlias { .. }
111 | LogicalPlan::EmptyRelation(_)
112 | LogicalPlan::Dml(_)
113 | LogicalPlan::DescribeTable(_)
114 | LogicalPlan::Unnest(_)
115 | LogicalPlan::Statement(_)
116 | LogicalPlan::Ddl(_)
117 | LogicalPlan::Copy(_)
118 | LogicalPlan::RecursiveQuery(_)
119 | LogicalPlan::Join { .. } => Ok(Transformed::no(plan)),
120 })
121 .map(|x| x.data)
122 }
123}
124
125struct TypeConverter {
126 query_ctx: QueryContextRef,
127 schema: DFSchemaRef,
128}
129
130impl TypeConverter {
131 fn column_type(&self, expr: &Expr) -> Option<DataType> {
132 if let Expr::Column(_) = expr {
133 if let Ok(v) = expr.get_type(&self.schema) {
134 return Some(v);
135 }
136 }
137 None
138 }
139
140 fn cast_scalar_value(
141 &self,
142 value: &ScalarValue,
143 target_type: &DataType,
144 ) -> Result<ScalarValue> {
145 match (target_type, value) {
146 (DataType::Timestamp(_, _), ScalarValue::Utf8(Some(v))) => {
147 string_to_timestamp_ms(v, Some(&self.query_ctx.timezone()))
148 }
149 (DataType::Boolean, ScalarValue::Utf8(Some(v))) => match v.to_lowercase().as_str() {
150 "true" => Ok(ScalarValue::Boolean(Some(true))),
151 "false" => Ok(ScalarValue::Boolean(Some(false))),
152 _ => Ok(ScalarValue::Boolean(None)),
153 },
154 (target_type, value) => {
155 let value_arr = value.to_array()?;
156 let arr = compute::cast(&value_arr, target_type)
157 .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?;
158
159 ScalarValue::try_from_array(
160 &arr,
161 0, )
163 }
164 }
165 }
166
167 fn convert_type<'b>(&self, left: &'b Expr, right: &'b Expr) -> Result<(Expr, Expr)> {
168 let left_type = self.column_type(left);
169 let right_type = self.column_type(right);
170
171 let target_type = match (&left_type, &right_type) {
172 (Some(v), None) => v,
173 (None, Some(v)) => v,
174 _ => return Ok((left.clone(), right.clone())),
175 };
176
177 if !matches!(target_type, DataType::Timestamp(_, _) | DataType::Boolean) {
179 return Ok((left.clone(), right.clone()));
180 }
181
182 match (left, right) {
183 (Expr::Column(col), Expr::Literal(value, _)) => {
184 let casted_right = self.cast_scalar_value(value, target_type)?;
185 if casted_right.is_null() {
186 return Err(DataFusionError::Plan(format!(
187 "column:{col:?}. Casting value:{value:?} to {target_type:?} is invalid",
188 )));
189 }
190 Ok((left.clone(), Expr::Literal(casted_right, None)))
191 }
192 (Expr::Literal(value, _), Expr::Column(col)) => {
193 let casted_left = self.cast_scalar_value(value, target_type)?;
194 if casted_left.is_null() {
195 return Err(DataFusionError::Plan(format!(
196 "column:{col:?}. Casting value:{value:?} to {target_type:?} is invalid",
197 )));
198 }
199 Ok((Expr::Literal(casted_left, None), right.clone()))
200 }
201 _ => Ok((left.clone(), right.clone())),
202 }
203 }
204}
205
206impl TreeNodeRewriter for TypeConverter {
207 type Node = Expr;
208
209 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
210 let new_expr = match expr {
211 Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
212 Operator::Eq
213 | Operator::NotEq
214 | Operator::Lt
215 | Operator::LtEq
216 | Operator::Gt
217 | Operator::GtEq => {
218 let (left, right) = self.convert_type(&left, &right)?;
219 Expr::BinaryExpr(BinaryExpr {
220 left: Box::new(left),
221 op,
222 right: Box::new(right),
223 })
224 }
225 _ => Expr::BinaryExpr(BinaryExpr { left, op, right }),
226 },
227 Expr::Between(Between {
228 expr,
229 negated,
230 low,
231 high,
232 }) => {
233 let (expr, low) = self.convert_type(&expr, &low)?;
234 let (expr, high) = self.convert_type(&expr, &high)?;
235 Expr::Between(Between {
236 expr: Box::new(expr),
237 negated,
238 low: Box::new(low),
239 high: Box::new(high),
240 })
241 }
242 Expr::InList(InList {
243 expr,
244 list,
245 negated,
246 }) => {
247 let mut list_expr = Vec::with_capacity(list.len());
248 for e in list {
249 let (_, expr_conversion) = self.convert_type(&expr, &e)?;
250 list_expr.push(expr_conversion);
251 }
252 Expr::InList(InList {
253 expr,
254 list: list_expr,
255 negated,
256 })
257 }
258 Expr::Literal(value, _) => match value {
259 ScalarValue::TimestampSecond(Some(i), _) => {
260 timestamp_to_timestamp_ms_expr(i, TimeUnit::Second)
261 }
262 ScalarValue::TimestampMillisecond(Some(i), _) => {
263 timestamp_to_timestamp_ms_expr(i, TimeUnit::Millisecond)
264 }
265 ScalarValue::TimestampMicrosecond(Some(i), _) => {
266 timestamp_to_timestamp_ms_expr(i, TimeUnit::Microsecond)
267 }
268 ScalarValue::TimestampNanosecond(Some(i), _) => {
269 timestamp_to_timestamp_ms_expr(i, TimeUnit::Nanosecond)
270 }
271 _ => Expr::Literal(value, None),
272 },
273 expr => expr,
274 };
275 Ok(Transformed::yes(new_expr))
276 }
277}
278
279fn timestamp_to_timestamp_ms_expr(val: i64, unit: TimeUnit) -> Expr {
280 let timestamp = match unit {
281 TimeUnit::Second => val * 1_000,
282 TimeUnit::Millisecond => val,
283 TimeUnit::Microsecond => val / 1_000,
284 TimeUnit::Nanosecond => val / 1_000 / 1_000,
285 };
286
287 Expr::Literal(
288 ScalarValue::TimestampMillisecond(Some(timestamp), None),
289 None,
290 )
291}
292
293fn string_to_timestamp_ms(string: &str, timezone: Option<&Timezone>) -> Result<ScalarValue> {
294 let ts = Timestamp::from_str(string, timezone)
295 .map_err(|e| DataFusionError::External(Box::new(e)))?;
296
297 let value = Some(ts.value());
298 let scalar = match ts.unit() {
299 TimeUnit::Second => ScalarValue::TimestampSecond(value, None),
300 TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(value, None),
301 TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(value, None),
302 TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(value, None),
303 };
304 Ok(scalar)
305}
306
307#[cfg(test)]
308mod tests {
309 use std::collections::HashMap;
310 use std::sync::Arc;
311
312 use datafusion_common::arrow::datatypes::Field;
313 use datafusion_common::{Column, DFSchema};
314 use datafusion_expr::{Literal, LogicalPlanBuilder};
315 use datafusion_sql::TableReference;
316 use session::context::QueryContext;
317
318 use super::*;
319
320 #[test]
321 fn test_string_to_timestamp_ms() {
322 assert_eq!(
323 string_to_timestamp_ms("2022-02-02 19:00:00+08:00", None).unwrap(),
324 ScalarValue::TimestampSecond(Some(1643799600), None)
325 );
326 assert_eq!(
327 string_to_timestamp_ms("2009-02-13 23:31:30Z", None).unwrap(),
328 ScalarValue::TimestampSecond(Some(1234567890), None)
329 );
330
331 assert_eq!(
332 string_to_timestamp_ms(
333 "2009-02-13 23:31:30",
334 Some(&Timezone::from_tz_string("Asia/Shanghai").unwrap())
335 )
336 .unwrap(),
337 ScalarValue::TimestampSecond(Some(1234567890 - 8 * 3600), None)
338 );
339
340 assert_eq!(
341 string_to_timestamp_ms(
342 "2009-02-13 23:31:30",
343 Some(&Timezone::from_tz_string("-8:00").unwrap())
344 )
345 .unwrap(),
346 ScalarValue::TimestampSecond(Some(1234567890 + 8 * 3600), None)
347 );
348 }
349
350 #[test]
351 fn test_timestamp_to_timestamp_ms_expr() {
352 assert_eq!(
353 timestamp_to_timestamp_ms_expr(123, TimeUnit::Second),
354 ScalarValue::TimestampMillisecond(Some(123000), None).lit()
355 );
356
357 assert_eq!(
358 timestamp_to_timestamp_ms_expr(123, TimeUnit::Millisecond),
359 ScalarValue::TimestampMillisecond(Some(123), None).lit()
360 );
361
362 assert_eq!(
363 timestamp_to_timestamp_ms_expr(123, TimeUnit::Microsecond),
364 ScalarValue::TimestampMillisecond(Some(0), None).lit()
365 );
366
367 assert_eq!(
368 timestamp_to_timestamp_ms_expr(1230, TimeUnit::Microsecond),
369 ScalarValue::TimestampMillisecond(Some(1), None).lit()
370 );
371
372 assert_eq!(
373 timestamp_to_timestamp_ms_expr(123000, TimeUnit::Microsecond),
374 ScalarValue::TimestampMillisecond(Some(123), None).lit()
375 );
376
377 assert_eq!(
378 timestamp_to_timestamp_ms_expr(1230, TimeUnit::Nanosecond),
379 ScalarValue::TimestampMillisecond(Some(0), None).lit()
380 );
381 assert_eq!(
382 timestamp_to_timestamp_ms_expr(123_000_000, TimeUnit::Nanosecond),
383 ScalarValue::TimestampMillisecond(Some(123), None).lit()
384 );
385 }
386
387 #[test]
388 fn test_convert_timestamp_str() {
389 use datatypes::arrow::datatypes::TimeUnit as ArrowTimeUnit;
390
391 let schema = Arc::new(
392 DFSchema::new_with_metadata(
393 vec![(
394 None::<TableReference>,
395 Arc::new(Field::new(
396 "ts",
397 DataType::Timestamp(ArrowTimeUnit::Millisecond, None),
398 true,
399 )),
400 )],
401 HashMap::new(),
402 )
403 .unwrap(),
404 );
405 let mut converter = TypeConverter {
406 schema,
407 query_ctx: QueryContext::arc(),
408 };
409
410 assert_eq!(
411 Expr::Column(Column::from_name("ts")).gt(ScalarValue::TimestampSecond(
412 Some(1599514949),
413 None
414 )
415 .lit()),
416 converter
417 .f_up(Expr::Column(Column::from_name("ts")).gt("2020-09-08T05:42:29+08:00".lit()))
418 .unwrap()
419 .data
420 );
421 }
422
423 #[test]
424 fn test_convert_bool() {
425 let col_name = "is_valid";
426 let schema = Arc::new(
427 DFSchema::new_with_metadata(
428 vec![(
429 None::<TableReference>,
430 Arc::new(Field::new(col_name, DataType::Boolean, false)),
431 )],
432 HashMap::new(),
433 )
434 .unwrap(),
435 );
436 let mut converter = TypeConverter {
437 schema,
438 query_ctx: QueryContext::arc(),
439 };
440
441 assert_eq!(
442 Expr::Column(Column::from_name(col_name)).eq(true.lit()),
443 converter
444 .f_up(Expr::Column(Column::from_name(col_name)).eq("true".lit()))
445 .unwrap()
446 .data
447 );
448 }
449
450 #[test]
451 fn test_retrieve_type_from_aggr_plan() {
452 let plan = LogicalPlanBuilder::values(vec![vec![
453 ScalarValue::Int64(Some(1)).lit(),
454 ScalarValue::Float64(Some(1.0)).lit(),
455 ScalarValue::TimestampMillisecond(Some(1), None).lit(),
456 ]])
457 .unwrap()
458 .filter(Expr::Column(Column::from_name("column3")).gt("1970-01-01 00:00:00+08:00".lit()))
459 .unwrap()
460 .filter(
461 "1970-01-01 00:00:00+08:00"
462 .lit()
463 .lt_eq(Expr::Column(Column::from_name("column3"))),
464 )
465 .unwrap()
466 .aggregate(
467 Vec::<Expr>::new(),
468 vec![Expr::AggregateFunction(
469 datafusion_expr::expr::AggregateFunction::new_udf(
470 datafusion::functions_aggregate::count::count_udaf(),
471 vec![Expr::Column(Column::from_name("column1"))],
472 false,
473 None,
474 vec![],
475 None,
476 ),
477 )],
478 )
479 .unwrap()
480 .build()
481 .unwrap();
482 let context = QueryEngineContext::mock();
483
484 let transformed_plan = TypeConversionRule
485 .analyze(plan, &context, &ConfigOptions::default())
486 .unwrap();
487 let expected = String::from(
488 "Aggregate: groupBy=[[]], aggr=[[count(column1)]]\
489 \n Filter: TimestampSecond(-28800, None) <= column3\
490 \n Filter: column3 > TimestampSecond(-28800, None)\
491 \n Values: (Int64(1), Float64(1), TimestampMillisecond(1, None))",
492 );
493 assert_eq!(format!("{}", transformed_plan.display_indent()), expected);
494 }
495
496 #[test]
497 fn test_reverse_non_ts_type() {
498 let context = QueryEngineContext::mock();
499
500 let plan = LogicalPlanBuilder::values(vec![vec![1.0f64.lit()]])
501 .unwrap()
502 .filter(Expr::Column(Column::from_name("column1")).gt_eq("1.2345".lit()))
503 .unwrap()
504 .filter(
505 "1.2345"
506 .lit()
507 .lt(Expr::Column(Column::from_name("column1"))),
508 )
509 .unwrap()
510 .build()
511 .unwrap();
512 let transformed_plan = TypeConversionRule
513 .analyze(plan, &context, &ConfigOptions::default())
514 .unwrap();
515 let expected = String::from(
516 "Filter: Utf8(\"1.2345\") < column1\
517 \n Filter: column1 >= Utf8(\"1.2345\")\
518 \n Values: (Float64(1))",
519 );
520 assert_eq!(format!("{}", transformed_plan.display_indent()), expected);
521 }
522}