sql/
util.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::fmt::{Display, Formatter};
17
18use itertools::Itertools;
19use serde::Serialize;
20use snafu::ensure;
21use sqlparser::ast::{
22    Array, Expr, Ident, ObjectName, SetExpr, SqlOption, StructField, TableFactor, Value,
23    ValueWithSpan,
24};
25use sqlparser_derive::{Visit, VisitMut};
26
27use crate::ast::ObjectNamePartExt;
28use crate::error::{InvalidExprAsOptionValueSnafu, InvalidSqlSnafu, Result};
29use crate::statements::create::SqlOrTql;
30
31/// Format an [ObjectName] without any quote of its idents.
32pub fn format_raw_object_name(name: &ObjectName) -> String {
33    struct Inner<'a> {
34        name: &'a ObjectName,
35    }
36
37    impl Display for Inner<'_> {
38        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
39            let mut delim = "";
40            for ident in self.name.0.iter() {
41                write!(f, "{delim}")?;
42                delim = ".";
43                write!(f, "{}", ident.to_string_unquoted())?;
44            }
45            Ok(())
46        }
47    }
48
49    format!("{}", Inner { name })
50}
51
52#[derive(Clone, Debug, PartialEq, Eq, Serialize, Visit, VisitMut)]
53pub struct OptionValue(Expr);
54
55impl OptionValue {
56    pub(crate) fn try_new(expr: Expr) -> Result<Self> {
57        ensure!(
58            matches!(
59                expr,
60                Expr::Value(_) | Expr::Identifier(_) | Expr::Array(_) | Expr::Struct { .. }
61            ),
62            InvalidExprAsOptionValueSnafu {
63                error: format!("{expr} not accepted")
64            }
65        );
66        Ok(Self(expr))
67    }
68
69    fn expr_as_string(expr: &Expr) -> Option<&str> {
70        match expr {
71            Expr::Value(ValueWithSpan { value, .. }) => match value {
72                Value::SingleQuotedString(s)
73                | Value::DoubleQuotedString(s)
74                | Value::TripleSingleQuotedString(s)
75                | Value::TripleDoubleQuotedString(s)
76                | Value::SingleQuotedByteStringLiteral(s)
77                | Value::DoubleQuotedByteStringLiteral(s)
78                | Value::TripleSingleQuotedByteStringLiteral(s)
79                | Value::TripleDoubleQuotedByteStringLiteral(s)
80                | Value::SingleQuotedRawStringLiteral(s)
81                | Value::DoubleQuotedRawStringLiteral(s)
82                | Value::TripleSingleQuotedRawStringLiteral(s)
83                | Value::TripleDoubleQuotedRawStringLiteral(s)
84                | Value::EscapedStringLiteral(s)
85                | Value::UnicodeStringLiteral(s)
86                | Value::NationalStringLiteral(s)
87                | Value::HexStringLiteral(s) => Some(s),
88                Value::DollarQuotedString(s) => Some(&s.value),
89                Value::Number(s, _) => Some(s),
90                _ => None,
91            },
92            Expr::Identifier(ident) => Some(&ident.value),
93            _ => None,
94        }
95    }
96
97    pub fn as_string(&self) -> Option<&str> {
98        Self::expr_as_string(&self.0)
99    }
100
101    pub fn as_list(&self) -> Option<Vec<&str>> {
102        let expr = &self.0;
103        match expr {
104            Expr::Value(_) | Expr::Identifier(_) => self.as_string().map(|s| vec![s]),
105            Expr::Array(array) => array
106                .elem
107                .iter()
108                .map(Self::expr_as_string)
109                .collect::<Option<Vec<_>>>(),
110            _ => None,
111        }
112    }
113
114    pub(crate) fn as_struct_fields(&self) -> Option<&[StructField]> {
115        match &self.0 {
116            Expr::Struct { fields, .. } => Some(fields),
117            _ => None,
118        }
119    }
120}
121
122impl From<String> for OptionValue {
123    fn from(value: String) -> Self {
124        Self(Expr::Identifier(Ident::new(value)))
125    }
126}
127
128impl From<&str> for OptionValue {
129    fn from(value: &str) -> Self {
130        Self(Expr::Identifier(Ident::new(value)))
131    }
132}
133
134impl From<Vec<&str>> for OptionValue {
135    fn from(value: Vec<&str>) -> Self {
136        Self(Expr::Array(Array {
137            elem: value
138                .into_iter()
139                .map(|x| Expr::Identifier(Ident::new(x)))
140                .collect(),
141            named: false,
142        }))
143    }
144}
145
146impl Display for OptionValue {
147    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
148        if let Some(s) = self.as_string() {
149            write!(f, "'{s}'")
150        } else if let Some(s) = self.as_list() {
151            write!(
152                f,
153                "[{}]",
154                s.into_iter().map(|x| format!("'{x}'")).join(", ")
155            )
156        } else {
157            write!(f, "'{}'", self.0)
158        }
159    }
160}
161
162pub fn parse_option_string(option: SqlOption) -> Result<(String, OptionValue)> {
163    let SqlOption::KeyValue { key, value } = option else {
164        return InvalidSqlSnafu {
165            msg: "Expecting a key-value pair in the option",
166        }
167        .fail();
168    };
169    let v = OptionValue::try_new(value)?;
170    let k = key.value.to_lowercase();
171    Ok((k, v))
172}
173
174/// Walk through a [Query] and extract all the tables referenced in it.
175pub fn extract_tables_from_query(query: &SqlOrTql) -> impl Iterator<Item = ObjectName> {
176    let mut names = HashSet::new();
177
178    match query {
179        SqlOrTql::Sql(query, _) => extract_tables_from_set_expr(&query.body, &mut names),
180        SqlOrTql::Tql(_tql, _) => {
181            // since tql have sliding time window, so we don't need to extract tables from it
182            // (because we are going to eval it fully anyway)
183        }
184    }
185
186    names.into_iter()
187}
188
189/// translate the start location to the index in the sql string
190pub fn location_to_index(sql: &str, location: &sqlparser::tokenizer::Location) -> usize {
191    let mut index = 0;
192    for (lno, line) in sql.lines().enumerate() {
193        if lno + 1 == location.line as usize {
194            index += location.column as usize;
195            break;
196        } else {
197            index += line.len() + 1; // +1 for the newline
198        }
199    }
200    // -1 because the index is 0-based
201    // and the location is 1-based
202    index - 1
203}
204
205/// Helper function for [extract_tables_from_query].
206///
207/// Handle [SetExpr].
208fn extract_tables_from_set_expr(set_expr: &SetExpr, names: &mut HashSet<ObjectName>) {
209    match set_expr {
210        SetExpr::Select(select) => {
211            for from in &select.from {
212                table_factor_to_object_name(&from.relation, names);
213                for join in &from.joins {
214                    table_factor_to_object_name(&join.relation, names);
215                }
216            }
217        }
218        SetExpr::Query(query) => {
219            extract_tables_from_set_expr(&query.body, names);
220        }
221        SetExpr::SetOperation { left, right, .. } => {
222            extract_tables_from_set_expr(left, names);
223            extract_tables_from_set_expr(right, names);
224        }
225        _ => {}
226    };
227}
228
229/// Helper function for [extract_tables_from_query].
230///
231/// Handle [TableFactor].
232fn table_factor_to_object_name(table_factor: &TableFactor, names: &mut HashSet<ObjectName>) {
233    if let TableFactor::Table { name, .. } = table_factor {
234        names.insert(name.to_owned());
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use sqlparser::tokenizer::Token;
241
242    use super::*;
243    use crate::dialect::GreptimeDbDialect;
244    use crate::parser::ParserContext;
245
246    #[test]
247    fn test_location_to_index() {
248        let testcases = vec![
249            "SELECT * FROM t WHERE a = 1",
250            // start or end with newline
251            r"
252SELECT *
253FROM
254t
255WHERE a =
2561
257",
258            r"SELECT *
259FROM
260t
261WHERE a =
2621
263",
264            r"
265SELECT *
266FROM
267t
268WHERE a =
2691",
270        ];
271
272        for sql in testcases {
273            let mut parser = ParserContext::new(&GreptimeDbDialect {}, sql).unwrap();
274            loop {
275                let token = parser.parser.next_token();
276                if token == Token::EOF {
277                    break;
278                }
279                let span = token.span;
280                let subslice =
281                    &sql[location_to_index(sql, &span.start)..location_to_index(sql, &span.end)];
282                assert_eq!(token.to_string(), subslice);
283            }
284        }
285    }
286}