1use std::ops::ControlFlow;
16
17use datatypes::data_type::DataType as GreptimeDataType;
18use sqlparser::ast::{
19 DataType, ExactNumberInfo, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgumentList,
20 Ident, ObjectName, Value,
21};
22
23use crate::ast::ObjectNamePartExt;
24use crate::error::Result;
25use crate::statements::alter::AlterTableOperation;
26use crate::statements::create::{CreateExternalTable, CreateTable};
27use crate::statements::statement::Statement;
28use crate::statements::transform::TransformRule;
29use crate::statements::{TimezoneInfo, sql_data_type_to_concrete_data_type};
30
31pub(crate) struct TypeAliasTransformRule;
43
44impl TransformRule for TypeAliasTransformRule {
45 fn visit_statement(&self, stmt: &mut Statement) -> Result<()> {
46 match stmt {
47 Statement::CreateTable(CreateTable { columns, .. }) => {
48 columns
49 .iter_mut()
50 .for_each(|column| replace_type_alias(column.mut_data_type()));
51 }
52 Statement::CreateExternalTable(CreateExternalTable { columns, .. }) => {
53 columns
54 .iter_mut()
55 .for_each(|column| replace_type_alias(column.mut_data_type()));
56 }
57 Statement::AlterTable(alter_table) => {
58 if let AlterTableOperation::ModifyColumnType { target_type, .. } =
59 alter_table.alter_operation_mut()
60 {
61 replace_type_alias(target_type)
62 } else if let AlterTableOperation::AddColumns { add_columns, .. } =
63 alter_table.alter_operation_mut()
64 {
65 for add_column in add_columns {
66 replace_type_alias(&mut add_column.column_def.data_type);
67 }
68 }
69 }
70 _ => {}
71 }
72
73 Ok(())
74 }
75
76 fn visit_expr(&self, expr: &mut Expr) -> ControlFlow<()> {
77 fn cast_expr_to_arrow_cast_func(expr: Expr, cast_type: String) -> Function {
78 Function {
79 name: ObjectName::from(vec![Ident::new("arrow_cast")]),
80 args: sqlparser::ast::FunctionArguments::List(FunctionArgumentList {
81 args: vec![
82 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)),
83 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(
84 Value::SingleQuotedString(cast_type).into(),
85 ))),
86 ],
87 duplicate_treatment: None,
88 clauses: vec![],
89 }),
90 filter: None,
91 null_treatment: None,
92 over: None,
93 parameters: sqlparser::ast::FunctionArguments::None,
94 within_group: vec![],
95 uses_odbc_syntax: false,
96 }
97 }
98
99 match expr {
100 Expr::Cast {
105 expr: cast_expr,
106 data_type,
107 ..
108 } if get_type_by_alias(data_type).is_some() => {
109 let new_type = get_type_by_alias(data_type).unwrap();
111 if let Ok(new_type) = sql_data_type_to_concrete_data_type(&new_type) {
112 *expr = Expr::Function(cast_expr_to_arrow_cast_func(
113 (**cast_expr).clone(),
114 new_type.as_arrow_type().to_string(),
115 ));
116 }
117 }
118
119 Expr::Cast {
122 data_type: DataType::Timestamp(precision, zone),
123 expr: cast_expr,
124 ..
125 } => {
126 if let Ok(concrete_type) =
127 sql_data_type_to_concrete_data_type(&DataType::Timestamp(*precision, *zone))
128 {
129 let new_type = concrete_type.as_arrow_type();
130 *expr = Expr::Function(cast_expr_to_arrow_cast_func(
131 (**cast_expr).clone(),
132 new_type.to_string(),
133 ));
134 }
135 }
136
137 _ => {}
139 }
140
141 ControlFlow::<()>::Continue(())
142 }
143}
144
145fn replace_type_alias(data_type: &mut DataType) {
146 if let Some(new_type) = get_type_by_alias(data_type) {
147 *data_type = new_type;
148 }
149}
150
151pub(crate) fn get_type_by_alias(data_type: &DataType) -> Option<DataType> {
155 match data_type {
156 DataType::Custom(name, tokens) if name.0.len() == 1 && tokens.is_empty() => {
160 get_data_type_by_alias_name(name.0[0].to_string_unquoted().as_str())
161 }
162 DataType::Int8(None) => Some(DataType::TinyInt(None)),
163 DataType::Int16 => Some(DataType::SmallInt(None)),
164 DataType::Int32 => Some(DataType::Int(None)),
165 DataType::Int64 => Some(DataType::BigInt(None)),
166 DataType::UInt8 => Some(DataType::TinyIntUnsigned(None)),
167 DataType::UInt16 => Some(DataType::SmallIntUnsigned(None)),
168 DataType::UInt32 => Some(DataType::IntUnsigned(None)),
169 DataType::UInt64 => Some(DataType::BigIntUnsigned(None)),
170 DataType::Float32 => Some(DataType::Float(None)),
171 DataType::Float64 => Some(DataType::Double(ExactNumberInfo::None)),
172 DataType::Bool => Some(DataType::Boolean),
173 DataType::Datetime(_) => Some(DataType::Timestamp(Some(6), TimezoneInfo::None)),
174 _ => None,
175 }
176}
177
178pub(crate) fn get_data_type_by_alias_name(name: &str) -> Option<DataType> {
186 match name.to_uppercase().as_ref() {
187 "TIMESTAMP_S" | "TIMESTAMP_SEC" | "TIMESTAMPSECOND" => {
189 Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
190 }
191
192 "TIMESTAMP_MS" | "TIMESTAMPMILLISECOND" => {
193 Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
194 }
195 "TIMESTAMP_US" | "TIMESTAMPMICROSECOND" | "DATETIME" => {
196 Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
197 }
198 "TIMESTAMP_NS" | "TIMESTAMPNANOSECOND" => {
199 Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
200 }
201 "INT8" => Some(DataType::TinyInt(None)),
204 "INT16" => Some(DataType::SmallInt(None)),
205 "INT32" => Some(DataType::Int(None)),
206 "INT64" => Some(DataType::BigInt(None)),
207 "UINT8" => Some(DataType::TinyIntUnsigned(None)),
208 "UINT16" => Some(DataType::SmallIntUnsigned(None)),
209 "UINT32" => Some(DataType::IntUnsigned(None)),
210 "UINT64" => Some(DataType::BigIntUnsigned(None)),
211 "FLOAT32" => Some(DataType::Float(None)),
212 "FLOAT64" => Some(DataType::Double(ExactNumberInfo::None)),
213 "TINYTEXT" | "MEDIUMTEXT" | "LONGTEXT" => Some(DataType::Text),
215 _ => None,
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use sqlparser::dialect::GenericDialect;
222
223 use super::*;
224 use crate::parser::{ParseOptions, ParserContext};
225 use crate::statements::transform_statements;
226
227 #[test]
228 fn test_get_data_type_by_alias_name() {
229 assert_eq!(
230 get_data_type_by_alias_name("float64"),
231 Some(DataType::Double(ExactNumberInfo::None))
232 );
233 assert_eq!(
234 get_data_type_by_alias_name("Float64"),
235 Some(DataType::Double(ExactNumberInfo::None))
236 );
237 assert_eq!(
238 get_data_type_by_alias_name("FLOAT64"),
239 Some(DataType::Double(ExactNumberInfo::None))
240 );
241
242 assert_eq!(
243 get_data_type_by_alias_name("float32"),
244 Some(DataType::Float(None))
245 );
246 assert_eq!(
247 get_data_type_by_alias_name("int8"),
248 Some(DataType::TinyInt(None))
249 );
250 assert_eq!(
251 get_data_type_by_alias_name("INT16"),
252 Some(DataType::SmallInt(None))
253 );
254 assert_eq!(
255 get_data_type_by_alias_name("INT32"),
256 Some(DataType::Int(None))
257 );
258 assert_eq!(
259 get_data_type_by_alias_name("INT64"),
260 Some(DataType::BigInt(None))
261 );
262 assert_eq!(
263 get_data_type_by_alias_name("Uint8"),
264 Some(DataType::TinyIntUnsigned(None))
265 );
266 assert_eq!(
267 get_data_type_by_alias_name("UINT16"),
268 Some(DataType::SmallIntUnsigned(None))
269 );
270 assert_eq!(
271 get_data_type_by_alias_name("UINT32"),
272 Some(DataType::IntUnsigned(None))
273 );
274 assert_eq!(
275 get_data_type_by_alias_name("uint64"),
276 Some(DataType::BigIntUnsigned(None))
277 );
278
279 assert_eq!(
280 get_data_type_by_alias_name("TimestampSecond"),
281 Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
282 );
283 assert_eq!(
284 get_data_type_by_alias_name("Timestamp_s"),
285 Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
286 );
287 assert_eq!(
288 get_data_type_by_alias_name("Timestamp_sec"),
289 Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
290 );
291
292 assert_eq!(
293 get_data_type_by_alias_name("TimestampMilliSecond"),
294 Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
295 );
296 assert_eq!(
297 get_data_type_by_alias_name("Timestamp_ms"),
298 Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
299 );
300
301 assert_eq!(
302 get_data_type_by_alias_name("TimestampMicroSecond"),
303 Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
304 );
305 assert_eq!(
306 get_data_type_by_alias_name("Timestamp_us"),
307 Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
308 );
309
310 assert_eq!(
311 get_data_type_by_alias_name("TimestampNanoSecond"),
312 Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
313 );
314 assert_eq!(
315 get_data_type_by_alias_name("Timestamp_ns"),
316 Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
317 );
318 assert_eq!(
319 get_data_type_by_alias_name("TinyText"),
320 Some(DataType::Text)
321 );
322 assert_eq!(
323 get_data_type_by_alias_name("MediumText"),
324 Some(DataType::Text)
325 );
326 assert_eq!(
327 get_data_type_by_alias_name("LongText"),
328 Some(DataType::Text)
329 );
330 }
331
332 fn test_timestamp_alias(alias: &str, expected: &str) {
333 let sql = format!("SELECT TIMESTAMP '2020-01-01 01:23:45.12345678'::{alias}");
334 let mut stmts =
335 ParserContext::create_with_dialect(&sql, &GenericDialect {}, ParseOptions::default())
336 .unwrap();
337 transform_statements(&mut stmts).unwrap();
338
339 match &stmts[0] {
340 Statement::Query(q) => assert_eq!(
341 format!(
342 "SELECT arrow_cast(TIMESTAMP '2020-01-01 01:23:45.12345678', 'Timestamp({expected}, None)')"
343 ),
344 q.to_string()
345 ),
346 _ => unreachable!(),
347 }
348 }
349
350 fn test_timestamp_precision_type(precision: i32, expected: &str) {
351 test_timestamp_alias(&format!("Timestamp({precision})"), expected);
352 }
353
354 #[test]
355 fn test_boolean_alias() {
356 let sql = "CREATE TABLE test(b bool, ts TIMESTAMP TIME INDEX)";
357 let mut stmts =
358 ParserContext::create_with_dialect(sql, &GenericDialect {}, ParseOptions::default())
359 .unwrap();
360 transform_statements(&mut stmts).unwrap();
361
362 match &stmts[0] {
363 Statement::CreateTable(c) => assert_eq!(
364 "CREATE TABLE test (\n b BOOLEAN,\n ts TIMESTAMP NOT NULL,\n TIME INDEX (ts)\n)\nENGINE=mito\n",
365 c.to_string()
366 ),
367 _ => unreachable!(),
368 }
369 }
370
371 #[test]
372 fn test_transform_timestamp_alias() {
373 test_timestamp_alias("TimestampSecond", "Second");
375 test_timestamp_alias("Timestamp_s", "Second");
376 test_timestamp_alias("TimestampMillisecond", "Millisecond");
377 test_timestamp_alias("Timestamp_ms", "Millisecond");
378 test_timestamp_alias("TimestampMicrosecond", "Microsecond");
379 test_timestamp_alias("Timestamp_us", "Microsecond");
380 test_timestamp_alias("TimestampNanosecond", "Nanosecond");
381 test_timestamp_alias("Timestamp_ns", "Nanosecond");
382 test_timestamp_precision_type(0, "Second");
384 test_timestamp_precision_type(3, "Millisecond");
385 test_timestamp_precision_type(6, "Microsecond");
386 test_timestamp_precision_type(9, "Nanosecond");
387 }
388
389 #[test]
390 fn test_create_sql_with_type_alias() {
391 let sql = r#"
392CREATE TABLE data_types (
393 s string,
394 tt tinytext,
395 mt mediumtext,
396 lt longtext,
397 tint int8,
398 sint int16,
399 i int32,
400 bint int64,
401 v varchar,
402 f float32,
403 d float64,
404 b boolean,
405 vb varbinary,
406 dt date,
407 dtt datetime,
408 ts0 TimestampSecond,
409 ts3 TimestampMillisecond,
410 ts6 TimestampMicrosecond,
411 ts9 TimestampNanosecond DEFAULT CURRENT_TIMESTAMP TIME INDEX,
412 PRIMARY KEY(s));"#;
413
414 let mut stmts =
415 ParserContext::create_with_dialect(sql, &GenericDialect {}, ParseOptions::default())
416 .unwrap();
417 transform_statements(&mut stmts).unwrap();
418
419 match &stmts[0] {
420 Statement::CreateTable(c) => {
421 let expected = r#"CREATE TABLE data_types (
422 s STRING,
423 tt TINYTEXT,
424 mt MEDIUMTEXT,
425 lt LONGTEXT,
426 tint TINYINT,
427 sint SMALLINT,
428 i INT,
429 bint BIGINT,
430 v VARCHAR,
431 f FLOAT,
432 d DOUBLE,
433 b BOOLEAN,
434 vb VARBINARY,
435 dt DATE,
436 dtt TIMESTAMP(6),
437 ts0 TIMESTAMP(0),
438 ts3 TIMESTAMP(3),
439 ts6 TIMESTAMP(6),
440 ts9 TIMESTAMP(9) DEFAULT CURRENT_TIMESTAMP NOT NULL,
441 TIME INDEX (ts9),
442 PRIMARY KEY (s)
443)
444ENGINE=mito
445"#;
446
447 assert_eq!(expected, c.to_string());
448 }
449 _ => unreachable!(),
450 }
451 }
452}