sql/statements/transform/
type_alias.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::ControlFlow;
16
17use datatypes::data_type::DataType as GreptimeDataType;
18use sqlparser::ast::{
19    DataType, ExactNumberInfo, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgumentList,
20    Ident, ObjectName, Value,
21};
22
23use crate::ast::ObjectNamePartExt;
24use crate::error::Result;
25use crate::statements::alter::AlterTableOperation;
26use crate::statements::create::{CreateExternalTable, CreateTable};
27use crate::statements::statement::Statement;
28use crate::statements::transform::TransformRule;
29use crate::statements::{TimezoneInfo, sql_data_type_to_concrete_data_type};
30
31/// SQL data type alias transformer:
32///  - `TimestampSecond`, `Timestamp_s`, `Timestamp_sec` for `Timestamp(0)`.
33///  - `TimestampMillisecond`, `Timestamp_ms` for `Timestamp(3)`.
34///  - `TimestampMicrosecond`, `Timestamp_us` for `Timestamp(6)`.
35///  - `TimestampNanosecond`, `Timestamp_ns` for `Timestamp(9)`.
36///  -  TinyText, MediumText, LongText for `Text`.
37///
38/// SQL dialect integer type aliases (MySQL & PostgreSQL):
39///  - `INT2` for `smallint`
40///  - `INT4` for `int`
41///  - `INT8` for `bigint`
42///  - `FLOAT4` for `float`
43///  - `FLOAT8` for `double`
44///
45/// Extended type aliases for Arrow types:
46///  - `INT16` for `smallint`
47///  - `INT32` for `int`
48///  - `INT64` for `bigint`
49///  -  And `UINT8`, `UINT16` etc. for `TinyIntUnsigned` etc.
50///
51pub(crate) struct TypeAliasTransformRule;
52
53impl TransformRule for TypeAliasTransformRule {
54    fn visit_statement(&self, stmt: &mut Statement) -> Result<()> {
55        match stmt {
56            Statement::CreateTable(CreateTable { columns, .. }) => {
57                columns
58                    .iter_mut()
59                    .for_each(|column| replace_type_alias(column.mut_data_type()));
60            }
61            Statement::CreateExternalTable(CreateExternalTable { columns, .. }) => {
62                columns
63                    .iter_mut()
64                    .for_each(|column| replace_type_alias(column.mut_data_type()));
65            }
66            Statement::AlterTable(alter_table) => {
67                if let AlterTableOperation::ModifyColumnType { target_type, .. } =
68                    alter_table.alter_operation_mut()
69                {
70                    replace_type_alias(target_type)
71                } else if let AlterTableOperation::AddColumns { add_columns, .. } =
72                    alter_table.alter_operation_mut()
73                {
74                    for add_column in add_columns {
75                        replace_type_alias(&mut add_column.column_def.data_type);
76                    }
77                }
78            }
79            _ => {}
80        }
81
82        Ok(())
83    }
84
85    fn visit_expr(&self, expr: &mut Expr) -> ControlFlow<()> {
86        fn cast_expr_to_arrow_cast_func(expr: Expr, cast_type: String) -> Function {
87            Function {
88                name: ObjectName::from(vec![Ident::new("arrow_cast")]),
89                args: sqlparser::ast::FunctionArguments::List(FunctionArgumentList {
90                    args: vec![
91                        FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)),
92                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(
93                            Value::SingleQuotedString(cast_type).into(),
94                        ))),
95                    ],
96                    duplicate_treatment: None,
97                    clauses: vec![],
98                }),
99                filter: None,
100                null_treatment: None,
101                over: None,
102                parameters: sqlparser::ast::FunctionArguments::None,
103                within_group: vec![],
104                uses_odbc_syntax: false,
105            }
106        }
107
108        match expr {
109            // In new sqlparser, the "INT64" is no longer parsed to custom datatype.
110            // The new "Int64" is not recognizable by Datafusion, cannot directly "CAST" to it.
111            // We have to replace the expr to "arrow_cast" function call here.
112            // Same for "FLOAT64".
113            Expr::Cast {
114                expr: cast_expr,
115                data_type,
116                ..
117            } if get_type_by_alias(data_type).is_some() => {
118                // Safety: checked in the match arm.
119                let new_type = get_type_by_alias(data_type).unwrap();
120                if let Ok(new_type) =
121                    sql_data_type_to_concrete_data_type(&new_type, &Default::default())
122                {
123                    *expr = Expr::Function(cast_expr_to_arrow_cast_func(
124                        (**cast_expr).clone(),
125                        new_type.as_arrow_type().to_string(),
126                    ));
127                }
128            }
129
130            // Timestamp(precision) in cast, datafusion doesn't support Timestamp(9) etc.
131            // We have to transform it into arrow_cast(expr, type).
132            Expr::Cast {
133                data_type: DataType::Timestamp(precision, zone),
134                expr: cast_expr,
135                ..
136            } => {
137                if let Ok(concrete_type) = sql_data_type_to_concrete_data_type(
138                    &DataType::Timestamp(*precision, *zone),
139                    &Default::default(),
140                ) {
141                    let new_type = concrete_type.as_arrow_type();
142                    *expr = Expr::Function(cast_expr_to_arrow_cast_func(
143                        (**cast_expr).clone(),
144                        new_type.to_string(),
145                    ));
146                }
147            }
148
149            // TODO(dennis): supports try_cast
150            _ => {}
151        }
152
153        ControlFlow::<()>::Continue(())
154    }
155}
156
157fn replace_type_alias(data_type: &mut DataType) {
158    if let Some(new_type) = get_type_by_alias(data_type) {
159        *data_type = new_type;
160    }
161}
162
163/// Get data type from alias type.
164/// Returns the mapped data type if the input data type is an alias that we need to replace.
165// Remember to update `get_data_type_by_alias_name()` if you modify this method.
166pub(crate) fn get_type_by_alias(data_type: &DataType) -> Option<DataType> {
167    match data_type {
168        DataType::Custom(name, tokens) if name.0.len() == 1 && tokens.is_empty() => {
169            get_data_type_by_alias_name(name.0[0].to_string_unquoted().as_str())
170        }
171        DataType::Int2(None) => Some(DataType::SmallInt(None)),
172        DataType::Int4(None) => Some(DataType::Int(None)),
173        DataType::Int8(None) => Some(DataType::BigInt(None)),
174        DataType::Int16 => Some(DataType::SmallInt(None)),
175        DataType::Int32 => Some(DataType::Int(None)),
176        DataType::Int64 => Some(DataType::BigInt(None)),
177        DataType::UInt8 => Some(DataType::TinyIntUnsigned(None)),
178        DataType::UInt16 => Some(DataType::SmallIntUnsigned(None)),
179        DataType::UInt32 => Some(DataType::IntUnsigned(None)),
180        DataType::UInt64 => Some(DataType::BigIntUnsigned(None)),
181        DataType::Float4 => Some(DataType::Float(None)),
182        DataType::Float8 => Some(DataType::Double(ExactNumberInfo::None)),
183        DataType::Float32 => Some(DataType::Float(None)),
184        DataType::Float64 => Some(DataType::Double(ExactNumberInfo::None)),
185        DataType::Bool => Some(DataType::Boolean),
186        DataType::Datetime(_) => Some(DataType::Timestamp(Some(6), TimezoneInfo::None)),
187        _ => None,
188    }
189}
190
191/// Get the mapped data type from alias name.
192/// It only supports the following types of alias:
193/// - timestamps
194/// - ints
195/// - floats
196/// - texts
197// Remember to update `get_type_alias()` if you modify this method.
198pub(crate) fn get_data_type_by_alias_name(name: &str) -> Option<DataType> {
199    match name.to_uppercase().as_ref() {
200        // Timestamp type alias
201        "TIMESTAMP_S" | "TIMESTAMP_SEC" | "TIMESTAMPSECOND" => {
202            Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
203        }
204
205        "TIMESTAMP_MS" | "TIMESTAMPMILLISECOND" => {
206            Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
207        }
208        "TIMESTAMP_US" | "TIMESTAMPMICROSECOND" | "DATETIME" => {
209            Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
210        }
211        "TIMESTAMP_NS" | "TIMESTAMPNANOSECOND" => {
212            Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
213        }
214        // Number type alias
215        "INT2" => Some(DataType::SmallInt(None)),
216        "INT4" => Some(DataType::Int(None)),
217        "INT8" => Some(DataType::BigInt(None)),
218        "INT16" => Some(DataType::SmallInt(None)),
219        "INT32" => Some(DataType::Int(None)),
220        "INT64" => Some(DataType::BigInt(None)),
221        "UINT8" => Some(DataType::TinyIntUnsigned(None)),
222        "UINT16" => Some(DataType::SmallIntUnsigned(None)),
223        "UINT32" => Some(DataType::IntUnsigned(None)),
224        "UINT64" => Some(DataType::BigIntUnsigned(None)),
225        "FLOAT4" => Some(DataType::Float(None)),
226        "FLOAT8" => Some(DataType::Double(ExactNumberInfo::None)),
227        "FLOAT32" => Some(DataType::Float(None)),
228        "FLOAT64" => Some(DataType::Double(ExactNumberInfo::None)),
229        // String type alias
230        "TINYTEXT" | "MEDIUMTEXT" | "LONGTEXT" => Some(DataType::Text),
231        _ => None,
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use sqlparser::dialect::GenericDialect;
238
239    use super::*;
240    use crate::parser::{ParseOptions, ParserContext};
241    use crate::statements::transform_statements;
242
243    #[test]
244    fn test_get_data_type_by_alias_name() {
245        assert_eq!(
246            get_data_type_by_alias_name("float64"),
247            Some(DataType::Double(ExactNumberInfo::None))
248        );
249        assert_eq!(
250            get_data_type_by_alias_name("Float64"),
251            Some(DataType::Double(ExactNumberInfo::None))
252        );
253        assert_eq!(
254            get_data_type_by_alias_name("FLOAT64"),
255            Some(DataType::Double(ExactNumberInfo::None))
256        );
257        assert_eq!(
258            get_data_type_by_alias_name("float32"),
259            Some(DataType::Float(None))
260        );
261        assert_eq!(
262            get_data_type_by_alias_name("float8"),
263            Some(DataType::Double(ExactNumberInfo::None))
264        );
265        assert_eq!(
266            get_data_type_by_alias_name("float4"),
267            Some(DataType::Float(None))
268        );
269        assert_eq!(
270            get_data_type_by_alias_name("int8"),
271            Some(DataType::BigInt(None))
272        );
273        assert_eq!(
274            get_data_type_by_alias_name("int4"),
275            Some(DataType::Int(None))
276        );
277        assert_eq!(
278            get_data_type_by_alias_name("int2"),
279            Some(DataType::SmallInt(None))
280        );
281        assert_eq!(
282            get_data_type_by_alias_name("INT16"),
283            Some(DataType::SmallInt(None))
284        );
285        assert_eq!(
286            get_data_type_by_alias_name("INT32"),
287            Some(DataType::Int(None))
288        );
289        assert_eq!(
290            get_data_type_by_alias_name("INT64"),
291            Some(DataType::BigInt(None))
292        );
293        assert_eq!(
294            get_data_type_by_alias_name("Uint8"),
295            Some(DataType::TinyIntUnsigned(None))
296        );
297        assert_eq!(
298            get_data_type_by_alias_name("UINT16"),
299            Some(DataType::SmallIntUnsigned(None))
300        );
301        assert_eq!(
302            get_data_type_by_alias_name("UINT32"),
303            Some(DataType::IntUnsigned(None))
304        );
305        assert_eq!(
306            get_data_type_by_alias_name("uint64"),
307            Some(DataType::BigIntUnsigned(None))
308        );
309
310        assert_eq!(
311            get_data_type_by_alias_name("TimestampSecond"),
312            Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
313        );
314        assert_eq!(
315            get_data_type_by_alias_name("Timestamp_s"),
316            Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
317        );
318        assert_eq!(
319            get_data_type_by_alias_name("Timestamp_sec"),
320            Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
321        );
322
323        assert_eq!(
324            get_data_type_by_alias_name("TimestampMilliSecond"),
325            Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
326        );
327        assert_eq!(
328            get_data_type_by_alias_name("Timestamp_ms"),
329            Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
330        );
331
332        assert_eq!(
333            get_data_type_by_alias_name("TimestampMicroSecond"),
334            Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
335        );
336        assert_eq!(
337            get_data_type_by_alias_name("Timestamp_us"),
338            Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
339        );
340
341        assert_eq!(
342            get_data_type_by_alias_name("TimestampNanoSecond"),
343            Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
344        );
345        assert_eq!(
346            get_data_type_by_alias_name("Timestamp_ns"),
347            Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
348        );
349        assert_eq!(
350            get_data_type_by_alias_name("TinyText"),
351            Some(DataType::Text)
352        );
353        assert_eq!(
354            get_data_type_by_alias_name("MediumText"),
355            Some(DataType::Text)
356        );
357        assert_eq!(
358            get_data_type_by_alias_name("LongText"),
359            Some(DataType::Text)
360        );
361    }
362
363    fn test_timestamp_alias(alias: &str, expected: &str) {
364        let sql = format!("SELECT TIMESTAMP '2020-01-01 01:23:45.12345678'::{alias}");
365        let mut stmts =
366            ParserContext::create_with_dialect(&sql, &GenericDialect {}, ParseOptions::default())
367                .unwrap();
368        transform_statements(&mut stmts).unwrap();
369
370        match &stmts[0] {
371            Statement::Query(q) => assert_eq!(
372                format!(
373                    "SELECT arrow_cast(TIMESTAMP '2020-01-01 01:23:45.12345678', 'Timestamp({expected}, None)')"
374                ),
375                q.to_string()
376            ),
377            _ => unreachable!(),
378        }
379    }
380
381    fn test_timestamp_precision_type(precision: i32, expected: &str) {
382        test_timestamp_alias(&format!("Timestamp({precision})"), expected);
383    }
384
385    #[test]
386    fn test_boolean_alias() {
387        let sql = "CREATE TABLE test(b bool, ts TIMESTAMP TIME INDEX)";
388        let mut stmts =
389            ParserContext::create_with_dialect(sql, &GenericDialect {}, ParseOptions::default())
390                .unwrap();
391        transform_statements(&mut stmts).unwrap();
392
393        match &stmts[0] {
394            Statement::CreateTable(c) => assert_eq!(
395                "CREATE TABLE test (\n  b BOOLEAN,\n  ts TIMESTAMP NOT NULL,\n  TIME INDEX (ts)\n)\nENGINE=mito\n",
396                c.to_string()
397            ),
398            _ => unreachable!(),
399        }
400    }
401
402    #[test]
403    fn test_transform_timestamp_alias() {
404        // Timestamp[Second | Millisecond | Microsecond | Nanosecond]
405        test_timestamp_alias("TimestampSecond", "Second");
406        test_timestamp_alias("Timestamp_s", "Second");
407        test_timestamp_alias("TimestampMillisecond", "Millisecond");
408        test_timestamp_alias("Timestamp_ms", "Millisecond");
409        test_timestamp_alias("TimestampMicrosecond", "Microsecond");
410        test_timestamp_alias("Timestamp_us", "Microsecond");
411        test_timestamp_alias("TimestampNanosecond", "Nanosecond");
412        test_timestamp_alias("Timestamp_ns", "Nanosecond");
413        // Timestamp(precision)
414        test_timestamp_precision_type(0, "Second");
415        test_timestamp_precision_type(3, "Millisecond");
416        test_timestamp_precision_type(6, "Microsecond");
417        test_timestamp_precision_type(9, "Nanosecond");
418    }
419
420    #[test]
421    fn test_create_sql_with_type_alias() {
422        let sql = r#"
423CREATE TABLE data_types (
424  s string,
425  tt tinytext,
426  mt mediumtext,
427  lt longtext,
428  i2 int2,
429  i4 int4,
430  i8 int8,
431  sint int16,
432  i int32,
433  bint int64,
434  v varchar,
435  f4 float4,
436  f8 float8,
437  f float32,
438  d float64,
439  b boolean,
440  vb varbinary,
441  dt date,
442  dtt datetime,
443  ts0 TimestampSecond,
444  ts3 TimestampMillisecond,
445  ts6 TimestampMicrosecond,
446  ts9 TimestampNanosecond DEFAULT CURRENT_TIMESTAMP TIME INDEX,
447  PRIMARY KEY(s));"#;
448
449        let mut stmts =
450            ParserContext::create_with_dialect(sql, &GenericDialect {}, ParseOptions::default())
451                .unwrap();
452        transform_statements(&mut stmts).unwrap();
453
454        match &stmts[0] {
455            Statement::CreateTable(c) => {
456                let expected = r#"CREATE TABLE data_types (
457  s STRING,
458  tt TINYTEXT,
459  mt MEDIUMTEXT,
460  lt LONGTEXT,
461  i2 SMALLINT,
462  i4 INT,
463  i8 BIGINT,
464  sint SMALLINT,
465  i INT,
466  bint BIGINT,
467  v VARCHAR,
468  f4 FLOAT,
469  f8 DOUBLE,
470  f FLOAT,
471  d DOUBLE,
472  b BOOLEAN,
473  vb VARBINARY,
474  dt DATE,
475  dtt TIMESTAMP(6),
476  ts0 TIMESTAMP(0),
477  ts3 TIMESTAMP(3),
478  ts6 TIMESTAMP(6),
479  ts9 TIMESTAMP(9) DEFAULT CURRENT_TIMESTAMP NOT NULL,
480  TIME INDEX (ts9),
481  PRIMARY KEY (s)
482)
483ENGINE=mito
484"#;
485
486                assert_eq!(expected, c.to_string());
487            }
488            _ => unreachable!(),
489        }
490    }
491}