sql/statements/transform/
type_alias.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::ControlFlow;
16
17use datatypes::data_type::DataType as GreptimeDataType;
18use sqlparser::ast::{
19    DataType, ExactNumberInfo, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgumentList,
20    Ident, ObjectName, Value,
21};
22
23use crate::error::Result;
24use crate::statements::alter::AlterTableOperation;
25use crate::statements::create::{CreateExternalTable, CreateTable};
26use crate::statements::statement::Statement;
27use crate::statements::transform::TransformRule;
28use crate::statements::{sql_data_type_to_concrete_data_type, TimezoneInfo};
29
30/// SQL data type alias transformer:
31///  - `TimestampSecond`, `Timestamp_s`, `Timestamp_sec` for `Timestamp(0)`.
32///  - `TimestampMillisecond`, `Timestamp_ms` for `Timestamp(3)`.
33///  - `TimestampMicrosecond`, `Timestamp_us` for `Timestamp(6)`.
34///  - `TimestampNanosecond`, `Timestamp_ns` for `Timestamp(9)`.
35///  - `INT8` for `tinyint`
36///  - `INT16` for `smallint`
37///  - `INT32` for `int`
38///  - `INT64` for `bigint`
39///  -  And `UINT8`, `UINT16` etc. for `UnsignedTinyint` etc.
40///  -  TinyText, MediumText, LongText for `Text`.
41pub(crate) struct TypeAliasTransformRule;
42
43impl TransformRule for TypeAliasTransformRule {
44    fn visit_statement(&self, stmt: &mut Statement) -> Result<()> {
45        match stmt {
46            Statement::CreateTable(CreateTable { columns, .. }) => {
47                columns
48                    .iter_mut()
49                    .for_each(|column| replace_type_alias(column.mut_data_type()));
50            }
51            Statement::CreateExternalTable(CreateExternalTable { columns, .. }) => {
52                columns
53                    .iter_mut()
54                    .for_each(|column| replace_type_alias(column.mut_data_type()));
55            }
56            Statement::AlterTable(alter_table) => {
57                if let AlterTableOperation::ModifyColumnType { target_type, .. } =
58                    alter_table.alter_operation_mut()
59                {
60                    replace_type_alias(target_type)
61                } else if let AlterTableOperation::AddColumns { add_columns, .. } =
62                    alter_table.alter_operation_mut()
63                {
64                    for add_column in add_columns {
65                        replace_type_alias(&mut add_column.column_def.data_type);
66                    }
67                }
68            }
69            _ => {}
70        }
71
72        Ok(())
73    }
74
75    fn visit_expr(&self, expr: &mut Expr) -> ControlFlow<()> {
76        fn cast_expr_to_arrow_cast_func(expr: Expr, cast_type: String) -> Function {
77            Function {
78                name: ObjectName(vec![Ident::new("arrow_cast")]),
79                args: sqlparser::ast::FunctionArguments::List(FunctionArgumentList {
80                    args: vec![
81                        FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)),
82                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(
83                            Value::SingleQuotedString(cast_type),
84                        ))),
85                    ],
86                    duplicate_treatment: None,
87                    clauses: vec![],
88                }),
89                filter: None,
90                null_treatment: None,
91                over: None,
92                parameters: sqlparser::ast::FunctionArguments::None,
93                within_group: vec![],
94                uses_odbc_syntax: false,
95            }
96        }
97
98        match expr {
99            // In new sqlparser, the "INT64" is no longer parsed to custom datatype.
100            // The new "Int64" is not recognizable by Datafusion, cannot directly "CAST" to it.
101            // We have to replace the expr to "arrow_cast" function call here.
102            // Same for "FLOAT64".
103            Expr::Cast {
104                expr: cast_expr,
105                data_type,
106                ..
107            } if get_type_by_alias(data_type).is_some() => {
108                // Safety: checked in the match arm.
109                let new_type = get_type_by_alias(data_type).unwrap();
110                if let Ok(new_type) = sql_data_type_to_concrete_data_type(&new_type) {
111                    *expr = Expr::Function(cast_expr_to_arrow_cast_func(
112                        (**cast_expr).clone(),
113                        new_type.as_arrow_type().to_string(),
114                    ));
115                }
116            }
117
118            // Timestamp(precision) in cast, datafusion doesn't support Timestamp(9) etc.
119            // We have to transform it into arrow_cast(expr, type).
120            Expr::Cast {
121                data_type: DataType::Timestamp(precision, zone),
122                expr: cast_expr,
123                ..
124            } => {
125                if let Ok(concrete_type) =
126                    sql_data_type_to_concrete_data_type(&DataType::Timestamp(*precision, *zone))
127                {
128                    let new_type = concrete_type.as_arrow_type();
129                    *expr = Expr::Function(cast_expr_to_arrow_cast_func(
130                        (**cast_expr).clone(),
131                        new_type.to_string(),
132                    ));
133                }
134            }
135
136            // TODO(dennis): supports try_cast
137            _ => {}
138        }
139
140        ControlFlow::<()>::Continue(())
141    }
142}
143
144fn replace_type_alias(data_type: &mut DataType) {
145    if let Some(new_type) = get_type_by_alias(data_type) {
146        *data_type = new_type;
147    }
148}
149
150/// Get data type from alias type.
151/// Returns the mapped data type if the input data type is an alias that we need to replace.
152// Remember to update `get_data_type_by_alias_name()` if you modify this method.
153pub(crate) fn get_type_by_alias(data_type: &DataType) -> Option<DataType> {
154    match data_type {
155        // The sqlparser latest version contains the Int8 alias for Postgres Bigint.
156        // Which means 8 bytes in postgres (not 8 bits).
157        // See https://docs.rs/sqlparser/latest/sqlparser/ast/enum.DataType.html#variant.Int8
158        DataType::Custom(name, tokens) if name.0.len() == 1 && tokens.is_empty() => {
159            get_data_type_by_alias_name(name.0[0].value.as_str())
160        }
161        DataType::Int8(None) => Some(DataType::TinyInt(None)),
162        DataType::Int16 => Some(DataType::SmallInt(None)),
163        DataType::Int32 => Some(DataType::Int(None)),
164        DataType::Int64 => Some(DataType::BigInt(None)),
165        DataType::UInt8 => Some(DataType::UnsignedTinyInt(None)),
166        DataType::UInt16 => Some(DataType::UnsignedSmallInt(None)),
167        DataType::UInt32 => Some(DataType::UnsignedInt(None)),
168        DataType::UInt64 => Some(DataType::UnsignedBigInt(None)),
169        DataType::Float32 => Some(DataType::Float(None)),
170        DataType::Float64 => Some(DataType::Double(ExactNumberInfo::None)),
171        DataType::Bool => Some(DataType::Boolean),
172        DataType::Datetime(_) => Some(DataType::Timestamp(Some(6), TimezoneInfo::None)),
173        _ => None,
174    }
175}
176
177/// Get the mapped data type from alias name.
178/// It only supports the following types of alias:
179/// - timestamps
180/// - ints
181/// - floats
182/// - texts
183// Remember to update `get_type_alias()` if you modify this method.
184pub(crate) fn get_data_type_by_alias_name(name: &str) -> Option<DataType> {
185    match name.to_uppercase().as_ref() {
186        // Timestamp type alias
187        "TIMESTAMP_S" | "TIMESTAMP_SEC" | "TIMESTAMPSECOND" => {
188            Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
189        }
190
191        "TIMESTAMP_MS" | "TIMESTAMPMILLISECOND" => {
192            Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
193        }
194        "TIMESTAMP_US" | "TIMESTAMPMICROSECOND" | "DATETIME" => {
195            Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
196        }
197        "TIMESTAMP_NS" | "TIMESTAMPNANOSECOND" => {
198            Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
199        }
200        // Number type alias
201        // We keep them for backward compatibility.
202        "INT8" => Some(DataType::TinyInt(None)),
203        "INT16" => Some(DataType::SmallInt(None)),
204        "INT32" => Some(DataType::Int(None)),
205        "INT64" => Some(DataType::BigInt(None)),
206        "UINT8" => Some(DataType::UnsignedTinyInt(None)),
207        "UINT16" => Some(DataType::UnsignedSmallInt(None)),
208        "UINT32" => Some(DataType::UnsignedInt(None)),
209        "UINT64" => Some(DataType::UnsignedBigInt(None)),
210        "FLOAT32" => Some(DataType::Float(None)),
211        "FLOAT64" => Some(DataType::Double(ExactNumberInfo::None)),
212        // String type alias
213        "TINYTEXT" | "MEDIUMTEXT" | "LONGTEXT" => Some(DataType::Text),
214        _ => None,
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use sqlparser::dialect::GenericDialect;
221
222    use super::*;
223    use crate::parser::{ParseOptions, ParserContext};
224    use crate::statements::transform_statements;
225
226    #[test]
227    fn test_get_data_type_by_alias_name() {
228        assert_eq!(
229            get_data_type_by_alias_name("float64"),
230            Some(DataType::Double(ExactNumberInfo::None))
231        );
232        assert_eq!(
233            get_data_type_by_alias_name("Float64"),
234            Some(DataType::Double(ExactNumberInfo::None))
235        );
236        assert_eq!(
237            get_data_type_by_alias_name("FLOAT64"),
238            Some(DataType::Double(ExactNumberInfo::None))
239        );
240
241        assert_eq!(
242            get_data_type_by_alias_name("float32"),
243            Some(DataType::Float(None))
244        );
245        assert_eq!(
246            get_data_type_by_alias_name("int8"),
247            Some(DataType::TinyInt(None))
248        );
249        assert_eq!(
250            get_data_type_by_alias_name("INT16"),
251            Some(DataType::SmallInt(None))
252        );
253        assert_eq!(
254            get_data_type_by_alias_name("INT32"),
255            Some(DataType::Int(None))
256        );
257        assert_eq!(
258            get_data_type_by_alias_name("INT64"),
259            Some(DataType::BigInt(None))
260        );
261        assert_eq!(
262            get_data_type_by_alias_name("Uint8"),
263            Some(DataType::UnsignedTinyInt(None))
264        );
265        assert_eq!(
266            get_data_type_by_alias_name("UINT16"),
267            Some(DataType::UnsignedSmallInt(None))
268        );
269        assert_eq!(
270            get_data_type_by_alias_name("UINT32"),
271            Some(DataType::UnsignedInt(None))
272        );
273        assert_eq!(
274            get_data_type_by_alias_name("uint64"),
275            Some(DataType::UnsignedBigInt(None))
276        );
277
278        assert_eq!(
279            get_data_type_by_alias_name("TimestampSecond"),
280            Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
281        );
282        assert_eq!(
283            get_data_type_by_alias_name("Timestamp_s"),
284            Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
285        );
286        assert_eq!(
287            get_data_type_by_alias_name("Timestamp_sec"),
288            Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
289        );
290
291        assert_eq!(
292            get_data_type_by_alias_name("TimestampMilliSecond"),
293            Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
294        );
295        assert_eq!(
296            get_data_type_by_alias_name("Timestamp_ms"),
297            Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
298        );
299
300        assert_eq!(
301            get_data_type_by_alias_name("TimestampMicroSecond"),
302            Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
303        );
304        assert_eq!(
305            get_data_type_by_alias_name("Timestamp_us"),
306            Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
307        );
308
309        assert_eq!(
310            get_data_type_by_alias_name("TimestampNanoSecond"),
311            Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
312        );
313        assert_eq!(
314            get_data_type_by_alias_name("Timestamp_ns"),
315            Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
316        );
317        assert_eq!(
318            get_data_type_by_alias_name("TinyText"),
319            Some(DataType::Text)
320        );
321        assert_eq!(
322            get_data_type_by_alias_name("MediumText"),
323            Some(DataType::Text)
324        );
325        assert_eq!(
326            get_data_type_by_alias_name("LongText"),
327            Some(DataType::Text)
328        );
329    }
330
331    fn test_timestamp_alias(alias: &str, expected: &str) {
332        let sql = format!("SELECT TIMESTAMP '2020-01-01 01:23:45.12345678'::{alias}");
333        let mut stmts =
334            ParserContext::create_with_dialect(&sql, &GenericDialect {}, ParseOptions::default())
335                .unwrap();
336        transform_statements(&mut stmts).unwrap();
337
338        match &stmts[0] {
339            Statement::Query(q) => assert_eq!(format!("SELECT arrow_cast(TIMESTAMP '2020-01-01 01:23:45.12345678', 'Timestamp({expected}, None)')"), q.to_string()),
340            _ => unreachable!(),
341        }
342    }
343
344    fn test_timestamp_precision_type(precision: i32, expected: &str) {
345        test_timestamp_alias(&format!("Timestamp({precision})"), expected);
346    }
347
348    #[test]
349    fn test_boolean_alias() {
350        let sql = "CREATE TABLE test(b bool, ts TIMESTAMP TIME INDEX)";
351        let mut stmts =
352            ParserContext::create_with_dialect(sql, &GenericDialect {}, ParseOptions::default())
353                .unwrap();
354        transform_statements(&mut stmts).unwrap();
355
356        match &stmts[0] {
357            Statement::CreateTable(c) => assert_eq!("CREATE TABLE test (\n  b BOOLEAN,\n  ts TIMESTAMP NOT NULL,\n  TIME INDEX (ts)\n)\nENGINE=mito\n", c.to_string()),
358            _ => unreachable!(),
359        }
360    }
361
362    #[test]
363    fn test_transform_timestamp_alias() {
364        // Timestamp[Second | Millisecond | Microsecond | Nanosecond]
365        test_timestamp_alias("TimestampSecond", "Second");
366        test_timestamp_alias("Timestamp_s", "Second");
367        test_timestamp_alias("TimestampMillisecond", "Millisecond");
368        test_timestamp_alias("Timestamp_ms", "Millisecond");
369        test_timestamp_alias("TimestampMicrosecond", "Microsecond");
370        test_timestamp_alias("Timestamp_us", "Microsecond");
371        test_timestamp_alias("TimestampNanosecond", "Nanosecond");
372        test_timestamp_alias("Timestamp_ns", "Nanosecond");
373        // Timestamp(precision)
374        test_timestamp_precision_type(0, "Second");
375        test_timestamp_precision_type(3, "Millisecond");
376        test_timestamp_precision_type(6, "Microsecond");
377        test_timestamp_precision_type(9, "Nanosecond");
378    }
379
380    #[test]
381    fn test_create_sql_with_type_alias() {
382        let sql = r#"
383CREATE TABLE data_types (
384  s string,
385  tt tinytext,
386  mt mediumtext,
387  lt longtext,
388  tint int8,
389  sint int16,
390  i int32,
391  bint int64,
392  v varchar,
393  f float32,
394  d float64,
395  b boolean,
396  vb varbinary,
397  dt date,
398  dtt datetime,
399  ts0 TimestampSecond,
400  ts3 TimestampMillisecond,
401  ts6 TimestampMicrosecond,
402  ts9 TimestampNanosecond DEFAULT CURRENT_TIMESTAMP TIME INDEX,
403  PRIMARY KEY(s));"#;
404
405        let mut stmts =
406            ParserContext::create_with_dialect(sql, &GenericDialect {}, ParseOptions::default())
407                .unwrap();
408        transform_statements(&mut stmts).unwrap();
409
410        match &stmts[0] {
411            Statement::CreateTable(c) => {
412                let expected = r#"CREATE TABLE data_types (
413  s STRING,
414  tt TINYTEXT,
415  mt MEDIUMTEXT,
416  lt LONGTEXT,
417  tint TINYINT,
418  sint SMALLINT,
419  i INT,
420  bint BIGINT,
421  v VARCHAR,
422  f FLOAT,
423  d DOUBLE,
424  b BOOLEAN,
425  vb VARBINARY,
426  dt DATE,
427  dtt TIMESTAMP(6),
428  ts0 TIMESTAMP(0),
429  ts3 TIMESTAMP(3),
430  ts6 TIMESTAMP(6),
431  ts9 TIMESTAMP(9) DEFAULT CURRENT_TIMESTAMP NOT NULL,
432  TIME INDEX (ts9),
433  PRIMARY KEY (s)
434)
435ENGINE=mito
436"#;
437
438                assert_eq!(expected, c.to_string());
439            }
440            _ => unreachable!(),
441        }
442    }
443}