sql/
util.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::fmt::{Display, Formatter};
17
18use itertools::Itertools;
19use serde::Serialize;
20use snafu::ensure;
21use sqlparser::ast::{
22    Array, Expr, Ident, ObjectName, SetExpr, SqlOption, StructField, TableFactor, Value,
23    ValueWithSpan,
24};
25use sqlparser_derive::{Visit, VisitMut};
26
27use crate::ast::ObjectNamePartExt;
28use crate::error::{InvalidExprAsOptionValueSnafu, InvalidSqlSnafu, Result};
29use crate::statements::create::SqlOrTql;
30
31/// Format an [ObjectName] without any quote of its idents.
32pub fn format_raw_object_name(name: &ObjectName) -> String {
33    struct Inner<'a> {
34        name: &'a ObjectName,
35    }
36
37    impl Display for Inner<'_> {
38        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
39            let mut delim = "";
40            for ident in self.name.0.iter() {
41                write!(f, "{delim}")?;
42                delim = ".";
43                write!(f, "{}", ident.to_string_unquoted())?;
44            }
45            Ok(())
46        }
47    }
48
49    format!("{}", Inner { name })
50}
51
52#[derive(Clone, Debug, PartialEq, Eq, Serialize, Visit, VisitMut)]
53pub struct OptionValue(Expr);
54
55impl OptionValue {
56    pub(crate) fn try_new(expr: Expr) -> Result<Self> {
57        ensure!(
58            matches!(
59                expr,
60                Expr::Value(_) | Expr::Identifier(_) | Expr::Array(_) | Expr::Struct { .. }
61            ),
62            InvalidExprAsOptionValueSnafu {
63                error: format!("{expr} not accepted")
64            }
65        );
66        Ok(Self(expr))
67    }
68
69    fn expr_as_string(expr: &Expr) -> Option<&str> {
70        match expr {
71            Expr::Value(ValueWithSpan { value, .. }) => match value {
72                Value::SingleQuotedString(s)
73                | Value::DoubleQuotedString(s)
74                | Value::TripleSingleQuotedString(s)
75                | Value::TripleDoubleQuotedString(s)
76                | Value::SingleQuotedByteStringLiteral(s)
77                | Value::DoubleQuotedByteStringLiteral(s)
78                | Value::TripleSingleQuotedByteStringLiteral(s)
79                | Value::TripleDoubleQuotedByteStringLiteral(s)
80                | Value::SingleQuotedRawStringLiteral(s)
81                | Value::DoubleQuotedRawStringLiteral(s)
82                | Value::TripleSingleQuotedRawStringLiteral(s)
83                | Value::TripleDoubleQuotedRawStringLiteral(s)
84                | Value::EscapedStringLiteral(s)
85                | Value::UnicodeStringLiteral(s)
86                | Value::NationalStringLiteral(s)
87                | Value::HexStringLiteral(s) => Some(s),
88                Value::DollarQuotedString(s) => Some(&s.value),
89                Value::Number(s, _) => Some(s),
90                Value::Boolean(b) => Some(if *b { "true" } else { "false" }),
91                _ => None,
92            },
93            Expr::Identifier(ident) => Some(&ident.value),
94            _ => None,
95        }
96    }
97
98    /// Convert the option value to a string.
99    ///
100    /// Notes: Not all values can be converted to a string, refer to [Self::expr_as_string] for more details.
101    pub fn as_string(&self) -> Option<&str> {
102        Self::expr_as_string(&self.0)
103    }
104
105    pub fn as_list(&self) -> Option<Vec<&str>> {
106        let expr = &self.0;
107        match expr {
108            Expr::Value(_) | Expr::Identifier(_) => self.as_string().map(|s| vec![s]),
109            Expr::Array(array) => array
110                .elem
111                .iter()
112                .map(Self::expr_as_string)
113                .collect::<Option<Vec<_>>>(),
114            _ => None,
115        }
116    }
117
118    pub(crate) fn as_struct_fields(&self) -> Option<&[StructField]> {
119        match &self.0 {
120            Expr::Struct { fields, .. } => Some(fields),
121            _ => None,
122        }
123    }
124}
125
126impl From<String> for OptionValue {
127    fn from(value: String) -> Self {
128        Self(Expr::Identifier(Ident::new(value)))
129    }
130}
131
132impl From<&str> for OptionValue {
133    fn from(value: &str) -> Self {
134        Self(Expr::Identifier(Ident::new(value)))
135    }
136}
137
138impl From<Vec<&str>> for OptionValue {
139    fn from(value: Vec<&str>) -> Self {
140        Self(Expr::Array(Array {
141            elem: value
142                .into_iter()
143                .map(|x| Expr::Identifier(Ident::new(x)))
144                .collect(),
145            named: false,
146        }))
147    }
148}
149
150impl Display for OptionValue {
151    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
152        if let Some(s) = self.as_string() {
153            write!(f, "'{s}'")
154        } else if let Some(s) = self.as_list() {
155            write!(
156                f,
157                "[{}]",
158                s.into_iter().map(|x| format!("'{x}'")).join(", ")
159            )
160        } else {
161            write!(f, "'{}'", self.0)
162        }
163    }
164}
165
166pub fn parse_option_string(option: SqlOption) -> Result<(String, OptionValue)> {
167    let SqlOption::KeyValue { key, value } = option else {
168        return InvalidSqlSnafu {
169            msg: "Expecting a key-value pair in the option",
170        }
171        .fail();
172    };
173    let v = OptionValue::try_new(value)?;
174    let k = key.value.to_lowercase();
175    Ok((k, v))
176}
177
178/// Walk through a [Query] and extract all the tables referenced in it.
179pub fn extract_tables_from_query(query: &SqlOrTql) -> impl Iterator<Item = ObjectName> {
180    let mut names = HashSet::new();
181
182    match query {
183        SqlOrTql::Sql(query, _) => extract_tables_from_set_expr(&query.body, &mut names),
184        SqlOrTql::Tql(_tql, _) => {
185            // since tql have sliding time window, so we don't need to extract tables from it
186            // (because we are going to eval it fully anyway)
187        }
188    }
189
190    names.into_iter()
191}
192
193/// translate the start location to the index in the sql string
194pub fn location_to_index(sql: &str, location: &sqlparser::tokenizer::Location) -> usize {
195    let mut index = 0;
196    for (lno, line) in sql.lines().enumerate() {
197        if lno + 1 == location.line as usize {
198            index += location.column as usize;
199            break;
200        } else {
201            index += line.len() + 1; // +1 for the newline
202        }
203    }
204    // -1 because the index is 0-based
205    // and the location is 1-based
206    index - 1
207}
208
209/// Helper function for [extract_tables_from_query].
210///
211/// Handle [SetExpr].
212fn extract_tables_from_set_expr(set_expr: &SetExpr, names: &mut HashSet<ObjectName>) {
213    match set_expr {
214        SetExpr::Select(select) => {
215            for from in &select.from {
216                table_factor_to_object_name(&from.relation, names);
217                for join in &from.joins {
218                    table_factor_to_object_name(&join.relation, names);
219                }
220            }
221        }
222        SetExpr::Query(query) => {
223            extract_tables_from_set_expr(&query.body, names);
224        }
225        SetExpr::SetOperation { left, right, .. } => {
226            extract_tables_from_set_expr(left, names);
227            extract_tables_from_set_expr(right, names);
228        }
229        _ => {}
230    };
231}
232
233/// Helper function for [extract_tables_from_query].
234///
235/// Handle [TableFactor].
236fn table_factor_to_object_name(table_factor: &TableFactor, names: &mut HashSet<ObjectName>) {
237    if let TableFactor::Table { name, .. } = table_factor {
238        names.insert(name.to_owned());
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use sqlparser::tokenizer::Token;
245
246    use super::*;
247    use crate::dialect::GreptimeDbDialect;
248    use crate::parser::ParserContext;
249
250    #[test]
251    fn test_location_to_index() {
252        let testcases = vec![
253            "SELECT * FROM t WHERE a = 1",
254            // start or end with newline
255            r"
256SELECT *
257FROM
258t
259WHERE a =
2601
261",
262            r"SELECT *
263FROM
264t
265WHERE a =
2661
267",
268            r"
269SELECT *
270FROM
271t
272WHERE a =
2731",
274        ];
275
276        for sql in testcases {
277            let mut parser = ParserContext::new(&GreptimeDbDialect {}, sql).unwrap();
278            loop {
279                let token = parser.parser.next_token();
280                if token == Token::EOF {
281                    break;
282                }
283                let span = token.span;
284                let subslice =
285                    &sql[location_to_index(sql, &span.start)..location_to_index(sql, &span.end)];
286                assert_eq!(token.to_string(), subslice);
287            }
288        }
289    }
290}