1use std::ops::ControlFlow;
16
17use datatypes::data_type::DataType as GreptimeDataType;
18use sqlparser::ast::{
19 DataType, ExactNumberInfo, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgumentList,
20 Ident, ObjectName, Value,
21};
22
23use crate::ast::ObjectNamePartExt;
24use crate::error::Result;
25use crate::statements::alter::AlterTableOperation;
26use crate::statements::create::{CreateExternalTable, CreateTable};
27use crate::statements::statement::Statement;
28use crate::statements::transform::TransformRule;
29use crate::statements::{TimezoneInfo, sql_data_type_to_concrete_data_type};
30
31pub(crate) struct TypeAliasTransformRule;
52
53impl TransformRule for TypeAliasTransformRule {
54 fn visit_statement(&self, stmt: &mut Statement) -> Result<()> {
55 match stmt {
56 Statement::CreateTable(CreateTable { columns, .. }) => {
57 columns
58 .iter_mut()
59 .for_each(|column| replace_type_alias(column.mut_data_type()));
60 }
61 Statement::CreateExternalTable(CreateExternalTable { columns, .. }) => {
62 columns
63 .iter_mut()
64 .for_each(|column| replace_type_alias(column.mut_data_type()));
65 }
66 Statement::AlterTable(alter_table) => {
67 if let AlterTableOperation::ModifyColumnType { target_type, .. } =
68 alter_table.alter_operation_mut()
69 {
70 replace_type_alias(target_type)
71 } else if let AlterTableOperation::AddColumns { add_columns, .. } =
72 alter_table.alter_operation_mut()
73 {
74 for add_column in add_columns {
75 replace_type_alias(&mut add_column.column_def.data_type);
76 }
77 }
78 }
79 _ => {}
80 }
81
82 Ok(())
83 }
84
85 fn visit_expr(&self, expr: &mut Expr) -> ControlFlow<()> {
86 fn cast_expr_to_arrow_cast_func(expr: Expr, cast_type: String) -> Function {
87 Function {
88 name: ObjectName::from(vec![Ident::new("arrow_cast")]),
89 args: sqlparser::ast::FunctionArguments::List(FunctionArgumentList {
90 args: vec![
91 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)),
92 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(
93 Value::SingleQuotedString(cast_type).into(),
94 ))),
95 ],
96 duplicate_treatment: None,
97 clauses: vec![],
98 }),
99 filter: None,
100 null_treatment: None,
101 over: None,
102 parameters: sqlparser::ast::FunctionArguments::None,
103 within_group: vec![],
104 uses_odbc_syntax: false,
105 }
106 }
107
108 match expr {
109 Expr::Cast {
114 expr: cast_expr,
115 data_type,
116 ..
117 } if get_type_by_alias(data_type).is_some() => {
118 let new_type = get_type_by_alias(data_type).unwrap();
120 if let Ok(new_type) =
121 sql_data_type_to_concrete_data_type(&new_type, &Default::default())
122 {
123 *expr = Expr::Function(cast_expr_to_arrow_cast_func(
124 (**cast_expr).clone(),
125 new_type.as_arrow_type().to_string(),
126 ));
127 }
128 }
129
130 Expr::Cast {
133 data_type: DataType::Timestamp(precision, zone),
134 expr: cast_expr,
135 ..
136 } => {
137 if let Ok(concrete_type) = sql_data_type_to_concrete_data_type(
138 &DataType::Timestamp(*precision, *zone),
139 &Default::default(),
140 ) {
141 let new_type = concrete_type.as_arrow_type();
142 *expr = Expr::Function(cast_expr_to_arrow_cast_func(
143 (**cast_expr).clone(),
144 new_type.to_string(),
145 ));
146 }
147 }
148
149 _ => {}
151 }
152
153 ControlFlow::<()>::Continue(())
154 }
155}
156
157fn replace_type_alias(data_type: &mut DataType) {
158 if let Some(new_type) = get_type_by_alias(data_type) {
159 *data_type = new_type;
160 }
161}
162
163pub(crate) fn get_type_by_alias(data_type: &DataType) -> Option<DataType> {
167 match data_type {
168 DataType::Custom(name, tokens) if name.0.len() == 1 && tokens.is_empty() => {
169 get_data_type_by_alias_name(name.0[0].to_string_unquoted().as_str())
170 }
171 DataType::Int2(None) => Some(DataType::SmallInt(None)),
172 DataType::Int4(None) => Some(DataType::Int(None)),
173 DataType::Int8(None) => Some(DataType::BigInt(None)),
174 DataType::Int16 => Some(DataType::SmallInt(None)),
175 DataType::Int32 => Some(DataType::Int(None)),
176 DataType::Int64 => Some(DataType::BigInt(None)),
177 DataType::UInt8 => Some(DataType::TinyIntUnsigned(None)),
178 DataType::UInt16 => Some(DataType::SmallIntUnsigned(None)),
179 DataType::UInt32 => Some(DataType::IntUnsigned(None)),
180 DataType::UInt64 => Some(DataType::BigIntUnsigned(None)),
181 DataType::Float4 => Some(DataType::Float(None)),
182 DataType::Float8 => Some(DataType::Double(ExactNumberInfo::None)),
183 DataType::Float32 => Some(DataType::Float(None)),
184 DataType::Float64 => Some(DataType::Double(ExactNumberInfo::None)),
185 DataType::Bool => Some(DataType::Boolean),
186 DataType::Datetime(_) => Some(DataType::Timestamp(Some(6), TimezoneInfo::None)),
187 _ => None,
188 }
189}
190
191pub(crate) fn get_data_type_by_alias_name(name: &str) -> Option<DataType> {
199 match name.to_uppercase().as_ref() {
200 "TIMESTAMP_S" | "TIMESTAMP_SEC" | "TIMESTAMPSECOND" => {
202 Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
203 }
204
205 "TIMESTAMP_MS" | "TIMESTAMPMILLISECOND" => {
206 Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
207 }
208 "TIMESTAMP_US" | "TIMESTAMPMICROSECOND" | "DATETIME" => {
209 Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
210 }
211 "TIMESTAMP_NS" | "TIMESTAMPNANOSECOND" => {
212 Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
213 }
214 "INT2" => Some(DataType::SmallInt(None)),
216 "INT4" => Some(DataType::Int(None)),
217 "INT8" => Some(DataType::BigInt(None)),
218 "INT16" => Some(DataType::SmallInt(None)),
219 "INT32" => Some(DataType::Int(None)),
220 "INT64" => Some(DataType::BigInt(None)),
221 "UINT8" => Some(DataType::TinyIntUnsigned(None)),
222 "UINT16" => Some(DataType::SmallIntUnsigned(None)),
223 "UINT32" => Some(DataType::IntUnsigned(None)),
224 "UINT64" => Some(DataType::BigIntUnsigned(None)),
225 "FLOAT4" => Some(DataType::Float(None)),
226 "FLOAT8" => Some(DataType::Double(ExactNumberInfo::None)),
227 "FLOAT32" => Some(DataType::Float(None)),
228 "FLOAT64" => Some(DataType::Double(ExactNumberInfo::None)),
229 "TINYTEXT" | "MEDIUMTEXT" | "LONGTEXT" => Some(DataType::Text),
231 _ => None,
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use sqlparser::dialect::GenericDialect;
238
239 use super::*;
240 use crate::parser::{ParseOptions, ParserContext};
241 use crate::statements::transform_statements;
242
243 #[test]
244 fn test_get_data_type_by_alias_name() {
245 assert_eq!(
246 get_data_type_by_alias_name("float64"),
247 Some(DataType::Double(ExactNumberInfo::None))
248 );
249 assert_eq!(
250 get_data_type_by_alias_name("Float64"),
251 Some(DataType::Double(ExactNumberInfo::None))
252 );
253 assert_eq!(
254 get_data_type_by_alias_name("FLOAT64"),
255 Some(DataType::Double(ExactNumberInfo::None))
256 );
257 assert_eq!(
258 get_data_type_by_alias_name("float32"),
259 Some(DataType::Float(None))
260 );
261 assert_eq!(
262 get_data_type_by_alias_name("float8"),
263 Some(DataType::Double(ExactNumberInfo::None))
264 );
265 assert_eq!(
266 get_data_type_by_alias_name("float4"),
267 Some(DataType::Float(None))
268 );
269 assert_eq!(
270 get_data_type_by_alias_name("int8"),
271 Some(DataType::BigInt(None))
272 );
273 assert_eq!(
274 get_data_type_by_alias_name("int4"),
275 Some(DataType::Int(None))
276 );
277 assert_eq!(
278 get_data_type_by_alias_name("int2"),
279 Some(DataType::SmallInt(None))
280 );
281 assert_eq!(
282 get_data_type_by_alias_name("INT16"),
283 Some(DataType::SmallInt(None))
284 );
285 assert_eq!(
286 get_data_type_by_alias_name("INT32"),
287 Some(DataType::Int(None))
288 );
289 assert_eq!(
290 get_data_type_by_alias_name("INT64"),
291 Some(DataType::BigInt(None))
292 );
293 assert_eq!(
294 get_data_type_by_alias_name("Uint8"),
295 Some(DataType::TinyIntUnsigned(None))
296 );
297 assert_eq!(
298 get_data_type_by_alias_name("UINT16"),
299 Some(DataType::SmallIntUnsigned(None))
300 );
301 assert_eq!(
302 get_data_type_by_alias_name("UINT32"),
303 Some(DataType::IntUnsigned(None))
304 );
305 assert_eq!(
306 get_data_type_by_alias_name("uint64"),
307 Some(DataType::BigIntUnsigned(None))
308 );
309
310 assert_eq!(
311 get_data_type_by_alias_name("TimestampSecond"),
312 Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
313 );
314 assert_eq!(
315 get_data_type_by_alias_name("Timestamp_s"),
316 Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
317 );
318 assert_eq!(
319 get_data_type_by_alias_name("Timestamp_sec"),
320 Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
321 );
322
323 assert_eq!(
324 get_data_type_by_alias_name("TimestampMilliSecond"),
325 Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
326 );
327 assert_eq!(
328 get_data_type_by_alias_name("Timestamp_ms"),
329 Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
330 );
331
332 assert_eq!(
333 get_data_type_by_alias_name("TimestampMicroSecond"),
334 Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
335 );
336 assert_eq!(
337 get_data_type_by_alias_name("Timestamp_us"),
338 Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
339 );
340
341 assert_eq!(
342 get_data_type_by_alias_name("TimestampNanoSecond"),
343 Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
344 );
345 assert_eq!(
346 get_data_type_by_alias_name("Timestamp_ns"),
347 Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
348 );
349 assert_eq!(
350 get_data_type_by_alias_name("TinyText"),
351 Some(DataType::Text)
352 );
353 assert_eq!(
354 get_data_type_by_alias_name("MediumText"),
355 Some(DataType::Text)
356 );
357 assert_eq!(
358 get_data_type_by_alias_name("LongText"),
359 Some(DataType::Text)
360 );
361 }
362
363 fn test_timestamp_alias(alias: &str, expected: &str) {
364 let sql = format!("SELECT TIMESTAMP '2020-01-01 01:23:45.12345678'::{alias}");
365 let mut stmts =
366 ParserContext::create_with_dialect(&sql, &GenericDialect {}, ParseOptions::default())
367 .unwrap();
368 transform_statements(&mut stmts).unwrap();
369
370 match &stmts[0] {
371 Statement::Query(q) => assert_eq!(
372 format!(
373 "SELECT arrow_cast(TIMESTAMP '2020-01-01 01:23:45.12345678', 'Timestamp({expected}, None)')"
374 ),
375 q.to_string()
376 ),
377 _ => unreachable!(),
378 }
379 }
380
381 fn test_timestamp_precision_type(precision: i32, expected: &str) {
382 test_timestamp_alias(&format!("Timestamp({precision})"), expected);
383 }
384
385 #[test]
386 fn test_boolean_alias() {
387 let sql = "CREATE TABLE test(b bool, ts TIMESTAMP TIME INDEX)";
388 let mut stmts =
389 ParserContext::create_with_dialect(sql, &GenericDialect {}, ParseOptions::default())
390 .unwrap();
391 transform_statements(&mut stmts).unwrap();
392
393 match &stmts[0] {
394 Statement::CreateTable(c) => assert_eq!(
395 "CREATE TABLE test (\n b BOOLEAN,\n ts TIMESTAMP NOT NULL,\n TIME INDEX (ts)\n)\nENGINE=mito\n",
396 c.to_string()
397 ),
398 _ => unreachable!(),
399 }
400 }
401
402 #[test]
403 fn test_transform_timestamp_alias() {
404 test_timestamp_alias("TimestampSecond", "Second");
406 test_timestamp_alias("Timestamp_s", "Second");
407 test_timestamp_alias("TimestampMillisecond", "Millisecond");
408 test_timestamp_alias("Timestamp_ms", "Millisecond");
409 test_timestamp_alias("TimestampMicrosecond", "Microsecond");
410 test_timestamp_alias("Timestamp_us", "Microsecond");
411 test_timestamp_alias("TimestampNanosecond", "Nanosecond");
412 test_timestamp_alias("Timestamp_ns", "Nanosecond");
413 test_timestamp_precision_type(0, "Second");
415 test_timestamp_precision_type(3, "Millisecond");
416 test_timestamp_precision_type(6, "Microsecond");
417 test_timestamp_precision_type(9, "Nanosecond");
418 }
419
420 #[test]
421 fn test_create_sql_with_type_alias() {
422 let sql = r#"
423CREATE TABLE data_types (
424 s string,
425 tt tinytext,
426 mt mediumtext,
427 lt longtext,
428 i2 int2,
429 i4 int4,
430 i8 int8,
431 sint int16,
432 i int32,
433 bint int64,
434 v varchar,
435 f4 float4,
436 f8 float8,
437 f float32,
438 d float64,
439 b boolean,
440 vb varbinary,
441 dt date,
442 dtt datetime,
443 ts0 TimestampSecond,
444 ts3 TimestampMillisecond,
445 ts6 TimestampMicrosecond,
446 ts9 TimestampNanosecond DEFAULT CURRENT_TIMESTAMP TIME INDEX,
447 PRIMARY KEY(s));"#;
448
449 let mut stmts =
450 ParserContext::create_with_dialect(sql, &GenericDialect {}, ParseOptions::default())
451 .unwrap();
452 transform_statements(&mut stmts).unwrap();
453
454 match &stmts[0] {
455 Statement::CreateTable(c) => {
456 let expected = r#"CREATE TABLE data_types (
457 s STRING,
458 tt TINYTEXT,
459 mt MEDIUMTEXT,
460 lt LONGTEXT,
461 i2 SMALLINT,
462 i4 INT,
463 i8 BIGINT,
464 sint SMALLINT,
465 i INT,
466 bint BIGINT,
467 v VARCHAR,
468 f4 FLOAT,
469 f8 DOUBLE,
470 f FLOAT,
471 d DOUBLE,
472 b BOOLEAN,
473 vb VARBINARY,
474 dt DATE,
475 dtt TIMESTAMP(6),
476 ts0 TIMESTAMP(0),
477 ts3 TIMESTAMP(3),
478 ts6 TIMESTAMP(6),
479 ts9 TIMESTAMP(9) DEFAULT CURRENT_TIMESTAMP NOT NULL,
480 TIME INDEX (ts9),
481 PRIMARY KEY (s)
482)
483ENGINE=mito
484"#;
485
486 assert_eq!(expected, c.to_string());
487 }
488 _ => unreachable!(),
489 }
490 }
491}