1use std::collections::HashSet;
16use std::fmt::{Display, Formatter};
17
18use sqlparser::ast::{Expr, ObjectName, SetExpr, SqlOption, TableFactor, Value, ValueWithSpan};
19
20use crate::ast::ObjectNamePartExt;
21use crate::error::{InvalidSqlSnafu, InvalidTableOptionValueSnafu, Result};
22use crate::statements::create::SqlOrTql;
23
24pub fn format_raw_object_name(name: &ObjectName) -> String {
26 struct Inner<'a> {
27 name: &'a ObjectName,
28 }
29
30 impl Display for Inner<'_> {
31 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
32 let mut delim = "";
33 for ident in self.name.0.iter() {
34 write!(f, "{delim}")?;
35 delim = ".";
36 write!(f, "{}", ident.to_string_unquoted())?;
37 }
38 Ok(())
39 }
40 }
41
42 format!("{}", Inner { name })
43}
44
45pub fn parse_option_string(option: SqlOption) -> Result<(String, String)> {
46 let SqlOption::KeyValue { key, value } = option else {
47 return InvalidSqlSnafu {
48 msg: "Expecting a key-value pair in the option",
49 }
50 .fail();
51 };
52 let v = match value {
53 Expr::Value(ValueWithSpan {
54 value: Value::SingleQuotedString(v),
55 ..
56 })
57 | Expr::Value(ValueWithSpan {
58 value: Value::DoubleQuotedString(v),
59 ..
60 }) => v,
61 Expr::Identifier(v) => v.value,
62 Expr::Value(ValueWithSpan {
63 value: Value::Number(v, _),
64 ..
65 }) => v.to_string(),
66 value => return InvalidTableOptionValueSnafu { key, value }.fail(),
67 };
68 let k = key.value.to_lowercase();
69 Ok((k, v))
70}
71
72pub fn extract_tables_from_query(query: &SqlOrTql) -> impl Iterator<Item = ObjectName> {
74 let mut names = HashSet::new();
75
76 match query {
77 SqlOrTql::Sql(query, _) => extract_tables_from_set_expr(&query.body, &mut names),
78 SqlOrTql::Tql(_tql, _) => {
79 }
82 }
83
84 names.into_iter()
85}
86
87pub fn location_to_index(sql: &str, location: &sqlparser::tokenizer::Location) -> usize {
89 let mut index = 0;
90 for (lno, line) in sql.lines().enumerate() {
91 if lno + 1 == location.line as usize {
92 index += location.column as usize;
93 break;
94 } else {
95 index += line.len() + 1; }
97 }
98 index - 1
101}
102
103fn extract_tables_from_set_expr(set_expr: &SetExpr, names: &mut HashSet<ObjectName>) {
107 match set_expr {
108 SetExpr::Select(select) => {
109 for from in &select.from {
110 table_factor_to_object_name(&from.relation, names);
111 for join in &from.joins {
112 table_factor_to_object_name(&join.relation, names);
113 }
114 }
115 }
116 SetExpr::Query(query) => {
117 extract_tables_from_set_expr(&query.body, names);
118 }
119 SetExpr::SetOperation { left, right, .. } => {
120 extract_tables_from_set_expr(left, names);
121 extract_tables_from_set_expr(right, names);
122 }
123 SetExpr::Values(_) | SetExpr::Insert(_) | SetExpr::Update(_) | SetExpr::Table(_) => {}
124 };
125}
126
127fn table_factor_to_object_name(table_factor: &TableFactor, names: &mut HashSet<ObjectName>) {
131 if let TableFactor::Table { name, .. } = table_factor {
132 names.insert(name.to_owned());
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use sqlparser::tokenizer::Token;
139
140 use super::*;
141 use crate::dialect::GreptimeDbDialect;
142 use crate::parser::ParserContext;
143
144 #[test]
145 fn test_location_to_index() {
146 let testcases = vec![
147 "SELECT * FROM t WHERE a = 1",
148 r"
150SELECT *
151FROM
152t
153WHERE a =
1541
155",
156 r"SELECT *
157FROM
158t
159WHERE a =
1601
161",
162 r"
163SELECT *
164FROM
165t
166WHERE a =
1671",
168 ];
169
170 for sql in testcases {
171 let mut parser = ParserContext::new(&GreptimeDbDialect {}, sql).unwrap();
172 loop {
173 let token = parser.parser.next_token();
174 if token == Token::EOF {
175 break;
176 }
177 let span = token.span;
178 let subslice =
179 &sql[location_to_index(sql, &span.start)..location_to_index(sql, &span.end)];
180 assert_eq!(token.to_string(), subslice);
181 }
182 }
183 }
184}