sql/
util.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::fmt::{Display, Formatter};
17
18use itertools::Itertools;
19use serde::Serialize;
20use snafu::ensure;
21use sqlparser::ast::{
22    Array, Expr, Ident, ObjectName, SetExpr, SqlOption, TableFactor, Value, ValueWithSpan,
23};
24use sqlparser_derive::{Visit, VisitMut};
25
26use crate::ast::ObjectNamePartExt;
27use crate::error::{InvalidExprAsOptionValueSnafu, InvalidSqlSnafu, Result};
28use crate::statements::create::SqlOrTql;
29
30/// Format an [ObjectName] without any quote of its idents.
31pub fn format_raw_object_name(name: &ObjectName) -> String {
32    struct Inner<'a> {
33        name: &'a ObjectName,
34    }
35
36    impl Display for Inner<'_> {
37        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
38            let mut delim = "";
39            for ident in self.name.0.iter() {
40                write!(f, "{delim}")?;
41                delim = ".";
42                write!(f, "{}", ident.to_string_unquoted())?;
43            }
44            Ok(())
45        }
46    }
47
48    format!("{}", Inner { name })
49}
50
51#[derive(Clone, Debug, PartialEq, Eq, Serialize, Visit, VisitMut)]
52pub struct OptionValue(Expr);
53
54impl OptionValue {
55    fn try_new(expr: Expr) -> Result<Self> {
56        ensure!(
57            matches!(expr, Expr::Value(_) | Expr::Identifier(_) | Expr::Array(_)),
58            InvalidExprAsOptionValueSnafu {
59                error: format!("{expr} not accepted")
60            }
61        );
62        Ok(Self(expr))
63    }
64
65    fn expr_as_string(expr: &Expr) -> Option<&str> {
66        match expr {
67            Expr::Value(ValueWithSpan { value, .. }) => match value {
68                Value::SingleQuotedString(s)
69                | Value::DoubleQuotedString(s)
70                | Value::TripleSingleQuotedString(s)
71                | Value::TripleDoubleQuotedString(s)
72                | Value::SingleQuotedByteStringLiteral(s)
73                | Value::DoubleQuotedByteStringLiteral(s)
74                | Value::TripleSingleQuotedByteStringLiteral(s)
75                | Value::TripleDoubleQuotedByteStringLiteral(s)
76                | Value::SingleQuotedRawStringLiteral(s)
77                | Value::DoubleQuotedRawStringLiteral(s)
78                | Value::TripleSingleQuotedRawStringLiteral(s)
79                | Value::TripleDoubleQuotedRawStringLiteral(s)
80                | Value::EscapedStringLiteral(s)
81                | Value::UnicodeStringLiteral(s)
82                | Value::NationalStringLiteral(s)
83                | Value::HexStringLiteral(s) => Some(s),
84                Value::DollarQuotedString(s) => Some(&s.value),
85                Value::Number(s, _) => Some(s),
86                _ => None,
87            },
88            Expr::Identifier(ident) => Some(&ident.value),
89            _ => None,
90        }
91    }
92
93    pub fn as_string(&self) -> Option<&str> {
94        Self::expr_as_string(&self.0)
95    }
96
97    pub fn as_list(&self) -> Option<Vec<&str>> {
98        let expr = &self.0;
99        match expr {
100            Expr::Value(_) | Expr::Identifier(_) => self.as_string().map(|s| vec![s]),
101            Expr::Array(array) => array
102                .elem
103                .iter()
104                .map(Self::expr_as_string)
105                .collect::<Option<Vec<_>>>(),
106            _ => None,
107        }
108    }
109}
110
111impl From<String> for OptionValue {
112    fn from(value: String) -> Self {
113        Self(Expr::Identifier(Ident::new(value)))
114    }
115}
116
117impl From<&str> for OptionValue {
118    fn from(value: &str) -> Self {
119        Self(Expr::Identifier(Ident::new(value)))
120    }
121}
122
123impl From<Vec<&str>> for OptionValue {
124    fn from(value: Vec<&str>) -> Self {
125        Self(Expr::Array(Array {
126            elem: value
127                .into_iter()
128                .map(|x| Expr::Identifier(Ident::new(x)))
129                .collect(),
130            named: false,
131        }))
132    }
133}
134
135impl Display for OptionValue {
136    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
137        if let Some(s) = self.as_string() {
138            write!(f, "'{s}'")
139        } else if let Some(s) = self.as_list() {
140            write!(
141                f,
142                "[{}]",
143                s.into_iter().map(|x| format!("'{x}'")).join(", ")
144            )
145        } else {
146            write!(f, "'{}'", self.0)
147        }
148    }
149}
150
151pub fn parse_option_string(option: SqlOption) -> Result<(String, OptionValue)> {
152    let SqlOption::KeyValue { key, value } = option else {
153        return InvalidSqlSnafu {
154            msg: "Expecting a key-value pair in the option",
155        }
156        .fail();
157    };
158    let v = OptionValue::try_new(value)?;
159    let k = key.value.to_lowercase();
160    Ok((k, v))
161}
162
163/// Walk through a [Query] and extract all the tables referenced in it.
164pub fn extract_tables_from_query(query: &SqlOrTql) -> impl Iterator<Item = ObjectName> {
165    let mut names = HashSet::new();
166
167    match query {
168        SqlOrTql::Sql(query, _) => extract_tables_from_set_expr(&query.body, &mut names),
169        SqlOrTql::Tql(_tql, _) => {
170            // since tql have sliding time window, so we don't need to extract tables from it
171            // (because we are going to eval it fully anyway)
172        }
173    }
174
175    names.into_iter()
176}
177
178/// translate the start location to the index in the sql string
179pub fn location_to_index(sql: &str, location: &sqlparser::tokenizer::Location) -> usize {
180    let mut index = 0;
181    for (lno, line) in sql.lines().enumerate() {
182        if lno + 1 == location.line as usize {
183            index += location.column as usize;
184            break;
185        } else {
186            index += line.len() + 1; // +1 for the newline
187        }
188    }
189    // -1 because the index is 0-based
190    // and the location is 1-based
191    index - 1
192}
193
194/// Helper function for [extract_tables_from_query].
195///
196/// Handle [SetExpr].
197fn extract_tables_from_set_expr(set_expr: &SetExpr, names: &mut HashSet<ObjectName>) {
198    match set_expr {
199        SetExpr::Select(select) => {
200            for from in &select.from {
201                table_factor_to_object_name(&from.relation, names);
202                for join in &from.joins {
203                    table_factor_to_object_name(&join.relation, names);
204                }
205            }
206        }
207        SetExpr::Query(query) => {
208            extract_tables_from_set_expr(&query.body, names);
209        }
210        SetExpr::SetOperation { left, right, .. } => {
211            extract_tables_from_set_expr(left, names);
212            extract_tables_from_set_expr(right, names);
213        }
214        _ => {}
215    };
216}
217
218/// Helper function for [extract_tables_from_query].
219///
220/// Handle [TableFactor].
221fn table_factor_to_object_name(table_factor: &TableFactor, names: &mut HashSet<ObjectName>) {
222    if let TableFactor::Table { name, .. } = table_factor {
223        names.insert(name.to_owned());
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use sqlparser::tokenizer::Token;
230
231    use super::*;
232    use crate::dialect::GreptimeDbDialect;
233    use crate::parser::ParserContext;
234
235    #[test]
236    fn test_location_to_index() {
237        let testcases = vec![
238            "SELECT * FROM t WHERE a = 1",
239            // start or end with newline
240            r"
241SELECT *
242FROM
243t
244WHERE a =
2451
246",
247            r"SELECT *
248FROM
249t
250WHERE a =
2511
252",
253            r"
254SELECT *
255FROM
256t
257WHERE a =
2581",
259        ];
260
261        for sql in testcases {
262            let mut parser = ParserContext::new(&GreptimeDbDialect {}, sql).unwrap();
263            loop {
264                let token = parser.parser.next_token();
265                if token == Token::EOF {
266                    break;
267                }
268                let span = token.span;
269                let subslice =
270                    &sql[location_to_index(sql, &span.start)..location_to_index(sql, &span.end)];
271                assert_eq!(token.to_string(), subslice);
272            }
273        }
274    }
275}