1use std::collections::HashSet;
16use std::fmt::{Display, Formatter};
17
18use itertools::Itertools;
19use serde::Serialize;
20use snafu::ensure;
21use sqlparser::ast::{
22 Array, Expr, Ident, ObjectName, SetExpr, SqlOption, StructField, TableFactor, Value,
23 ValueWithSpan,
24};
25use sqlparser_derive::{Visit, VisitMut};
26
27use crate::ast::ObjectNamePartExt;
28use crate::error::{InvalidExprAsOptionValueSnafu, InvalidSqlSnafu, Result};
29use crate::statements::create::SqlOrTql;
30
31pub fn format_raw_object_name(name: &ObjectName) -> String {
33 struct Inner<'a> {
34 name: &'a ObjectName,
35 }
36
37 impl Display for Inner<'_> {
38 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
39 let mut delim = "";
40 for ident in self.name.0.iter() {
41 write!(f, "{delim}")?;
42 delim = ".";
43 write!(f, "{}", ident.to_string_unquoted())?;
44 }
45 Ok(())
46 }
47 }
48
49 format!("{}", Inner { name })
50}
51
52#[derive(Clone, Debug, PartialEq, Eq, Serialize, Visit, VisitMut)]
53pub struct OptionValue(Expr);
54
55impl OptionValue {
56 pub(crate) fn try_new(expr: Expr) -> Result<Self> {
57 ensure!(
58 matches!(
59 expr,
60 Expr::Value(_) | Expr::Identifier(_) | Expr::Array(_) | Expr::Struct { .. }
61 ),
62 InvalidExprAsOptionValueSnafu {
63 error: format!("{expr} not accepted")
64 }
65 );
66 Ok(Self(expr))
67 }
68
69 fn expr_as_string(expr: &Expr) -> Option<&str> {
70 match expr {
71 Expr::Value(ValueWithSpan { value, .. }) => match value {
72 Value::SingleQuotedString(s)
73 | Value::DoubleQuotedString(s)
74 | Value::TripleSingleQuotedString(s)
75 | Value::TripleDoubleQuotedString(s)
76 | Value::SingleQuotedByteStringLiteral(s)
77 | Value::DoubleQuotedByteStringLiteral(s)
78 | Value::TripleSingleQuotedByteStringLiteral(s)
79 | Value::TripleDoubleQuotedByteStringLiteral(s)
80 | Value::SingleQuotedRawStringLiteral(s)
81 | Value::DoubleQuotedRawStringLiteral(s)
82 | Value::TripleSingleQuotedRawStringLiteral(s)
83 | Value::TripleDoubleQuotedRawStringLiteral(s)
84 | Value::EscapedStringLiteral(s)
85 | Value::UnicodeStringLiteral(s)
86 | Value::NationalStringLiteral(s)
87 | Value::HexStringLiteral(s) => Some(s),
88 Value::DollarQuotedString(s) => Some(&s.value),
89 Value::Number(s, _) => Some(s),
90 Value::Boolean(b) => Some(if *b { "true" } else { "false" }),
91 _ => None,
92 },
93 Expr::Identifier(ident) => Some(&ident.value),
94 _ => None,
95 }
96 }
97
98 pub fn as_string(&self) -> Option<&str> {
102 Self::expr_as_string(&self.0)
103 }
104
105 pub fn as_list(&self) -> Option<Vec<&str>> {
106 let expr = &self.0;
107 match expr {
108 Expr::Value(_) | Expr::Identifier(_) => self.as_string().map(|s| vec![s]),
109 Expr::Array(array) => array
110 .elem
111 .iter()
112 .map(Self::expr_as_string)
113 .collect::<Option<Vec<_>>>(),
114 _ => None,
115 }
116 }
117
118 pub(crate) fn as_struct_fields(&self) -> Option<&[StructField]> {
119 match &self.0 {
120 Expr::Struct { fields, .. } => Some(fields),
121 _ => None,
122 }
123 }
124}
125
126impl From<String> for OptionValue {
127 fn from(value: String) -> Self {
128 Self(Expr::Identifier(Ident::new(value)))
129 }
130}
131
132impl From<&str> for OptionValue {
133 fn from(value: &str) -> Self {
134 Self(Expr::Identifier(Ident::new(value)))
135 }
136}
137
138impl From<Vec<&str>> for OptionValue {
139 fn from(value: Vec<&str>) -> Self {
140 Self(Expr::Array(Array {
141 elem: value
142 .into_iter()
143 .map(|x| Expr::Identifier(Ident::new(x)))
144 .collect(),
145 named: false,
146 }))
147 }
148}
149
150impl Display for OptionValue {
151 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
152 if let Some(s) = self.as_string() {
153 write!(f, "'{s}'")
154 } else if let Some(s) = self.as_list() {
155 write!(
156 f,
157 "[{}]",
158 s.into_iter().map(|x| format!("'{x}'")).join(", ")
159 )
160 } else {
161 write!(f, "'{}'", self.0)
162 }
163 }
164}
165
166pub fn parse_option_string(option: SqlOption) -> Result<(String, OptionValue)> {
167 let SqlOption::KeyValue { key, value } = option else {
168 return InvalidSqlSnafu {
169 msg: "Expecting a key-value pair in the option",
170 }
171 .fail();
172 };
173 let v = OptionValue::try_new(value)?;
174 let k = key.value.to_lowercase();
175 Ok((k, v))
176}
177
178pub fn extract_tables_from_query(query: &SqlOrTql) -> impl Iterator<Item = ObjectName> {
180 let mut names = HashSet::new();
181
182 match query {
183 SqlOrTql::Sql(query, _) => extract_tables_from_set_expr(&query.body, &mut names),
184 SqlOrTql::Tql(_tql, _) => {
185 }
188 }
189
190 names.into_iter()
191}
192
193pub fn location_to_index(sql: &str, location: &sqlparser::tokenizer::Location) -> usize {
195 let mut index = 0;
196 for (lno, line) in sql.lines().enumerate() {
197 if lno + 1 == location.line as usize {
198 index += location.column as usize;
199 break;
200 } else {
201 index += line.len() + 1; }
203 }
204 index - 1
207}
208
209fn extract_tables_from_set_expr(set_expr: &SetExpr, names: &mut HashSet<ObjectName>) {
213 match set_expr {
214 SetExpr::Select(select) => {
215 for from in &select.from {
216 table_factor_to_object_name(&from.relation, names);
217 for join in &from.joins {
218 table_factor_to_object_name(&join.relation, names);
219 }
220 }
221 }
222 SetExpr::Query(query) => {
223 extract_tables_from_set_expr(&query.body, names);
224 }
225 SetExpr::SetOperation { left, right, .. } => {
226 extract_tables_from_set_expr(left, names);
227 extract_tables_from_set_expr(right, names);
228 }
229 _ => {}
230 };
231}
232
233fn table_factor_to_object_name(table_factor: &TableFactor, names: &mut HashSet<ObjectName>) {
237 if let TableFactor::Table { name, .. } = table_factor {
238 names.insert(name.to_owned());
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use sqlparser::tokenizer::Token;
245
246 use super::*;
247 use crate::dialect::GreptimeDbDialect;
248 use crate::parser::ParserContext;
249
250 #[test]
251 fn test_location_to_index() {
252 let testcases = vec![
253 "SELECT * FROM t WHERE a = 1",
254 r"
256SELECT *
257FROM
258t
259WHERE a =
2601
261",
262 r"SELECT *
263FROM
264t
265WHERE a =
2661
267",
268 r"
269SELECT *
270FROM
271t
272WHERE a =
2731",
274 ];
275
276 for sql in testcases {
277 let mut parser = ParserContext::new(&GreptimeDbDialect {}, sql).unwrap();
278 loop {
279 let token = parser.parser.next_token();
280 if token == Token::EOF {
281 break;
282 }
283 let span = token.span;
284 let subslice =
285 &sql[location_to_index(sql, &span.start)..location_to_index(sql, &span.end)];
286 assert_eq!(token.to_string(), subslice);
287 }
288 }
289 }
290}