1use std::collections::HashSet;
16use std::fmt::{Display, Formatter};
17
18use itertools::Itertools;
19use serde::Serialize;
20use snafu::ensure;
21use sqlparser::ast::{
22 Array, Expr, Ident, ObjectName, SetExpr, SqlOption, TableFactor, Value, ValueWithSpan,
23};
24use sqlparser_derive::{Visit, VisitMut};
25
26use crate::ast::ObjectNamePartExt;
27use crate::error::{InvalidExprAsOptionValueSnafu, InvalidSqlSnafu, Result};
28use crate::statements::create::SqlOrTql;
29
30pub fn format_raw_object_name(name: &ObjectName) -> String {
32 struct Inner<'a> {
33 name: &'a ObjectName,
34 }
35
36 impl Display for Inner<'_> {
37 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
38 let mut delim = "";
39 for ident in self.name.0.iter() {
40 write!(f, "{delim}")?;
41 delim = ".";
42 write!(f, "{}", ident.to_string_unquoted())?;
43 }
44 Ok(())
45 }
46 }
47
48 format!("{}", Inner { name })
49}
50
51#[derive(Clone, Debug, PartialEq, Eq, Serialize, Visit, VisitMut)]
52pub struct OptionValue(Expr);
53
54impl OptionValue {
55 fn try_new(expr: Expr) -> Result<Self> {
56 ensure!(
57 matches!(expr, Expr::Value(_) | Expr::Identifier(_) | Expr::Array(_)),
58 InvalidExprAsOptionValueSnafu {
59 error: format!("{expr} not accepted")
60 }
61 );
62 Ok(Self(expr))
63 }
64
65 fn expr_as_string(expr: &Expr) -> Option<&str> {
66 match expr {
67 Expr::Value(ValueWithSpan { value, .. }) => match value {
68 Value::SingleQuotedString(s)
69 | Value::DoubleQuotedString(s)
70 | Value::TripleSingleQuotedString(s)
71 | Value::TripleDoubleQuotedString(s)
72 | Value::SingleQuotedByteStringLiteral(s)
73 | Value::DoubleQuotedByteStringLiteral(s)
74 | Value::TripleSingleQuotedByteStringLiteral(s)
75 | Value::TripleDoubleQuotedByteStringLiteral(s)
76 | Value::SingleQuotedRawStringLiteral(s)
77 | Value::DoubleQuotedRawStringLiteral(s)
78 | Value::TripleSingleQuotedRawStringLiteral(s)
79 | Value::TripleDoubleQuotedRawStringLiteral(s)
80 | Value::EscapedStringLiteral(s)
81 | Value::UnicodeStringLiteral(s)
82 | Value::NationalStringLiteral(s)
83 | Value::HexStringLiteral(s) => Some(s),
84 Value::DollarQuotedString(s) => Some(&s.value),
85 Value::Number(s, _) => Some(s),
86 _ => None,
87 },
88 Expr::Identifier(ident) => Some(&ident.value),
89 _ => None,
90 }
91 }
92
93 pub fn as_string(&self) -> Option<&str> {
94 Self::expr_as_string(&self.0)
95 }
96
97 pub fn as_list(&self) -> Option<Vec<&str>> {
98 let expr = &self.0;
99 match expr {
100 Expr::Value(_) | Expr::Identifier(_) => self.as_string().map(|s| vec![s]),
101 Expr::Array(array) => array
102 .elem
103 .iter()
104 .map(Self::expr_as_string)
105 .collect::<Option<Vec<_>>>(),
106 _ => None,
107 }
108 }
109}
110
111impl From<String> for OptionValue {
112 fn from(value: String) -> Self {
113 Self(Expr::Identifier(Ident::new(value)))
114 }
115}
116
117impl From<&str> for OptionValue {
118 fn from(value: &str) -> Self {
119 Self(Expr::Identifier(Ident::new(value)))
120 }
121}
122
123impl From<Vec<&str>> for OptionValue {
124 fn from(value: Vec<&str>) -> Self {
125 Self(Expr::Array(Array {
126 elem: value
127 .into_iter()
128 .map(|x| Expr::Identifier(Ident::new(x)))
129 .collect(),
130 named: false,
131 }))
132 }
133}
134
135impl Display for OptionValue {
136 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
137 if let Some(s) = self.as_string() {
138 write!(f, "'{s}'")
139 } else if let Some(s) = self.as_list() {
140 write!(
141 f,
142 "[{}]",
143 s.into_iter().map(|x| format!("'{x}'")).join(", ")
144 )
145 } else {
146 write!(f, "'{}'", self.0)
147 }
148 }
149}
150
151pub fn parse_option_string(option: SqlOption) -> Result<(String, OptionValue)> {
152 let SqlOption::KeyValue { key, value } = option else {
153 return InvalidSqlSnafu {
154 msg: "Expecting a key-value pair in the option",
155 }
156 .fail();
157 };
158 let v = OptionValue::try_new(value)?;
159 let k = key.value.to_lowercase();
160 Ok((k, v))
161}
162
163pub fn extract_tables_from_query(query: &SqlOrTql) -> impl Iterator<Item = ObjectName> {
165 let mut names = HashSet::new();
166
167 match query {
168 SqlOrTql::Sql(query, _) => extract_tables_from_set_expr(&query.body, &mut names),
169 SqlOrTql::Tql(_tql, _) => {
170 }
173 }
174
175 names.into_iter()
176}
177
178pub fn location_to_index(sql: &str, location: &sqlparser::tokenizer::Location) -> usize {
180 let mut index = 0;
181 for (lno, line) in sql.lines().enumerate() {
182 if lno + 1 == location.line as usize {
183 index += location.column as usize;
184 break;
185 } else {
186 index += line.len() + 1; }
188 }
189 index - 1
192}
193
194fn extract_tables_from_set_expr(set_expr: &SetExpr, names: &mut HashSet<ObjectName>) {
198 match set_expr {
199 SetExpr::Select(select) => {
200 for from in &select.from {
201 table_factor_to_object_name(&from.relation, names);
202 for join in &from.joins {
203 table_factor_to_object_name(&join.relation, names);
204 }
205 }
206 }
207 SetExpr::Query(query) => {
208 extract_tables_from_set_expr(&query.body, names);
209 }
210 SetExpr::SetOperation { left, right, .. } => {
211 extract_tables_from_set_expr(left, names);
212 extract_tables_from_set_expr(right, names);
213 }
214 _ => {}
215 };
216}
217
218fn table_factor_to_object_name(table_factor: &TableFactor, names: &mut HashSet<ObjectName>) {
222 if let TableFactor::Table { name, .. } = table_factor {
223 names.insert(name.to_owned());
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use sqlparser::tokenizer::Token;
230
231 use super::*;
232 use crate::dialect::GreptimeDbDialect;
233 use crate::parser::ParserContext;
234
235 #[test]
236 fn test_location_to_index() {
237 let testcases = vec![
238 "SELECT * FROM t WHERE a = 1",
239 r"
241SELECT *
242FROM
243t
244WHERE a =
2451
246",
247 r"SELECT *
248FROM
249t
250WHERE a =
2511
252",
253 r"
254SELECT *
255FROM
256t
257WHERE a =
2581",
259 ];
260
261 for sql in testcases {
262 let mut parser = ParserContext::new(&GreptimeDbDialect {}, sql).unwrap();
263 loop {
264 let token = parser.parser.next_token();
265 if token == Token::EOF {
266 break;
267 }
268 let span = token.span;
269 let subslice =
270 &sql[location_to_index(sql, &span.start)..location_to_index(sql, &span.end)];
271 assert_eq!(token.to_string(), subslice);
272 }
273 }
274 }
275}