sql/
util.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::fmt::{Display, Formatter};
17
18use sqlparser::ast::{Expr, ObjectName, SetExpr, SqlOption, TableFactor, Value};
19
20use crate::error::{InvalidSqlSnafu, InvalidTableOptionValueSnafu, Result};
21use crate::statements::create::SqlOrTql;
22
23/// Format an [ObjectName] without any quote of its idents.
24pub fn format_raw_object_name(name: &ObjectName) -> String {
25    struct Inner<'a> {
26        name: &'a ObjectName,
27    }
28
29    impl Display for Inner<'_> {
30        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
31            let mut delim = "";
32            for ident in self.name.0.iter() {
33                write!(f, "{delim}")?;
34                delim = ".";
35                write!(f, "{}", ident.value)?;
36            }
37            Ok(())
38        }
39    }
40
41    format!("{}", Inner { name })
42}
43
44pub fn parse_option_string(option: SqlOption) -> Result<(String, String)> {
45    let SqlOption::KeyValue { key, value } = option else {
46        return InvalidSqlSnafu {
47            msg: "Expecting a key-value pair in the option",
48        }
49        .fail();
50    };
51    let v = match value {
52        Expr::Value(Value::SingleQuotedString(v)) | Expr::Value(Value::DoubleQuotedString(v)) => v,
53        Expr::Identifier(v) => v.value,
54        Expr::Value(Value::Number(v, _)) => v.to_string(),
55        value => return InvalidTableOptionValueSnafu { key, value }.fail(),
56    };
57    let k = key.value.to_lowercase();
58    Ok((k, v))
59}
60
61/// Walk through a [Query] and extract all the tables referenced in it.
62pub fn extract_tables_from_query(query: &SqlOrTql) -> impl Iterator<Item = ObjectName> {
63    let mut names = HashSet::new();
64
65    match query {
66        SqlOrTql::Sql(query, _) => extract_tables_from_set_expr(&query.body, &mut names),
67        SqlOrTql::Tql(_tql, _) => {
68            // since tql have sliding time window, so we don't need to extract tables from it
69            // (because we are going to eval it fully anyway)
70        }
71    }
72
73    names.into_iter()
74}
75
76/// translate the start location to the index in the sql string
77pub fn location_to_index(sql: &str, location: &sqlparser::tokenizer::Location) -> usize {
78    let mut index = 0;
79    for (lno, line) in sql.lines().enumerate() {
80        if lno + 1 == location.line as usize {
81            index += location.column as usize;
82            break;
83        } else {
84            index += line.len() + 1; // +1 for the newline
85        }
86    }
87    // -1 because the index is 0-based
88    // and the location is 1-based
89    index - 1
90}
91
92/// Helper function for [extract_tables_from_query].
93///
94/// Handle [SetExpr].
95fn extract_tables_from_set_expr(set_expr: &SetExpr, names: &mut HashSet<ObjectName>) {
96    match set_expr {
97        SetExpr::Select(select) => {
98            for from in &select.from {
99                table_factor_to_object_name(&from.relation, names);
100                for join in &from.joins {
101                    table_factor_to_object_name(&join.relation, names);
102                }
103            }
104        }
105        SetExpr::Query(query) => {
106            extract_tables_from_set_expr(&query.body, names);
107        }
108        SetExpr::SetOperation { left, right, .. } => {
109            extract_tables_from_set_expr(left, names);
110            extract_tables_from_set_expr(right, names);
111        }
112        SetExpr::Values(_) | SetExpr::Insert(_) | SetExpr::Update(_) | SetExpr::Table(_) => {}
113    };
114}
115
116/// Helper function for [extract_tables_from_query].
117///
118/// Handle [TableFactor].
119fn table_factor_to_object_name(table_factor: &TableFactor, names: &mut HashSet<ObjectName>) {
120    if let TableFactor::Table { name, .. } = table_factor {
121        names.insert(name.to_owned());
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use sqlparser::tokenizer::Token;
128
129    use super::*;
130    use crate::dialect::GreptimeDbDialect;
131    use crate::parser::ParserContext;
132
133    #[test]
134    fn test_location_to_index() {
135        let testcases = vec![
136            "SELECT * FROM t WHERE a = 1",
137            // start or end with newline
138            r"
139SELECT *
140FROM
141t
142WHERE a = 
1431
144",
145            r"SELECT *
146FROM
147t
148WHERE a = 
1491
150",
151            r"
152SELECT *
153FROM
154t
155WHERE a = 
1561",
157        ];
158
159        for sql in testcases {
160            let mut parser = ParserContext::new(&GreptimeDbDialect {}, sql).unwrap();
161            loop {
162                let token = parser.parser.next_token();
163                if token == Token::EOF {
164                    break;
165                }
166                let span = token.span;
167                let subslice =
168                    &sql[location_to_index(sql, &span.start)..location_to_index(sql, &span.end)];
169                assert_eq!(token.to_string(), subslice);
170            }
171        }
172    }
173}