1use std::ops::ControlFlow;
16
17use datatypes::data_type::DataType as GreptimeDataType;
18use sqlparser::ast::{
19 DataType, ExactNumberInfo, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgumentList,
20 Ident, ObjectName, Value,
21};
22
23use crate::error::Result;
24use crate::statements::alter::AlterTableOperation;
25use crate::statements::create::{CreateExternalTable, CreateTable};
26use crate::statements::statement::Statement;
27use crate::statements::transform::TransformRule;
28use crate::statements::{sql_data_type_to_concrete_data_type, TimezoneInfo};
29
30pub(crate) struct TypeAliasTransformRule;
42
43impl TransformRule for TypeAliasTransformRule {
44 fn visit_statement(&self, stmt: &mut Statement) -> Result<()> {
45 match stmt {
46 Statement::CreateTable(CreateTable { columns, .. }) => {
47 columns
48 .iter_mut()
49 .for_each(|column| replace_type_alias(column.mut_data_type()));
50 }
51 Statement::CreateExternalTable(CreateExternalTable { columns, .. }) => {
52 columns
53 .iter_mut()
54 .for_each(|column| replace_type_alias(column.mut_data_type()));
55 }
56 Statement::AlterTable(alter_table) => {
57 if let AlterTableOperation::ModifyColumnType { target_type, .. } =
58 alter_table.alter_operation_mut()
59 {
60 replace_type_alias(target_type)
61 } else if let AlterTableOperation::AddColumns { add_columns, .. } =
62 alter_table.alter_operation_mut()
63 {
64 for add_column in add_columns {
65 replace_type_alias(&mut add_column.column_def.data_type);
66 }
67 }
68 }
69 _ => {}
70 }
71
72 Ok(())
73 }
74
75 fn visit_expr(&self, expr: &mut Expr) -> ControlFlow<()> {
76 fn cast_expr_to_arrow_cast_func(expr: Expr, cast_type: String) -> Function {
77 Function {
78 name: ObjectName(vec![Ident::new("arrow_cast")]),
79 args: sqlparser::ast::FunctionArguments::List(FunctionArgumentList {
80 args: vec![
81 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)),
82 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(
83 Value::SingleQuotedString(cast_type),
84 ))),
85 ],
86 duplicate_treatment: None,
87 clauses: vec![],
88 }),
89 filter: None,
90 null_treatment: None,
91 over: None,
92 parameters: sqlparser::ast::FunctionArguments::None,
93 within_group: vec![],
94 uses_odbc_syntax: false,
95 }
96 }
97
98 match expr {
99 Expr::Cast {
104 expr: cast_expr,
105 data_type,
106 ..
107 } if get_type_by_alias(data_type).is_some() => {
108 let new_type = get_type_by_alias(data_type).unwrap();
110 if let Ok(new_type) = sql_data_type_to_concrete_data_type(&new_type) {
111 *expr = Expr::Function(cast_expr_to_arrow_cast_func(
112 (**cast_expr).clone(),
113 new_type.as_arrow_type().to_string(),
114 ));
115 }
116 }
117
118 Expr::Cast {
121 data_type: DataType::Timestamp(precision, zone),
122 expr: cast_expr,
123 ..
124 } => {
125 if let Ok(concrete_type) =
126 sql_data_type_to_concrete_data_type(&DataType::Timestamp(*precision, *zone))
127 {
128 let new_type = concrete_type.as_arrow_type();
129 *expr = Expr::Function(cast_expr_to_arrow_cast_func(
130 (**cast_expr).clone(),
131 new_type.to_string(),
132 ));
133 }
134 }
135
136 _ => {}
138 }
139
140 ControlFlow::<()>::Continue(())
141 }
142}
143
144fn replace_type_alias(data_type: &mut DataType) {
145 if let Some(new_type) = get_type_by_alias(data_type) {
146 *data_type = new_type;
147 }
148}
149
150pub(crate) fn get_type_by_alias(data_type: &DataType) -> Option<DataType> {
154 match data_type {
155 DataType::Custom(name, tokens) if name.0.len() == 1 && tokens.is_empty() => {
159 get_data_type_by_alias_name(name.0[0].value.as_str())
160 }
161 DataType::Int8(None) => Some(DataType::TinyInt(None)),
162 DataType::Int16 => Some(DataType::SmallInt(None)),
163 DataType::Int32 => Some(DataType::Int(None)),
164 DataType::Int64 => Some(DataType::BigInt(None)),
165 DataType::UInt8 => Some(DataType::UnsignedTinyInt(None)),
166 DataType::UInt16 => Some(DataType::UnsignedSmallInt(None)),
167 DataType::UInt32 => Some(DataType::UnsignedInt(None)),
168 DataType::UInt64 => Some(DataType::UnsignedBigInt(None)),
169 DataType::Float32 => Some(DataType::Float(None)),
170 DataType::Float64 => Some(DataType::Double(ExactNumberInfo::None)),
171 DataType::Bool => Some(DataType::Boolean),
172 DataType::Datetime(_) => Some(DataType::Timestamp(Some(6), TimezoneInfo::None)),
173 _ => None,
174 }
175}
176
177pub(crate) fn get_data_type_by_alias_name(name: &str) -> Option<DataType> {
185 match name.to_uppercase().as_ref() {
186 "TIMESTAMP_S" | "TIMESTAMP_SEC" | "TIMESTAMPSECOND" => {
188 Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
189 }
190
191 "TIMESTAMP_MS" | "TIMESTAMPMILLISECOND" => {
192 Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
193 }
194 "TIMESTAMP_US" | "TIMESTAMPMICROSECOND" | "DATETIME" => {
195 Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
196 }
197 "TIMESTAMP_NS" | "TIMESTAMPNANOSECOND" => {
198 Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
199 }
200 "INT8" => Some(DataType::TinyInt(None)),
203 "INT16" => Some(DataType::SmallInt(None)),
204 "INT32" => Some(DataType::Int(None)),
205 "INT64" => Some(DataType::BigInt(None)),
206 "UINT8" => Some(DataType::UnsignedTinyInt(None)),
207 "UINT16" => Some(DataType::UnsignedSmallInt(None)),
208 "UINT32" => Some(DataType::UnsignedInt(None)),
209 "UINT64" => Some(DataType::UnsignedBigInt(None)),
210 "FLOAT32" => Some(DataType::Float(None)),
211 "FLOAT64" => Some(DataType::Double(ExactNumberInfo::None)),
212 "TINYTEXT" | "MEDIUMTEXT" | "LONGTEXT" => Some(DataType::Text),
214 _ => None,
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use sqlparser::dialect::GenericDialect;
221
222 use super::*;
223 use crate::parser::{ParseOptions, ParserContext};
224 use crate::statements::transform_statements;
225
226 #[test]
227 fn test_get_data_type_by_alias_name() {
228 assert_eq!(
229 get_data_type_by_alias_name("float64"),
230 Some(DataType::Double(ExactNumberInfo::None))
231 );
232 assert_eq!(
233 get_data_type_by_alias_name("Float64"),
234 Some(DataType::Double(ExactNumberInfo::None))
235 );
236 assert_eq!(
237 get_data_type_by_alias_name("FLOAT64"),
238 Some(DataType::Double(ExactNumberInfo::None))
239 );
240
241 assert_eq!(
242 get_data_type_by_alias_name("float32"),
243 Some(DataType::Float(None))
244 );
245 assert_eq!(
246 get_data_type_by_alias_name("int8"),
247 Some(DataType::TinyInt(None))
248 );
249 assert_eq!(
250 get_data_type_by_alias_name("INT16"),
251 Some(DataType::SmallInt(None))
252 );
253 assert_eq!(
254 get_data_type_by_alias_name("INT32"),
255 Some(DataType::Int(None))
256 );
257 assert_eq!(
258 get_data_type_by_alias_name("INT64"),
259 Some(DataType::BigInt(None))
260 );
261 assert_eq!(
262 get_data_type_by_alias_name("Uint8"),
263 Some(DataType::UnsignedTinyInt(None))
264 );
265 assert_eq!(
266 get_data_type_by_alias_name("UINT16"),
267 Some(DataType::UnsignedSmallInt(None))
268 );
269 assert_eq!(
270 get_data_type_by_alias_name("UINT32"),
271 Some(DataType::UnsignedInt(None))
272 );
273 assert_eq!(
274 get_data_type_by_alias_name("uint64"),
275 Some(DataType::UnsignedBigInt(None))
276 );
277
278 assert_eq!(
279 get_data_type_by_alias_name("TimestampSecond"),
280 Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
281 );
282 assert_eq!(
283 get_data_type_by_alias_name("Timestamp_s"),
284 Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
285 );
286 assert_eq!(
287 get_data_type_by_alias_name("Timestamp_sec"),
288 Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
289 );
290
291 assert_eq!(
292 get_data_type_by_alias_name("TimestampMilliSecond"),
293 Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
294 );
295 assert_eq!(
296 get_data_type_by_alias_name("Timestamp_ms"),
297 Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
298 );
299
300 assert_eq!(
301 get_data_type_by_alias_name("TimestampMicroSecond"),
302 Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
303 );
304 assert_eq!(
305 get_data_type_by_alias_name("Timestamp_us"),
306 Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
307 );
308
309 assert_eq!(
310 get_data_type_by_alias_name("TimestampNanoSecond"),
311 Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
312 );
313 assert_eq!(
314 get_data_type_by_alias_name("Timestamp_ns"),
315 Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
316 );
317 assert_eq!(
318 get_data_type_by_alias_name("TinyText"),
319 Some(DataType::Text)
320 );
321 assert_eq!(
322 get_data_type_by_alias_name("MediumText"),
323 Some(DataType::Text)
324 );
325 assert_eq!(
326 get_data_type_by_alias_name("LongText"),
327 Some(DataType::Text)
328 );
329 }
330
331 fn test_timestamp_alias(alias: &str, expected: &str) {
332 let sql = format!("SELECT TIMESTAMP '2020-01-01 01:23:45.12345678'::{alias}");
333 let mut stmts =
334 ParserContext::create_with_dialect(&sql, &GenericDialect {}, ParseOptions::default())
335 .unwrap();
336 transform_statements(&mut stmts).unwrap();
337
338 match &stmts[0] {
339 Statement::Query(q) => assert_eq!(format!("SELECT arrow_cast(TIMESTAMP '2020-01-01 01:23:45.12345678', 'Timestamp({expected}, None)')"), q.to_string()),
340 _ => unreachable!(),
341 }
342 }
343
344 fn test_timestamp_precision_type(precision: i32, expected: &str) {
345 test_timestamp_alias(&format!("Timestamp({precision})"), expected);
346 }
347
348 #[test]
349 fn test_boolean_alias() {
350 let sql = "CREATE TABLE test(b bool, ts TIMESTAMP TIME INDEX)";
351 let mut stmts =
352 ParserContext::create_with_dialect(sql, &GenericDialect {}, ParseOptions::default())
353 .unwrap();
354 transform_statements(&mut stmts).unwrap();
355
356 match &stmts[0] {
357 Statement::CreateTable(c) => assert_eq!("CREATE TABLE test (\n b BOOLEAN,\n ts TIMESTAMP NOT NULL,\n TIME INDEX (ts)\n)\nENGINE=mito\n", c.to_string()),
358 _ => unreachable!(),
359 }
360 }
361
362 #[test]
363 fn test_transform_timestamp_alias() {
364 test_timestamp_alias("TimestampSecond", "Second");
366 test_timestamp_alias("Timestamp_s", "Second");
367 test_timestamp_alias("TimestampMillisecond", "Millisecond");
368 test_timestamp_alias("Timestamp_ms", "Millisecond");
369 test_timestamp_alias("TimestampMicrosecond", "Microsecond");
370 test_timestamp_alias("Timestamp_us", "Microsecond");
371 test_timestamp_alias("TimestampNanosecond", "Nanosecond");
372 test_timestamp_alias("Timestamp_ns", "Nanosecond");
373 test_timestamp_precision_type(0, "Second");
375 test_timestamp_precision_type(3, "Millisecond");
376 test_timestamp_precision_type(6, "Microsecond");
377 test_timestamp_precision_type(9, "Nanosecond");
378 }
379
380 #[test]
381 fn test_create_sql_with_type_alias() {
382 let sql = r#"
383CREATE TABLE data_types (
384 s string,
385 tt tinytext,
386 mt mediumtext,
387 lt longtext,
388 tint int8,
389 sint int16,
390 i int32,
391 bint int64,
392 v varchar,
393 f float32,
394 d float64,
395 b boolean,
396 vb varbinary,
397 dt date,
398 dtt datetime,
399 ts0 TimestampSecond,
400 ts3 TimestampMillisecond,
401 ts6 TimestampMicrosecond,
402 ts9 TimestampNanosecond DEFAULT CURRENT_TIMESTAMP TIME INDEX,
403 PRIMARY KEY(s));"#;
404
405 let mut stmts =
406 ParserContext::create_with_dialect(sql, &GenericDialect {}, ParseOptions::default())
407 .unwrap();
408 transform_statements(&mut stmts).unwrap();
409
410 match &stmts[0] {
411 Statement::CreateTable(c) => {
412 let expected = r#"CREATE TABLE data_types (
413 s STRING,
414 tt TINYTEXT,
415 mt MEDIUMTEXT,
416 lt LONGTEXT,
417 tint TINYINT,
418 sint SMALLINT,
419 i INT,
420 bint BIGINT,
421 v VARCHAR,
422 f FLOAT,
423 d DOUBLE,
424 b BOOLEAN,
425 vb VARBINARY,
426 dt DATE,
427 dtt TIMESTAMP(6),
428 ts0 TIMESTAMP(0),
429 ts3 TIMESTAMP(3),
430 ts6 TIMESTAMP(6),
431 ts9 TIMESTAMP(9) DEFAULT CURRENT_TIMESTAMP NOT NULL,
432 TIME INDEX (ts9),
433 PRIMARY KEY (s)
434)
435ENGINE=mito
436"#;
437
438 assert_eq!(expected, c.to_string());
439 }
440 _ => unreachable!(),
441 }
442 }
443}