1use std::ops::ControlFlow;
16
17use datatypes::data_type::DataType as GreptimeDataType;
18use sqlparser::ast::{
19 DataType, ExactNumberInfo, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgumentList,
20 Ident, ObjectName, Value,
21};
22
23use crate::ast::ObjectNamePartExt;
24use crate::error::Result;
25use crate::statements::alter::AlterTableOperation;
26use crate::statements::create::{CreateExternalTable, CreateTable};
27use crate::statements::statement::Statement;
28use crate::statements::transform::TransformRule;
29use crate::statements::{sql_data_type_to_concrete_data_type, TimezoneInfo};
30
31pub(crate) struct TypeAliasTransformRule;
43
44impl TransformRule for TypeAliasTransformRule {
45 fn visit_statement(&self, stmt: &mut Statement) -> Result<()> {
46 match stmt {
47 Statement::CreateTable(CreateTable { columns, .. }) => {
48 columns
49 .iter_mut()
50 .for_each(|column| replace_type_alias(column.mut_data_type()));
51 }
52 Statement::CreateExternalTable(CreateExternalTable { columns, .. }) => {
53 columns
54 .iter_mut()
55 .for_each(|column| replace_type_alias(column.mut_data_type()));
56 }
57 Statement::AlterTable(alter_table) => {
58 if let AlterTableOperation::ModifyColumnType { target_type, .. } =
59 alter_table.alter_operation_mut()
60 {
61 replace_type_alias(target_type)
62 } else if let AlterTableOperation::AddColumns { add_columns, .. } =
63 alter_table.alter_operation_mut()
64 {
65 for add_column in add_columns {
66 replace_type_alias(&mut add_column.column_def.data_type);
67 }
68 }
69 }
70 _ => {}
71 }
72
73 Ok(())
74 }
75
76 fn visit_expr(&self, expr: &mut Expr) -> ControlFlow<()> {
77 fn cast_expr_to_arrow_cast_func(expr: Expr, cast_type: String) -> Function {
78 Function {
79 name: ObjectName::from(vec![Ident::new("arrow_cast")]),
80 args: sqlparser::ast::FunctionArguments::List(FunctionArgumentList {
81 args: vec![
82 FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)),
83 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(
84 Value::SingleQuotedString(cast_type).into(),
85 ))),
86 ],
87 duplicate_treatment: None,
88 clauses: vec![],
89 }),
90 filter: None,
91 null_treatment: None,
92 over: None,
93 parameters: sqlparser::ast::FunctionArguments::None,
94 within_group: vec![],
95 uses_odbc_syntax: false,
96 }
97 }
98
99 match expr {
100 Expr::Cast {
105 expr: cast_expr,
106 data_type,
107 ..
108 } if get_type_by_alias(data_type).is_some() => {
109 let new_type = get_type_by_alias(data_type).unwrap();
111 if let Ok(new_type) = sql_data_type_to_concrete_data_type(&new_type) {
112 *expr = Expr::Function(cast_expr_to_arrow_cast_func(
113 (**cast_expr).clone(),
114 new_type.as_arrow_type().to_string(),
115 ));
116 }
117 }
118
119 Expr::Cast {
122 data_type: DataType::Timestamp(precision, zone),
123 expr: cast_expr,
124 ..
125 } => {
126 if let Ok(concrete_type) =
127 sql_data_type_to_concrete_data_type(&DataType::Timestamp(*precision, *zone))
128 {
129 let new_type = concrete_type.as_arrow_type();
130 *expr = Expr::Function(cast_expr_to_arrow_cast_func(
131 (**cast_expr).clone(),
132 new_type.to_string(),
133 ));
134 }
135 }
136
137 _ => {}
139 }
140
141 ControlFlow::<()>::Continue(())
142 }
143}
144
145fn replace_type_alias(data_type: &mut DataType) {
146 if let Some(new_type) = get_type_by_alias(data_type) {
147 *data_type = new_type;
148 }
149}
150
151pub(crate) fn get_type_by_alias(data_type: &DataType) -> Option<DataType> {
155 match data_type {
156 DataType::Custom(name, tokens) if name.0.len() == 1 && tokens.is_empty() => {
160 get_data_type_by_alias_name(name.0[0].to_string_unquoted().as_str())
161 }
162 DataType::Int8(None) => Some(DataType::TinyInt(None)),
163 DataType::Int16 => Some(DataType::SmallInt(None)),
164 DataType::Int32 => Some(DataType::Int(None)),
165 DataType::Int64 => Some(DataType::BigInt(None)),
166 DataType::UInt8 => Some(DataType::TinyIntUnsigned(None)),
167 DataType::UInt16 => Some(DataType::SmallIntUnsigned(None)),
168 DataType::UInt32 => Some(DataType::IntUnsigned(None)),
169 DataType::UInt64 => Some(DataType::BigIntUnsigned(None)),
170 DataType::Float32 => Some(DataType::Float(None)),
171 DataType::Float64 => Some(DataType::Double(ExactNumberInfo::None)),
172 DataType::Bool => Some(DataType::Boolean),
173 DataType::Datetime(_) => Some(DataType::Timestamp(Some(6), TimezoneInfo::None)),
174 _ => None,
175 }
176}
177
178pub(crate) fn get_data_type_by_alias_name(name: &str) -> Option<DataType> {
186 match name.to_uppercase().as_ref() {
187 "TIMESTAMP_S" | "TIMESTAMP_SEC" | "TIMESTAMPSECOND" => {
189 Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
190 }
191
192 "TIMESTAMP_MS" | "TIMESTAMPMILLISECOND" => {
193 Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
194 }
195 "TIMESTAMP_US" | "TIMESTAMPMICROSECOND" | "DATETIME" => {
196 Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
197 }
198 "TIMESTAMP_NS" | "TIMESTAMPNANOSECOND" => {
199 Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
200 }
201 "INT8" => Some(DataType::TinyInt(None)),
204 "INT16" => Some(DataType::SmallInt(None)),
205 "INT32" => Some(DataType::Int(None)),
206 "INT64" => Some(DataType::BigInt(None)),
207 "UINT8" => Some(DataType::TinyIntUnsigned(None)),
208 "UINT16" => Some(DataType::SmallIntUnsigned(None)),
209 "UINT32" => Some(DataType::IntUnsigned(None)),
210 "UINT64" => Some(DataType::BigIntUnsigned(None)),
211 "FLOAT32" => Some(DataType::Float(None)),
212 "FLOAT64" => Some(DataType::Double(ExactNumberInfo::None)),
213 "TINYTEXT" | "MEDIUMTEXT" | "LONGTEXT" => Some(DataType::Text),
215 _ => None,
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use sqlparser::dialect::GenericDialect;
222
223 use super::*;
224 use crate::parser::{ParseOptions, ParserContext};
225 use crate::statements::transform_statements;
226
227 #[test]
228 fn test_get_data_type_by_alias_name() {
229 assert_eq!(
230 get_data_type_by_alias_name("float64"),
231 Some(DataType::Double(ExactNumberInfo::None))
232 );
233 assert_eq!(
234 get_data_type_by_alias_name("Float64"),
235 Some(DataType::Double(ExactNumberInfo::None))
236 );
237 assert_eq!(
238 get_data_type_by_alias_name("FLOAT64"),
239 Some(DataType::Double(ExactNumberInfo::None))
240 );
241
242 assert_eq!(
243 get_data_type_by_alias_name("float32"),
244 Some(DataType::Float(None))
245 );
246 assert_eq!(
247 get_data_type_by_alias_name("int8"),
248 Some(DataType::TinyInt(None))
249 );
250 assert_eq!(
251 get_data_type_by_alias_name("INT16"),
252 Some(DataType::SmallInt(None))
253 );
254 assert_eq!(
255 get_data_type_by_alias_name("INT32"),
256 Some(DataType::Int(None))
257 );
258 assert_eq!(
259 get_data_type_by_alias_name("INT64"),
260 Some(DataType::BigInt(None))
261 );
262 assert_eq!(
263 get_data_type_by_alias_name("Uint8"),
264 Some(DataType::TinyIntUnsigned(None))
265 );
266 assert_eq!(
267 get_data_type_by_alias_name("UINT16"),
268 Some(DataType::SmallIntUnsigned(None))
269 );
270 assert_eq!(
271 get_data_type_by_alias_name("UINT32"),
272 Some(DataType::IntUnsigned(None))
273 );
274 assert_eq!(
275 get_data_type_by_alias_name("uint64"),
276 Some(DataType::BigIntUnsigned(None))
277 );
278
279 assert_eq!(
280 get_data_type_by_alias_name("TimestampSecond"),
281 Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
282 );
283 assert_eq!(
284 get_data_type_by_alias_name("Timestamp_s"),
285 Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
286 );
287 assert_eq!(
288 get_data_type_by_alias_name("Timestamp_sec"),
289 Some(DataType::Timestamp(Some(0), TimezoneInfo::None))
290 );
291
292 assert_eq!(
293 get_data_type_by_alias_name("TimestampMilliSecond"),
294 Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
295 );
296 assert_eq!(
297 get_data_type_by_alias_name("Timestamp_ms"),
298 Some(DataType::Timestamp(Some(3), TimezoneInfo::None))
299 );
300
301 assert_eq!(
302 get_data_type_by_alias_name("TimestampMicroSecond"),
303 Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
304 );
305 assert_eq!(
306 get_data_type_by_alias_name("Timestamp_us"),
307 Some(DataType::Timestamp(Some(6), TimezoneInfo::None))
308 );
309
310 assert_eq!(
311 get_data_type_by_alias_name("TimestampNanoSecond"),
312 Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
313 );
314 assert_eq!(
315 get_data_type_by_alias_name("Timestamp_ns"),
316 Some(DataType::Timestamp(Some(9), TimezoneInfo::None))
317 );
318 assert_eq!(
319 get_data_type_by_alias_name("TinyText"),
320 Some(DataType::Text)
321 );
322 assert_eq!(
323 get_data_type_by_alias_name("MediumText"),
324 Some(DataType::Text)
325 );
326 assert_eq!(
327 get_data_type_by_alias_name("LongText"),
328 Some(DataType::Text)
329 );
330 }
331
332 fn test_timestamp_alias(alias: &str, expected: &str) {
333 let sql = format!("SELECT TIMESTAMP '2020-01-01 01:23:45.12345678'::{alias}");
334 let mut stmts =
335 ParserContext::create_with_dialect(&sql, &GenericDialect {}, ParseOptions::default())
336 .unwrap();
337 transform_statements(&mut stmts).unwrap();
338
339 match &stmts[0] {
340 Statement::Query(q) => assert_eq!(format!("SELECT arrow_cast(TIMESTAMP '2020-01-01 01:23:45.12345678', 'Timestamp({expected}, None)')"), q.to_string()),
341 _ => unreachable!(),
342 }
343 }
344
345 fn test_timestamp_precision_type(precision: i32, expected: &str) {
346 test_timestamp_alias(&format!("Timestamp({precision})"), expected);
347 }
348
349 #[test]
350 fn test_boolean_alias() {
351 let sql = "CREATE TABLE test(b bool, ts TIMESTAMP TIME INDEX)";
352 let mut stmts =
353 ParserContext::create_with_dialect(sql, &GenericDialect {}, ParseOptions::default())
354 .unwrap();
355 transform_statements(&mut stmts).unwrap();
356
357 match &stmts[0] {
358 Statement::CreateTable(c) => assert_eq!("CREATE TABLE test (\n b BOOLEAN,\n ts TIMESTAMP NOT NULL,\n TIME INDEX (ts)\n)\nENGINE=mito\n", c.to_string()),
359 _ => unreachable!(),
360 }
361 }
362
363 #[test]
364 fn test_transform_timestamp_alias() {
365 test_timestamp_alias("TimestampSecond", "Second");
367 test_timestamp_alias("Timestamp_s", "Second");
368 test_timestamp_alias("TimestampMillisecond", "Millisecond");
369 test_timestamp_alias("Timestamp_ms", "Millisecond");
370 test_timestamp_alias("TimestampMicrosecond", "Microsecond");
371 test_timestamp_alias("Timestamp_us", "Microsecond");
372 test_timestamp_alias("TimestampNanosecond", "Nanosecond");
373 test_timestamp_alias("Timestamp_ns", "Nanosecond");
374 test_timestamp_precision_type(0, "Second");
376 test_timestamp_precision_type(3, "Millisecond");
377 test_timestamp_precision_type(6, "Microsecond");
378 test_timestamp_precision_type(9, "Nanosecond");
379 }
380
381 #[test]
382 fn test_create_sql_with_type_alias() {
383 let sql = r#"
384CREATE TABLE data_types (
385 s string,
386 tt tinytext,
387 mt mediumtext,
388 lt longtext,
389 tint int8,
390 sint int16,
391 i int32,
392 bint int64,
393 v varchar,
394 f float32,
395 d float64,
396 b boolean,
397 vb varbinary,
398 dt date,
399 dtt datetime,
400 ts0 TimestampSecond,
401 ts3 TimestampMillisecond,
402 ts6 TimestampMicrosecond,
403 ts9 TimestampNanosecond DEFAULT CURRENT_TIMESTAMP TIME INDEX,
404 PRIMARY KEY(s));"#;
405
406 let mut stmts =
407 ParserContext::create_with_dialect(sql, &GenericDialect {}, ParseOptions::default())
408 .unwrap();
409 transform_statements(&mut stmts).unwrap();
410
411 match &stmts[0] {
412 Statement::CreateTable(c) => {
413 let expected = r#"CREATE TABLE data_types (
414 s STRING,
415 tt TINYTEXT,
416 mt MEDIUMTEXT,
417 lt LONGTEXT,
418 tint TINYINT,
419 sint SMALLINT,
420 i INT,
421 bint BIGINT,
422 v VARCHAR,
423 f FLOAT,
424 d DOUBLE,
425 b BOOLEAN,
426 vb VARBINARY,
427 dt DATE,
428 dtt TIMESTAMP(6),
429 ts0 TIMESTAMP(0),
430 ts3 TIMESTAMP(3),
431 ts6 TIMESTAMP(6),
432 ts9 TIMESTAMP(9) DEFAULT CURRENT_TIMESTAMP NOT NULL,
433 TIME INDEX (ts9),
434 PRIMARY KEY (s)
435)
436ENGINE=mito
437"#;
438
439 assert_eq!(expected, c.to_string());
440 }
441 _ => unreachable!(),
442 }
443 }
444}