sql/
util.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::fmt::{Display, Formatter};
17
18use itertools::Itertools;
19use promql_parser::label::{METRIC_NAME, MatchOp};
20use promql_parser::parser::{
21    AggregateExpr as PromAggregateExpr, BinaryExpr as PromBinaryExpr, Call as PromCall,
22    Expr as PromExpr, MatrixSelector as PromMatrixSelector, ParenExpr as PromParenExpr,
23    SubqueryExpr as PromSubqueryExpr, UnaryExpr as PromUnaryExpr,
24    VectorSelector as PromVectorSelector,
25};
26use serde::Serialize;
27use snafu::ensure;
28use sqlparser::ast::{
29    Array, Expr, Ident, ObjectName, ObjectNamePart, SetExpr, SqlOption, StructField, TableFactor,
30    Value, ValueWithSpan,
31};
32use sqlparser_derive::{Visit, VisitMut};
33
34use crate::ast::ObjectNamePartExt;
35use crate::error::{InvalidExprAsOptionValueSnafu, InvalidSqlSnafu, Result};
36use crate::statements::create::SqlOrTql;
37use crate::statements::tql::Tql;
38
39const SCHEMA_MATCHER: &str = "__schema__";
40const DATABASE_MATCHER: &str = "__database__";
41
42/// Format an [ObjectName] without any quote of its idents.
43pub fn format_raw_object_name(name: &ObjectName) -> String {
44    struct Inner<'a> {
45        name: &'a ObjectName,
46    }
47
48    impl Display for Inner<'_> {
49        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
50            let mut delim = "";
51            for ident in self.name.0.iter() {
52                write!(f, "{delim}")?;
53                delim = ".";
54                write!(f, "{}", ident.to_string_unquoted())?;
55            }
56            Ok(())
57        }
58    }
59
60    format!("{}", Inner { name })
61}
62
63#[derive(Clone, Debug, PartialEq, Eq, Serialize, Visit, VisitMut)]
64pub struct OptionValue(Expr);
65
66impl OptionValue {
67    pub(crate) fn try_new(expr: Expr) -> Result<Self> {
68        ensure!(
69            matches!(
70                expr,
71                Expr::Value(_) | Expr::Identifier(_) | Expr::Array(_) | Expr::Struct { .. }
72            ),
73            InvalidExprAsOptionValueSnafu {
74                error: format!("{expr} not accepted")
75            }
76        );
77        Ok(Self(expr))
78    }
79
80    fn expr_as_string(expr: &Expr) -> Option<&str> {
81        match expr {
82            Expr::Value(ValueWithSpan { value, .. }) => match value {
83                Value::SingleQuotedString(s)
84                | Value::DoubleQuotedString(s)
85                | Value::TripleSingleQuotedString(s)
86                | Value::TripleDoubleQuotedString(s)
87                | Value::SingleQuotedByteStringLiteral(s)
88                | Value::DoubleQuotedByteStringLiteral(s)
89                | Value::TripleSingleQuotedByteStringLiteral(s)
90                | Value::TripleDoubleQuotedByteStringLiteral(s)
91                | Value::SingleQuotedRawStringLiteral(s)
92                | Value::DoubleQuotedRawStringLiteral(s)
93                | Value::TripleSingleQuotedRawStringLiteral(s)
94                | Value::TripleDoubleQuotedRawStringLiteral(s)
95                | Value::EscapedStringLiteral(s)
96                | Value::UnicodeStringLiteral(s)
97                | Value::NationalStringLiteral(s)
98                | Value::HexStringLiteral(s) => Some(s),
99                Value::DollarQuotedString(s) => Some(&s.value),
100                Value::Number(s, _) => Some(s),
101                Value::Boolean(b) => Some(if *b { "true" } else { "false" }),
102                _ => None,
103            },
104            Expr::Identifier(ident) => Some(&ident.value),
105            _ => None,
106        }
107    }
108
109    /// Convert the option value to a string.
110    ///
111    /// Notes: Not all values can be converted to a string, refer to [Self::expr_as_string] for more details.
112    pub fn as_string(&self) -> Option<&str> {
113        Self::expr_as_string(&self.0)
114    }
115
116    pub fn as_list(&self) -> Option<Vec<&str>> {
117        let expr = &self.0;
118        match expr {
119            Expr::Value(_) | Expr::Identifier(_) => self.as_string().map(|s| vec![s]),
120            Expr::Array(array) => array
121                .elem
122                .iter()
123                .map(Self::expr_as_string)
124                .collect::<Option<Vec<_>>>(),
125            _ => None,
126        }
127    }
128
129    pub(crate) fn as_struct_fields(&self) -> Option<&[StructField]> {
130        match &self.0 {
131            Expr::Struct { fields, .. } => Some(fields),
132            _ => None,
133        }
134    }
135}
136
137impl From<String> for OptionValue {
138    fn from(value: String) -> Self {
139        Self(Expr::Identifier(Ident::new(value)))
140    }
141}
142
143impl From<&str> for OptionValue {
144    fn from(value: &str) -> Self {
145        Self(Expr::Identifier(Ident::new(value)))
146    }
147}
148
149impl From<Vec<&str>> for OptionValue {
150    fn from(value: Vec<&str>) -> Self {
151        Self(Expr::Array(Array {
152            elem: value
153                .into_iter()
154                .map(|x| Expr::Identifier(Ident::new(x)))
155                .collect(),
156            named: false,
157        }))
158    }
159}
160
161impl Display for OptionValue {
162    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
163        if let Some(s) = self.as_string() {
164            write!(f, "'{s}'")
165        } else if let Some(s) = self.as_list() {
166            write!(
167                f,
168                "[{}]",
169                s.into_iter().map(|x| format!("'{x}'")).join(", ")
170            )
171        } else {
172            write!(f, "'{}'", self.0)
173        }
174    }
175}
176
177pub fn parse_option_string(option: SqlOption) -> Result<(String, OptionValue)> {
178    let SqlOption::KeyValue { key, value } = option else {
179        return InvalidSqlSnafu {
180            msg: "Expecting a key-value pair in the option",
181        }
182        .fail();
183    };
184    let v = OptionValue::try_new(value)?;
185    let k = key.value.to_lowercase();
186    Ok((k, v))
187}
188
189/// Walk through a [Query] and extract all the tables referenced in it.
190pub fn extract_tables_from_query(query: &SqlOrTql) -> impl Iterator<Item = ObjectName> {
191    let mut names = HashSet::new();
192
193    match query {
194        SqlOrTql::Sql(query, _) => extract_tables_from_set_expr(&query.body, &mut names),
195        SqlOrTql::Tql(tql, _) => extract_tables_from_tql(tql, &mut names),
196    }
197
198    names.into_iter()
199}
200
201fn extract_tables_from_tql(tql: &Tql, names: &mut HashSet<ObjectName>) {
202    let promql = match tql {
203        Tql::Eval(eval) => &eval.query,
204        Tql::Explain(explain) => &explain.query,
205        Tql::Analyze(analyze) => &analyze.query,
206    };
207
208    if let Ok(expr) = promql_parser::parser::parse(promql) {
209        extract_tables_from_prom_expr(&expr, names);
210    }
211}
212
213fn extract_tables_from_prom_expr(expr: &PromExpr, names: &mut HashSet<ObjectName>) {
214    match expr {
215        PromExpr::Aggregate(PromAggregateExpr { expr, .. }) => {
216            extract_tables_from_prom_expr(expr, names);
217        }
218        PromExpr::Unary(PromUnaryExpr { expr, .. }) => {
219            extract_tables_from_prom_expr(expr, names);
220        }
221        PromExpr::Binary(PromBinaryExpr { lhs, rhs, .. }) => {
222            extract_tables_from_prom_expr(lhs, names);
223            extract_tables_from_prom_expr(rhs, names);
224        }
225        PromExpr::Paren(PromParenExpr { expr }) => {
226            extract_tables_from_prom_expr(expr, names);
227        }
228        PromExpr::Subquery(PromSubqueryExpr { expr, .. }) => {
229            extract_tables_from_prom_expr(expr, names);
230        }
231        PromExpr::VectorSelector(selector) => {
232            extract_metric_name_from_vector_selector(selector, names);
233        }
234        PromExpr::MatrixSelector(PromMatrixSelector { vs, .. }) => {
235            extract_metric_name_from_vector_selector(vs, names);
236        }
237        PromExpr::Call(PromCall { args, .. }) => {
238            for arg in &args.args {
239                extract_tables_from_prom_expr(arg, names);
240            }
241        }
242        PromExpr::NumberLiteral(_) | PromExpr::StringLiteral(_) | PromExpr::Extension(_) => {}
243    }
244}
245
246fn extract_metric_name_from_vector_selector(
247    selector: &PromVectorSelector,
248    names: &mut HashSet<ObjectName>,
249) {
250    let metric_name = selector.name.clone().or_else(|| {
251        let mut metric_name_matchers = selector.matchers.find_matchers(METRIC_NAME);
252        if metric_name_matchers.len() == 1 && metric_name_matchers[0].op == MatchOp::Equal {
253            metric_name_matchers.pop().map(|matcher| matcher.value)
254        } else {
255            None
256        }
257    });
258    let Some(metric_name) = metric_name else {
259        return;
260    };
261
262    let schema_matcher = selector.matchers.matchers.iter().rev().find(|matcher| {
263        matcher.op == MatchOp::Equal
264            && (matcher.name == SCHEMA_MATCHER || matcher.name == DATABASE_MATCHER)
265    });
266
267    if let Some(schema) = schema_matcher {
268        names.insert(ObjectName(vec![
269            ObjectNamePart::Identifier(Ident::new(&schema.value)),
270            ObjectNamePart::Identifier(Ident::new(metric_name)),
271        ]));
272    } else {
273        names.insert(ObjectName(vec![ObjectNamePart::Identifier(Ident::new(
274            metric_name,
275        ))]));
276    }
277}
278
279/// translate the start location to the index in the sql string
280pub fn location_to_index(sql: &str, location: &sqlparser::tokenizer::Location) -> usize {
281    let mut index = 0;
282    for (lno, line) in sql.lines().enumerate() {
283        if lno + 1 == location.line as usize {
284            index += location.column as usize;
285            break;
286        } else {
287            index += line.len() + 1; // +1 for the newline
288        }
289    }
290    // -1 because the index is 0-based
291    // and the location is 1-based
292    index - 1
293}
294
295/// Helper function for [extract_tables_from_query].
296///
297/// Handle [SetExpr].
298fn extract_tables_from_set_expr(set_expr: &SetExpr, names: &mut HashSet<ObjectName>) {
299    match set_expr {
300        SetExpr::Select(select) => {
301            for from in &select.from {
302                table_factor_to_object_name(&from.relation, names);
303                for join in &from.joins {
304                    table_factor_to_object_name(&join.relation, names);
305                }
306            }
307        }
308        SetExpr::Query(query) => {
309            extract_tables_from_set_expr(&query.body, names);
310        }
311        SetExpr::SetOperation { left, right, .. } => {
312            extract_tables_from_set_expr(left, names);
313            extract_tables_from_set_expr(right, names);
314        }
315        _ => {}
316    };
317}
318
319/// Helper function for [extract_tables_from_query].
320///
321/// Handle [TableFactor].
322fn table_factor_to_object_name(table_factor: &TableFactor, names: &mut HashSet<ObjectName>) {
323    if let TableFactor::Table { name, .. } = table_factor {
324        names.insert(name.to_owned());
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use sqlparser::tokenizer::Token;
331
332    use super::*;
333    use crate::dialect::GreptimeDbDialect;
334    use crate::parser::{ParseOptions, ParserContext};
335    use crate::statements::statement::Statement;
336
337    #[test]
338    fn test_location_to_index() {
339        let testcases = vec![
340            "SELECT * FROM t WHERE a = 1",
341            // start or end with newline
342            r"
343SELECT *
344FROM
345t
346WHERE a =
3471
348",
349            r"SELECT *
350FROM
351t
352WHERE a =
3531
354",
355            r"
356SELECT *
357FROM
358t
359WHERE a =
3601",
361        ];
362
363        for sql in testcases {
364            let mut parser = ParserContext::new(&GreptimeDbDialect {}, sql).unwrap();
365            loop {
366                let token = parser.parser.next_token();
367                if token == Token::EOF {
368                    break;
369                }
370                let span = token.span;
371                let subslice =
372                    &sql[location_to_index(sql, &span.start)..location_to_index(sql, &span.end)];
373                assert_eq!(token.to_string(), subslice);
374            }
375        }
376    }
377
378    #[test]
379    fn test_extract_tables_from_tql_query() {
380        let testcases = vec![
381            (
382                r#"
383CREATE FLOW calc_reqs SINK TO cnt_reqs AS
384TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", http_requests);"#,
385                vec!["http_requests".to_string()],
386            ),
387            (
388                r#"
389CREATE FLOW calc_reqs SINK TO cnt_reqs AS
390TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", {__name__="http_requests"});"#,
391                vec!["http_requests".to_string()],
392            ),
393        ];
394
395        for (sql, expected_tables) in testcases {
396            let mut stmts = ParserContext::create_with_dialect(
397                sql,
398                &GreptimeDbDialect {},
399                ParseOptions::default(),
400            )
401            .unwrap();
402            let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
403                unreachable!()
404            };
405
406            let mut tables = extract_tables_from_query(&create_flow.query)
407                .map(|table| format_raw_object_name(&table))
408                .collect_vec();
409            tables.sort();
410            assert_eq!(expected_tables, tables);
411        }
412    }
413
414    #[test]
415    fn test_extract_tables_from_tql_query_with_schema_matcher() {
416        let sql = r#"
417CREATE FLOW calc_reqs SINK TO cnt_reqs AS
418TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", http_requests{__schema__="greptime_private"});"#;
419        let mut stmts =
420            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
421                .unwrap();
422        let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
423            unreachable!()
424        };
425
426        let mut tables = extract_tables_from_query(&create_flow.query)
427            .map(|table| format_raw_object_name(&table))
428            .collect_vec();
429        tables.sort();
430        assert_eq!(vec!["greptime_private.http_requests".to_string()], tables);
431    }
432}