1use std::collections::HashSet;
16use std::fmt::{Display, Formatter};
17
18use itertools::Itertools;
19use serde::Serialize;
20use snafu::ensure;
21use sqlparser::ast::{
22 Array, Expr, Ident, ObjectName, SetExpr, SqlOption, StructField, TableFactor, Value,
23 ValueWithSpan,
24};
25use sqlparser_derive::{Visit, VisitMut};
26
27use crate::ast::ObjectNamePartExt;
28use crate::error::{InvalidExprAsOptionValueSnafu, InvalidSqlSnafu, Result};
29use crate::statements::create::SqlOrTql;
30
31pub fn format_raw_object_name(name: &ObjectName) -> String {
33 struct Inner<'a> {
34 name: &'a ObjectName,
35 }
36
37 impl Display for Inner<'_> {
38 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
39 let mut delim = "";
40 for ident in self.name.0.iter() {
41 write!(f, "{delim}")?;
42 delim = ".";
43 write!(f, "{}", ident.to_string_unquoted())?;
44 }
45 Ok(())
46 }
47 }
48
49 format!("{}", Inner { name })
50}
51
52#[derive(Clone, Debug, PartialEq, Eq, Serialize, Visit, VisitMut)]
53pub struct OptionValue(Expr);
54
55impl OptionValue {
56 pub(crate) fn try_new(expr: Expr) -> Result<Self> {
57 ensure!(
58 matches!(
59 expr,
60 Expr::Value(_) | Expr::Identifier(_) | Expr::Array(_) | Expr::Struct { .. }
61 ),
62 InvalidExprAsOptionValueSnafu {
63 error: format!("{expr} not accepted")
64 }
65 );
66 Ok(Self(expr))
67 }
68
69 fn expr_as_string(expr: &Expr) -> Option<&str> {
70 match expr {
71 Expr::Value(ValueWithSpan { value, .. }) => match value {
72 Value::SingleQuotedString(s)
73 | Value::DoubleQuotedString(s)
74 | Value::TripleSingleQuotedString(s)
75 | Value::TripleDoubleQuotedString(s)
76 | Value::SingleQuotedByteStringLiteral(s)
77 | Value::DoubleQuotedByteStringLiteral(s)
78 | Value::TripleSingleQuotedByteStringLiteral(s)
79 | Value::TripleDoubleQuotedByteStringLiteral(s)
80 | Value::SingleQuotedRawStringLiteral(s)
81 | Value::DoubleQuotedRawStringLiteral(s)
82 | Value::TripleSingleQuotedRawStringLiteral(s)
83 | Value::TripleDoubleQuotedRawStringLiteral(s)
84 | Value::EscapedStringLiteral(s)
85 | Value::UnicodeStringLiteral(s)
86 | Value::NationalStringLiteral(s)
87 | Value::HexStringLiteral(s) => Some(s),
88 Value::DollarQuotedString(s) => Some(&s.value),
89 Value::Number(s, _) => Some(s),
90 _ => None,
91 },
92 Expr::Identifier(ident) => Some(&ident.value),
93 _ => None,
94 }
95 }
96
97 pub fn as_string(&self) -> Option<&str> {
98 Self::expr_as_string(&self.0)
99 }
100
101 pub fn as_list(&self) -> Option<Vec<&str>> {
102 let expr = &self.0;
103 match expr {
104 Expr::Value(_) | Expr::Identifier(_) => self.as_string().map(|s| vec![s]),
105 Expr::Array(array) => array
106 .elem
107 .iter()
108 .map(Self::expr_as_string)
109 .collect::<Option<Vec<_>>>(),
110 _ => None,
111 }
112 }
113
114 pub(crate) fn as_struct_fields(&self) -> Option<&[StructField]> {
115 match &self.0 {
116 Expr::Struct { fields, .. } => Some(fields),
117 _ => None,
118 }
119 }
120}
121
122impl From<String> for OptionValue {
123 fn from(value: String) -> Self {
124 Self(Expr::Identifier(Ident::new(value)))
125 }
126}
127
128impl From<&str> for OptionValue {
129 fn from(value: &str) -> Self {
130 Self(Expr::Identifier(Ident::new(value)))
131 }
132}
133
134impl From<Vec<&str>> for OptionValue {
135 fn from(value: Vec<&str>) -> Self {
136 Self(Expr::Array(Array {
137 elem: value
138 .into_iter()
139 .map(|x| Expr::Identifier(Ident::new(x)))
140 .collect(),
141 named: false,
142 }))
143 }
144}
145
146impl Display for OptionValue {
147 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
148 if let Some(s) = self.as_string() {
149 write!(f, "'{s}'")
150 } else if let Some(s) = self.as_list() {
151 write!(
152 f,
153 "[{}]",
154 s.into_iter().map(|x| format!("'{x}'")).join(", ")
155 )
156 } else {
157 write!(f, "'{}'", self.0)
158 }
159 }
160}
161
162pub fn parse_option_string(option: SqlOption) -> Result<(String, OptionValue)> {
163 let SqlOption::KeyValue { key, value } = option else {
164 return InvalidSqlSnafu {
165 msg: "Expecting a key-value pair in the option",
166 }
167 .fail();
168 };
169 let v = OptionValue::try_new(value)?;
170 let k = key.value.to_lowercase();
171 Ok((k, v))
172}
173
174pub fn extract_tables_from_query(query: &SqlOrTql) -> impl Iterator<Item = ObjectName> {
176 let mut names = HashSet::new();
177
178 match query {
179 SqlOrTql::Sql(query, _) => extract_tables_from_set_expr(&query.body, &mut names),
180 SqlOrTql::Tql(_tql, _) => {
181 }
184 }
185
186 names.into_iter()
187}
188
189pub fn location_to_index(sql: &str, location: &sqlparser::tokenizer::Location) -> usize {
191 let mut index = 0;
192 for (lno, line) in sql.lines().enumerate() {
193 if lno + 1 == location.line as usize {
194 index += location.column as usize;
195 break;
196 } else {
197 index += line.len() + 1; }
199 }
200 index - 1
203}
204
205fn extract_tables_from_set_expr(set_expr: &SetExpr, names: &mut HashSet<ObjectName>) {
209 match set_expr {
210 SetExpr::Select(select) => {
211 for from in &select.from {
212 table_factor_to_object_name(&from.relation, names);
213 for join in &from.joins {
214 table_factor_to_object_name(&join.relation, names);
215 }
216 }
217 }
218 SetExpr::Query(query) => {
219 extract_tables_from_set_expr(&query.body, names);
220 }
221 SetExpr::SetOperation { left, right, .. } => {
222 extract_tables_from_set_expr(left, names);
223 extract_tables_from_set_expr(right, names);
224 }
225 _ => {}
226 };
227}
228
229fn table_factor_to_object_name(table_factor: &TableFactor, names: &mut HashSet<ObjectName>) {
233 if let TableFactor::Table { name, .. } = table_factor {
234 names.insert(name.to_owned());
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use sqlparser::tokenizer::Token;
241
242 use super::*;
243 use crate::dialect::GreptimeDbDialect;
244 use crate::parser::ParserContext;
245
246 #[test]
247 fn test_location_to_index() {
248 let testcases = vec![
249 "SELECT * FROM t WHERE a = 1",
250 r"
252SELECT *
253FROM
254t
255WHERE a =
2561
257",
258 r"SELECT *
259FROM
260t
261WHERE a =
2621
263",
264 r"
265SELECT *
266FROM
267t
268WHERE a =
2691",
270 ];
271
272 for sql in testcases {
273 let mut parser = ParserContext::new(&GreptimeDbDialect {}, sql).unwrap();
274 loop {
275 let token = parser.parser.next_token();
276 if token == Token::EOF {
277 break;
278 }
279 let span = token.span;
280 let subslice =
281 &sql[location_to_index(sql, &span.start)..location_to_index(sql, &span.end)];
282 assert_eq!(token.to_string(), subslice);
283 }
284 }
285 }
286}