sql/
util.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::fmt::{Display, Formatter};
17
18use itertools::Itertools;
19use promql_parser::label::{METRIC_NAME, MatchOp};
20use promql_parser::parser::{
21    AggregateExpr as PromAggregateExpr, BinaryExpr as PromBinaryExpr, Call as PromCall,
22    Expr as PromExpr, MatrixSelector as PromMatrixSelector, ParenExpr as PromParenExpr,
23    SubqueryExpr as PromSubqueryExpr, UnaryExpr as PromUnaryExpr,
24    VectorSelector as PromVectorSelector,
25};
26use serde::Serialize;
27use snafu::ensure;
28use sqlparser::ast::{
29    Array, Expr, Ident, ObjectName, ObjectNamePart, SetExpr, SqlOption, StructField, TableFactor,
30    Value, ValueWithSpan,
31};
32use sqlparser_derive::{Visit, VisitMut};
33
34use crate::ast::ObjectNamePartExt;
35use crate::error::{InvalidExprAsOptionValueSnafu, InvalidSqlSnafu, Result};
36use crate::parser::ParserContext;
37use crate::parsers::with_tql_parser::CteContent;
38use crate::statements::create::SqlOrTql;
39use crate::statements::query::Query;
40use crate::statements::tql::Tql;
41
42const SCHEMA_MATCHER: &str = "__schema__";
43const DATABASE_MATCHER: &str = "__database__";
44
45/// Format an [ObjectName] without any quote of its idents.
46pub fn format_raw_object_name(name: &ObjectName) -> String {
47    struct Inner<'a> {
48        name: &'a ObjectName,
49    }
50
51    impl Display for Inner<'_> {
52        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
53            let mut delim = "";
54            for ident in self.name.0.iter() {
55                write!(f, "{delim}")?;
56                delim = ".";
57                write!(f, "{}", ident.to_string_unquoted())?;
58            }
59            Ok(())
60        }
61    }
62
63    format!("{}", Inner { name })
64}
65
66#[derive(Clone, Debug, PartialEq, Eq, Serialize, Visit, VisitMut)]
67pub struct OptionValue(Expr);
68
69impl OptionValue {
70    pub(crate) fn try_new(expr: Expr) -> Result<Self> {
71        ensure!(
72            matches!(
73                expr,
74                Expr::Value(_) | Expr::Identifier(_) | Expr::Array(_) | Expr::Struct { .. }
75            ),
76            InvalidExprAsOptionValueSnafu {
77                error: format!("{expr} not accepted")
78            }
79        );
80        Ok(Self(expr))
81    }
82
83    fn expr_as_string(expr: &Expr) -> Option<&str> {
84        match expr {
85            Expr::Value(ValueWithSpan { value, .. }) => match value {
86                Value::SingleQuotedString(s)
87                | Value::DoubleQuotedString(s)
88                | Value::TripleSingleQuotedString(s)
89                | Value::TripleDoubleQuotedString(s)
90                | Value::SingleQuotedByteStringLiteral(s)
91                | Value::DoubleQuotedByteStringLiteral(s)
92                | Value::TripleSingleQuotedByteStringLiteral(s)
93                | Value::TripleDoubleQuotedByteStringLiteral(s)
94                | Value::SingleQuotedRawStringLiteral(s)
95                | Value::DoubleQuotedRawStringLiteral(s)
96                | Value::TripleSingleQuotedRawStringLiteral(s)
97                | Value::TripleDoubleQuotedRawStringLiteral(s)
98                | Value::EscapedStringLiteral(s)
99                | Value::UnicodeStringLiteral(s)
100                | Value::NationalStringLiteral(s)
101                | Value::HexStringLiteral(s) => Some(s),
102                Value::DollarQuotedString(s) => Some(&s.value),
103                Value::Number(s, _) => Some(s),
104                Value::Boolean(b) => Some(if *b { "true" } else { "false" }),
105                _ => None,
106            },
107            Expr::Identifier(ident) => Some(&ident.value),
108            _ => None,
109        }
110    }
111
112    /// Convert the option value to a string.
113    ///
114    /// Notes: Not all values can be converted to a string, refer to [Self::expr_as_string] for more details.
115    pub fn as_string(&self) -> Option<&str> {
116        Self::expr_as_string(&self.0)
117    }
118
119    pub fn as_list(&self) -> Option<Vec<&str>> {
120        let expr = &self.0;
121        match expr {
122            Expr::Value(_) | Expr::Identifier(_) => self.as_string().map(|s| vec![s]),
123            Expr::Array(array) => array
124                .elem
125                .iter()
126                .map(Self::expr_as_string)
127                .collect::<Option<Vec<_>>>(),
128            _ => None,
129        }
130    }
131
132    pub(crate) fn as_struct_fields(&self) -> Option<&[StructField]> {
133        match &self.0 {
134            Expr::Struct { fields, .. } => Some(fields),
135            _ => None,
136        }
137    }
138}
139
140impl From<String> for OptionValue {
141    fn from(value: String) -> Self {
142        Self(Expr::Identifier(Ident::new(value)))
143    }
144}
145
146impl From<&str> for OptionValue {
147    fn from(value: &str) -> Self {
148        Self(Expr::Identifier(Ident::new(value)))
149    }
150}
151
152impl From<Vec<&str>> for OptionValue {
153    fn from(value: Vec<&str>) -> Self {
154        Self(Expr::Array(Array {
155            elem: value
156                .into_iter()
157                .map(|x| Expr::Identifier(Ident::new(x)))
158                .collect(),
159            named: false,
160        }))
161    }
162}
163
164impl Display for OptionValue {
165    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
166        if let Some(s) = self.as_string() {
167            write!(f, "'{s}'")
168        } else if let Some(s) = self.as_list() {
169            write!(
170                f,
171                "[{}]",
172                s.into_iter().map(|x| format!("'{x}'")).join(", ")
173            )
174        } else {
175            write!(f, "'{}'", self.0)
176        }
177    }
178}
179
180pub fn parse_option_string(option: SqlOption) -> Result<(String, OptionValue)> {
181    let SqlOption::KeyValue { key, value } = option else {
182        return InvalidSqlSnafu {
183            msg: "Expecting a key-value pair in the option",
184        }
185        .fail();
186    };
187    let v = OptionValue::try_new(value)?;
188    let k = key.value.to_lowercase();
189    Ok((k, v))
190}
191
192/// Walk through a [Query] and extract all the tables referenced in it.
193pub fn extract_tables_from_query(query: &SqlOrTql) -> impl Iterator<Item = ObjectName> {
194    let mut names = HashSet::new();
195
196    match query {
197        SqlOrTql::Sql(query, _) => {
198            extract_tables_from_set_expr(&query.inner.body, &mut names);
199            extract_tables_from_hybrid_cte_query(query, &mut names);
200        }
201        SqlOrTql::Tql(tql, _) => extract_tables_from_tql(tql, &mut names),
202    }
203
204    names.into_iter()
205}
206
207fn extract_tables_from_hybrid_cte_query(query: &Query, sql_names: &mut HashSet<ObjectName>) {
208    let mut tql_names = HashSet::new();
209    let mut cte_names: HashSet<String> = HashSet::new();
210    if let Some(hybrid_cte) = &query.hybrid_cte {
211        for cte in &hybrid_cte.cte_tables {
212            cte_names.insert(ParserContext::canonicalize_identifier(cte.name.clone()).value);
213            if let CteContent::Tql(tql) = &cte.content {
214                extract_tables_from_tql(tql, &mut tql_names);
215            }
216        }
217    }
218
219    if let Some(with) = &query.inner.with {
220        for cte in &with.cte_tables {
221            cte_names.insert(ParserContext::canonicalize_identifier(cte.alias.name.clone()).value);
222        }
223    }
224
225    remove_cte_names(sql_names, &cte_names);
226
227    sql_names.extend(tql_names);
228}
229
230fn remove_cte_names(names: &mut HashSet<ObjectName>, cte_names: &HashSet<String>) {
231    if cte_names.is_empty() {
232        return;
233    }
234
235    names.retain(|name| {
236        if name.0.len() != 1 {
237            return true;
238        }
239        let Some(ident) = name.0[0].as_ident() else {
240            return true;
241        };
242
243        let canonical = ParserContext::canonicalize_identifier(ident.clone()).value;
244        !cte_names.contains(&canonical)
245    });
246}
247
248fn extract_tables_from_tql(tql: &Tql, names: &mut HashSet<ObjectName>) {
249    let promql = match tql {
250        Tql::Eval(eval) => &eval.query,
251        Tql::Explain(explain) => &explain.query,
252        Tql::Analyze(analyze) => &analyze.query,
253    };
254
255    if let Ok(expr) = promql_parser::parser::parse(promql) {
256        extract_tables_from_prom_expr(&expr, names);
257    }
258}
259
260fn extract_tables_from_prom_expr(expr: &PromExpr, names: &mut HashSet<ObjectName>) {
261    match expr {
262        PromExpr::Aggregate(PromAggregateExpr { expr, .. }) => {
263            extract_tables_from_prom_expr(expr, names);
264        }
265        PromExpr::Unary(PromUnaryExpr { expr, .. }) => {
266            extract_tables_from_prom_expr(expr, names);
267        }
268        PromExpr::Binary(PromBinaryExpr { lhs, rhs, .. }) => {
269            extract_tables_from_prom_expr(lhs, names);
270            extract_tables_from_prom_expr(rhs, names);
271        }
272        PromExpr::Paren(PromParenExpr { expr }) => {
273            extract_tables_from_prom_expr(expr, names);
274        }
275        PromExpr::Subquery(PromSubqueryExpr { expr, .. }) => {
276            extract_tables_from_prom_expr(expr, names);
277        }
278        PromExpr::VectorSelector(selector) => {
279            extract_metric_name_from_vector_selector(selector, names);
280        }
281        PromExpr::MatrixSelector(PromMatrixSelector { vs, .. }) => {
282            extract_metric_name_from_vector_selector(vs, names);
283        }
284        PromExpr::Call(PromCall { args, .. }) => {
285            for arg in &args.args {
286                extract_tables_from_prom_expr(arg, names);
287            }
288        }
289        PromExpr::NumberLiteral(_) | PromExpr::StringLiteral(_) | PromExpr::Extension(_) => {}
290    }
291}
292
293fn extract_metric_name_from_vector_selector(
294    selector: &PromVectorSelector,
295    names: &mut HashSet<ObjectName>,
296) {
297    let metric_name = selector.name.clone().or_else(|| {
298        let mut metric_name_matchers = selector.matchers.find_matchers(METRIC_NAME);
299        if metric_name_matchers.len() == 1 && metric_name_matchers[0].op == MatchOp::Equal {
300            metric_name_matchers.pop().map(|matcher| matcher.value)
301        } else {
302            None
303        }
304    });
305    let Some(metric_name) = metric_name else {
306        return;
307    };
308
309    let schema_matcher = selector.matchers.matchers.iter().rev().find(|matcher| {
310        matcher.op == MatchOp::Equal
311            && (matcher.name == SCHEMA_MATCHER || matcher.name == DATABASE_MATCHER)
312    });
313
314    if let Some(schema) = schema_matcher {
315        names.insert(ObjectName(vec![
316            ObjectNamePart::Identifier(Ident::new(&schema.value)),
317            ObjectNamePart::Identifier(Ident::new(metric_name)),
318        ]));
319    } else {
320        names.insert(ObjectName(vec![ObjectNamePart::Identifier(Ident::new(
321            metric_name,
322        ))]));
323    }
324}
325
326/// translate the start location to the index in the sql string
327pub fn location_to_index(sql: &str, location: &sqlparser::tokenizer::Location) -> usize {
328    let mut index = 0;
329    for (lno, line) in sql.lines().enumerate() {
330        if lno + 1 == location.line as usize {
331            index += location.column as usize;
332            break;
333        } else {
334            index += line.len() + 1; // +1 for the newline
335        }
336    }
337    // -1 because the index is 0-based
338    // and the location is 1-based
339    index - 1
340}
341
342/// Helper function for [extract_tables_from_query].
343///
344/// Handle [SetExpr].
345fn extract_tables_from_set_expr(set_expr: &SetExpr, names: &mut HashSet<ObjectName>) {
346    match set_expr {
347        SetExpr::Select(select) => {
348            for from in &select.from {
349                table_factor_to_object_name(&from.relation, names);
350                for join in &from.joins {
351                    table_factor_to_object_name(&join.relation, names);
352                }
353            }
354        }
355        SetExpr::Query(query) => {
356            extract_tables_from_set_expr(&query.body, names);
357        }
358        SetExpr::SetOperation { left, right, .. } => {
359            extract_tables_from_set_expr(left, names);
360            extract_tables_from_set_expr(right, names);
361        }
362        _ => {}
363    };
364}
365
366/// Helper function for [extract_tables_from_query].
367///
368/// Handle [TableFactor].
369fn table_factor_to_object_name(table_factor: &TableFactor, names: &mut HashSet<ObjectName>) {
370    if let TableFactor::Table { name, .. } = table_factor {
371        names.insert(name.to_owned());
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use sqlparser::tokenizer::Token;
378
379    use super::*;
380    use crate::dialect::GreptimeDbDialect;
381    use crate::parser::{ParseOptions, ParserContext};
382    use crate::statements::statement::Statement;
383
384    #[test]
385    fn test_location_to_index() {
386        let testcases = vec![
387            "SELECT * FROM t WHERE a = 1",
388            // start or end with newline
389            r"
390SELECT *
391FROM
392t
393WHERE a =
3941
395",
396            r"SELECT *
397FROM
398t
399WHERE a =
4001
401",
402            r"
403SELECT *
404FROM
405t
406WHERE a =
4071",
408        ];
409
410        for sql in testcases {
411            let mut parser = ParserContext::new(&GreptimeDbDialect {}, sql).unwrap();
412            loop {
413                let token = parser.parser.next_token();
414                if token == Token::EOF {
415                    break;
416                }
417                let span = token.span;
418                let subslice =
419                    &sql[location_to_index(sql, &span.start)..location_to_index(sql, &span.end)];
420                assert_eq!(token.to_string(), subslice);
421            }
422        }
423    }
424
425    #[test]
426    fn test_extract_tables_from_tql_query() {
427        let testcases = vec![
428            (
429                r#"
430CREATE FLOW calc_reqs SINK TO cnt_reqs AS
431TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", http_requests);"#,
432                vec!["http_requests".to_string()],
433            ),
434            (
435                r#"
436CREATE FLOW calc_reqs SINK TO cnt_reqs AS
437TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", {__name__="http_requests"});"#,
438                vec!["http_requests".to_string()],
439            ),
440        ];
441
442        for (sql, expected_tables) in testcases {
443            let mut stmts = ParserContext::create_with_dialect(
444                sql,
445                &GreptimeDbDialect {},
446                ParseOptions::default(),
447            )
448            .unwrap();
449            let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
450                unreachable!()
451            };
452
453            let mut tables = extract_tables_from_query(&create_flow.query)
454                .map(|table| format_raw_object_name(&table))
455                .collect_vec();
456            tables.sort();
457            assert_eq!(expected_tables, tables);
458        }
459    }
460
461    #[test]
462    fn test_extract_tables_from_tql_query_with_schema_matcher() {
463        let sql = r#"
464CREATE FLOW calc_reqs SINK TO cnt_reqs AS
465TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", http_requests{__schema__="greptime_private"});"#;
466        let mut stmts =
467            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
468                .unwrap();
469        let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
470            unreachable!()
471        };
472
473        let mut tables = extract_tables_from_query(&create_flow.query)
474            .map(|table| format_raw_object_name(&table))
475            .collect_vec();
476        tables.sort();
477        assert_eq!(vec!["greptime_private.http_requests".to_string()], tables);
478    }
479}