sql/statements/transform/
type_alias.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::ControlFlow;
16
17use datatypes::data_type::DataType as GreptimeDataType;
18use sqlparser::ast::{
19    DataType, ExactNumberInfo, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgumentList,
20    Ident, ObjectName, Value,
21};
22
23use crate::ast::ObjectNamePartExt;
24use crate::error::Result;
25use crate::statements::alter::AlterTableOperation;
26use crate::statements::create::{CreateExternalTable, CreateTable};
27use crate::statements::statement::Statement;
28use crate::statements::transform::TransformRule;
29use crate::statements::{sql_data_type_to_concrete_data_type, TimezoneInfo};
30
31/// SQL data type alias transformer:
32///  - `TimestampSecond`, `Timestamp_s`, `Timestamp_sec` for `Timestamp(0)`.
33///  - `TimestampMillisecond`, `Timestamp_ms` for `Timestamp(3)`.
34///  - `TimestampMicrosecond`, `Timestamp_us` for `Timestamp(6)`.
35///  - `TimestampNanosecond`, `Timestamp_ns` for `Timestamp(9)`.
36///  - `INT8` for `tinyint`
37///  - `INT16` for `smallint`
38///  - `INT32` for `int`
39///  - `INT64` for `bigint`
40///  -  And `UINT8`, `UINT16` etc. for `TinyIntUnsigned` etc.
41///  -  TinyText, MediumText, LongText for `Text`.
42pub(crate) struct TypeAliasTransformRule;
43
44impl TransformRule for TypeAliasTransformRule {
45    fn visit_statement(&self, stmt: &mut Statement) -> Result<()> {
46        match stmt {
47            Statement::CreateTable(CreateTable { columns, .. }) => {
48                columns
49                    .iter_mut()
50                    .for_each(|column| replace_type_alias(column.mut_data_type()));
51            }
52            Statement::CreateExternalTable(CreateExternalTable { columns, .. }) => {
53                columns
54                    .iter_mut()
55                    .for_each(|column| replace_type_alias(column.mut_data_type()));
56            }
57            Statement::AlterTable(alter_table) => {
58                if let AlterTableOperation::ModifyColumnType { target_type, .. } =
59                    alter_table.alter_operation_mut()
60                {
61                    replace_type_alias(target_type)
62                } else if let AlterTableOperation::AddColumns { add_columns, .. } =
63                    alter_table.alter_operation_mut()
64                {
65                    for add_column in add_columns {
66                        replace_type_alias(&mut add_column.column_def.data_type);
67                    }
68                }
69            }
70            _ => {}
71        }
72
73        Ok(())
74    }
75
76    fn visit_expr(&self, expr: &mut Expr) -> ControlFlow<()> {
77        fn cast_expr_to_arrow_cast_func(expr: Expr, cast_type: String) -> Function {
78            Function {
79                name: ObjectName::from(vec![Ident::new("arrow_cast")]),
80                args: sqlparser::ast::FunctionArguments::List(FunctionArgumentList {
81                    args: vec![
82                        FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)),
83                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(
84                            Value::SingleQuotedString(cast_type).into(),
85                        ))),
86                    ],
87                    duplicate_treatment: None,
88                    clauses: vec![],
89                }),
90                filter: None,
91                null_treatment: None,
92                over: None,
93                parameters: sqlparser::ast::FunctionArguments::None,
94                within_group: vec![],
95                uses_odbc_syntax: false,
96            }
97        }
98
99        match expr {
100            // In new sqlparser, the "INT64" is no longer parsed to custom datatype.
101            // The new "Int64" is not recognizable by Datafusion, cannot directly "CAST" to it.
102            // We have to replace the expr to "arrow_cast" function call here.
103            // Same for "FLOAT64".
104            Expr::Cast {
105                expr: cast_expr,
106                data_type,
107                ..
108            } if get_type_by_alias(data_type).is_some() => {
109                // Safety: checked in the match arm.
110                let new_type = get_type_by_alias(data_type).unwrap();
111                if let Ok(new_type) = sql_data_type_to_concrete_data_type(&new_type) {
112                    *expr = Expr::Function(cast_expr_to_arrow_cast_func(
113                        (**cast_expr).clone(),
114                        new_type.as_arrow_type().to_string(),
115                    ));
116                }
117            }
118
119            // Timestamp(precision) in cast, datafusion doesn't support Timestamp(9) etc.
120            // We have to transform it into arrow_cast(expr, type).
121            Expr::Cast {
122                data_type: DataType::Timestamp(precision, zone),
123                expr: cast_expr,
124                ..
125            } => {
126                if let Ok(concrete_type) =
127                    sql_data_type_to_concrete_data_type(&DataType::Timestamp(*precision, *zone))
128                {
129                    let new_type = concrete_type.as_arrow_type();
130                    *expr = Expr::Function(cast_expr_to_arrow_cast_func(
131                        (**cast_expr).clone(),
132                        new_type.to_string(),
133                    ));
134                }
135            }
136
137            // TODO(dennis): supports try_cast
138            _ => {}
139        }
140
141        ControlFlow::<()>::Continue(())
142    }
143}
144
145fn replace_type_alias(data_type: &mut DataType) {
146    if let Some(new_type) = get_type_by_alias(data_type) {
147        *data_type = new_type;
148    }
149}
150
151/// Get data type from alias type.
152/// Returns the mapped data type if the input data type is an alias that we need to replace.
153// Remember to update `get_data_type_by_alias_name()` if you modify this method.
154pub(crate) fn get_type_by_alias(data_type: &DataType) -> Option<DataType> {
155    match data_type {
156        // The sqlparser latest version contains the Int8 alias for Postgres Bigint.
157        // Which means 8 bytes in postgres (not 8 bits).
158        // See https://docs.rs/sqlparser/latest/sqlparser/ast/enum.DataType.html#variant.Int8
159        DataType::Custom(name, tokens) if name.0.len() == 1 && tokens.is_empty() => {
160            get_data_type_by_alias_name(name.0[0].to_string_unquoted().as_str())
161        }
162        DataType::Int8(None) => Some(DataType::TinyInt(None)),
163        DataType::Int16 => Some(DataType::SmallInt(None)),
164        DataType::Int32 => Some(DataType::Int(None)),
165        DataType::Int64 => Some(DataType::BigInt(None)),
166        DataType::UInt8 => Some(DataType::TinyIntUnsigned(None)),
167        DataType::UInt16 => Some(DataType::SmallIntUnsigned(None)),
168        DataType::UInt32 => Some(DataType::IntUnsigned(None)),
169        DataType::UInt64 => Some(DataType::BigIntUnsigned(None)),
170        DataType::Float32 => Some(DataType::Float(None)),
171        DataType::Float64 => Some(DataType::Double(ExactNumberInfo::None)),
172        DataType::Bool => Some(DataType::Boolean),
173        DataType::Datetime(_) => Some(DataType::Timestamp(Some(6), TimezoneInfo::None)),
174        _ => None,
175    }
176}
177
178/// Get the mapped data type from alias name.
179/// It only supports the following types of alias:
180/// - timestamps
181/// - ints
182/// - floats
183/// - texts
184// Remember to update `get_type_alias()` if you modify this method.
185pub(crate) fn get_data_type_by_alias_name(name: &str) -> Option<DataType> {
186    match name.to_uppercase().as_ref() {
187        // Timestamp type alias
188        "TIMESTAMP_S" | "TIMESTAMP_SEC" | "TIMESTAMPSECOND" => {
189            Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
190        }
191
192        "TIMESTAMP_MS" | "TIMESTAMPMILLISECOND" => {
193            Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
194        }
195        "TIMESTAMP_US" | "TIMESTAMPMICROSECOND" | "DATETIME" => {
196            Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
197        }
198        "TIMESTAMP_NS" | "TIMESTAMPNANOSECOND" => {
199            Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
200        }
201        // Number type alias
202        // We keep them for backward compatibility.
203        "INT8" => Some(DataType::TinyInt(None)),
204        "INT16" => Some(DataType::SmallInt(None)),
205        "INT32" => Some(DataType::Int(None)),
206        "INT64" => Some(DataType::BigInt(None)),
207        "UINT8" => Some(DataType::TinyIntUnsigned(None)),
208        "UINT16" => Some(DataType::SmallIntUnsigned(None)),
209        "UINT32" => Some(DataType::IntUnsigned(None)),
210        "UINT64" => Some(DataType::BigIntUnsigned(None)),
211        "FLOAT32" => Some(DataType::Float(None)),
212        "FLOAT64" => Some(DataType::Double(ExactNumberInfo::None)),
213        // String type alias
214        "TINYTEXT" | "MEDIUMTEXT" | "LONGTEXT" => Some(DataType::Text),
215        _ => None,
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use sqlparser::dialect::GenericDialect;
222
223    use super::*;
224    use crate::parser::{ParseOptions, ParserContext};
225    use crate::statements::transform_statements;
226
227    #[test]
228    fn test_get_data_type_by_alias_name() {
229        assert_eq!(
230            get_data_type_by_alias_name("float64"),
231            Some(DataType::Double(ExactNumberInfo::None))
232        );
233        assert_eq!(
234            get_data_type_by_alias_name("Float64"),
235            Some(DataType::Double(ExactNumberInfo::None))
236        );
237        assert_eq!(
238            get_data_type_by_alias_name("FLOAT64"),
239            Some(DataType::Double(ExactNumberInfo::None))
240        );
241
242        assert_eq!(
243            get_data_type_by_alias_name("float32"),
244            Some(DataType::Float(None))
245        );
246        assert_eq!(
247            get_data_type_by_alias_name("int8"),
248            Some(DataType::TinyInt(None))
249        );
250        assert_eq!(
251            get_data_type_by_alias_name("INT16"),
252            Some(DataType::SmallInt(None))
253        );
254        assert_eq!(
255            get_data_type_by_alias_name("INT32"),
256            Some(DataType::Int(None))
257        );
258        assert_eq!(
259            get_data_type_by_alias_name("INT64"),
260            Some(DataType::BigInt(None))
261        );
262        assert_eq!(
263            get_data_type_by_alias_name("Uint8"),
264            Some(DataType::TinyIntUnsigned(None))
265        );
266        assert_eq!(
267            get_data_type_by_alias_name("UINT16"),
268            Some(DataType::SmallIntUnsigned(None))
269        );
270        assert_eq!(
271            get_data_type_by_alias_name("UINT32"),
272            Some(DataType::IntUnsigned(None))
273        );
274        assert_eq!(
275            get_data_type_by_alias_name("uint64"),
276            Some(DataType::BigIntUnsigned(None))
277        );
278
279        assert_eq!(
280            get_data_type_by_alias_name("TimestampSecond"),
281            Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
282        );
283        assert_eq!(
284            get_data_type_by_alias_name("Timestamp_s"),
285            Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
286        );
287        assert_eq!(
288            get_data_type_by_alias_name("Timestamp_sec"),
289            Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
290        );
291
292        assert_eq!(
293            get_data_type_by_alias_name("TimestampMilliSecond"),
294            Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
295        );
296        assert_eq!(
297            get_data_type_by_alias_name("Timestamp_ms"),
298            Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
299        );
300
301        assert_eq!(
302            get_data_type_by_alias_name("TimestampMicroSecond"),
303            Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
304        );
305        assert_eq!(
306            get_data_type_by_alias_name("Timestamp_us"),
307            Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
308        );
309
310        assert_eq!(
311            get_data_type_by_alias_name("TimestampNanoSecond"),
312            Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
313        );
314        assert_eq!(
315            get_data_type_by_alias_name("Timestamp_ns"),
316            Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
317        );
318        assert_eq!(
319            get_data_type_by_alias_name("TinyText"),
320            Some(DataType::Text)
321        );
322        assert_eq!(
323            get_data_type_by_alias_name("MediumText"),
324            Some(DataType::Text)
325        );
326        assert_eq!(
327            get_data_type_by_alias_name("LongText"),
328            Some(DataType::Text)
329        );
330    }
331
332    fn test_timestamp_alias(alias: &str, expected: &str) {
333        let sql = format!("SELECT TIMESTAMP '2020-01-01 01:23:45.12345678'::{alias}");
334        let mut stmts =
335            ParserContext::create_with_dialect(&sql, &GenericDialect {}, ParseOptions::default())
336                .unwrap();
337        transform_statements(&mut stmts).unwrap();
338
339        match &stmts[0] {
340            Statement::Query(q) => assert_eq!(format!("SELECT arrow_cast(TIMESTAMP '2020-01-01 01:23:45.12345678', 'Timestamp({expected}, None)')"), q.to_string()),
341            _ => unreachable!(),
342        }
343    }
344
345    fn test_timestamp_precision_type(precision: i32, expected: &str) {
346        test_timestamp_alias(&format!("Timestamp({precision})"), expected);
347    }
348
349    #[test]
350    fn test_boolean_alias() {
351        let sql = "CREATE TABLE test(b bool, ts TIMESTAMP TIME INDEX)";
352        let mut stmts =
353            ParserContext::create_with_dialect(sql, &GenericDialect {}, ParseOptions::default())
354                .unwrap();
355        transform_statements(&mut stmts).unwrap();
356
357        match &stmts[0] {
358            Statement::CreateTable(c) => assert_eq!("CREATE TABLE test (\n  b BOOLEAN,\n  ts TIMESTAMP NOT NULL,\n  TIME INDEX (ts)\n)\nENGINE=mito\n", c.to_string()),
359            _ => unreachable!(),
360        }
361    }
362
363    #[test]
364    fn test_transform_timestamp_alias() {
365        // Timestamp[Second | Millisecond | Microsecond | Nanosecond]
366        test_timestamp_alias("TimestampSecond", "Second");
367        test_timestamp_alias("Timestamp_s", "Second");
368        test_timestamp_alias("TimestampMillisecond", "Millisecond");
369        test_timestamp_alias("Timestamp_ms", "Millisecond");
370        test_timestamp_alias("TimestampMicrosecond", "Microsecond");
371        test_timestamp_alias("Timestamp_us", "Microsecond");
372        test_timestamp_alias("TimestampNanosecond", "Nanosecond");
373        test_timestamp_alias("Timestamp_ns", "Nanosecond");
374        // Timestamp(precision)
375        test_timestamp_precision_type(0, "Second");
376        test_timestamp_precision_type(3, "Millisecond");
377        test_timestamp_precision_type(6, "Microsecond");
378        test_timestamp_precision_type(9, "Nanosecond");
379    }
380
381    #[test]
382    fn test_create_sql_with_type_alias() {
383        let sql = r#"
384CREATE TABLE data_types (
385  s string,
386  tt tinytext,
387  mt mediumtext,
388  lt longtext,
389  tint int8,
390  sint int16,
391  i int32,
392  bint int64,
393  v varchar,
394  f float32,
395  d float64,
396  b boolean,
397  vb varbinary,
398  dt date,
399  dtt datetime,
400  ts0 TimestampSecond,
401  ts3 TimestampMillisecond,
402  ts6 TimestampMicrosecond,
403  ts9 TimestampNanosecond DEFAULT CURRENT_TIMESTAMP TIME INDEX,
404  PRIMARY KEY(s));"#;
405
406        let mut stmts =
407            ParserContext::create_with_dialect(sql, &GenericDialect {}, ParseOptions::default())
408                .unwrap();
409        transform_statements(&mut stmts).unwrap();
410
411        match &stmts[0] {
412            Statement::CreateTable(c) => {
413                let expected = r#"CREATE TABLE data_types (
414  s STRING,
415  tt TINYTEXT,
416  mt MEDIUMTEXT,
417  lt LONGTEXT,
418  tint TINYINT,
419  sint SMALLINT,
420  i INT,
421  bint BIGINT,
422  v VARCHAR,
423  f FLOAT,
424  d DOUBLE,
425  b BOOLEAN,
426  vb VARBINARY,
427  dt DATE,
428  dtt TIMESTAMP(6),
429  ts0 TIMESTAMP(0),
430  ts3 TIMESTAMP(3),
431  ts6 TIMESTAMP(6),
432  ts9 TIMESTAMP(9) DEFAULT CURRENT_TIMESTAMP NOT NULL,
433  TIME INDEX (ts9),
434  PRIMARY KEY (s)
435)
436ENGINE=mito
437"#;
438
439                assert_eq!(expected, c.to_string());
440            }
441            _ => unreachable!(),
442        }
443    }
444}