sql/parsers/
with_tql_parser.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16
17use serde::Serialize;
18use snafu::ResultExt;
19use sqlparser::ast::helpers::attached_token::AttachedToken;
20use sqlparser::ast::{
21    Cte, Ident, ObjectName, Query as SpQuery, TableAlias, TableAliasColumnDef, With,
22};
23use sqlparser::keywords::Keyword;
24use sqlparser::parser::IsOptional;
25use sqlparser::tokenizer::Token;
26use sqlparser_derive::{Visit, VisitMut};
27
28use crate::dialect::GreptimeDbDialect;
29use crate::error::{self, Result};
30use crate::parser::ParserContext;
31use crate::parsers::tql_parser;
32use crate::statements::query::Query;
33use crate::statements::statement::Statement;
34use crate::statements::tql::Tql;
35use crate::util::location_to_index;
36
37/// Content of a CTE - either SQL or TQL
38#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
39pub enum CteContent {
40    Sql(Box<SpQuery>),
41    Tql(Tql),
42}
43
44/// A hybrid CTE that can contain either SQL or TQL
45#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
46pub struct HybridCte {
47    pub name: Ident,
48    /// Column aliases for the CTE table. Empty if not specified.
49    pub columns: Vec<ObjectName>,
50    pub content: CteContent,
51}
52
53/// Extended WITH clause that supports hybrid SQL/TQL CTEs
54#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
55pub struct HybridCteWith {
56    pub recursive: bool,
57    pub cte_tables: Vec<HybridCte>,
58}
59
60impl fmt::Display for HybridCteWith {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        write!(f, "WITH ")?;
63
64        if self.recursive {
65            write!(f, "RECURSIVE ")?;
66        }
67
68        for (i, cte) in self.cte_tables.iter().enumerate() {
69            if i > 0 {
70                write!(f, ", ")?;
71            }
72            write!(f, "{}", cte.name)?;
73
74            if !cte.columns.is_empty() {
75                write!(f, " (")?;
76                for (j, col) in cte.columns.iter().enumerate() {
77                    if j > 0 {
78                        write!(f, ", ")?;
79                    }
80                    write!(f, "{}", col)?;
81                }
82                write!(f, ")")?;
83            }
84
85            write!(f, " AS (")?;
86            match &cte.content {
87                CteContent::Sql(query) => write!(f, "{}", query)?,
88                CteContent::Tql(tql) => write!(f, "{}", tql)?,
89            }
90            write!(f, ")")?;
91        }
92        Ok(())
93    }
94}
95
96/// Parser implementation for hybrid WITH clauses containing TQL
97impl ParserContext<'_> {
98    /// Parse a WITH clause that may contain TQL CTEs or SQL CTEs.
99    pub(crate) fn parse_with_tql(&mut self) -> Result<Statement> {
100        self.parse_with_tql_with_now(false)
101    }
102
103    pub(crate) fn parse_with_tql_with_now(&mut self, require_now_expr: bool) -> Result<Statement> {
104        // Consume the WITH token
105        self.parser
106            .expect_keyword(Keyword::WITH)
107            .context(error::SyntaxSnafu)?;
108
109        // Check for RECURSIVE keyword
110        let recursive = self.parser.parse_keyword(Keyword::RECURSIVE);
111
112        // Parse the CTE list
113        let mut tql_cte_tables = Vec::new();
114        let mut sql_cte_tables = Vec::new();
115
116        loop {
117            let cte = self.parse_hybrid_cte(require_now_expr)?;
118            match cte.content {
119                CteContent::Sql(body) => sql_cte_tables.push(Cte {
120                    alias: TableAlias {
121                        explicit: false,
122                        name: cte.name,
123                        columns: cte
124                            .columns
125                            .into_iter()
126                            .flat_map(|col| col.0[0].as_ident().cloned())
127                            .map(|name| TableAliasColumnDef {
128                                name,
129                                data_type: None,
130                            })
131                            .collect(),
132                    },
133                    query: body,
134                    from: None,
135                    materialized: None,
136                    closing_paren_token: AttachedToken::empty(),
137                }),
138                CteContent::Tql(_) => tql_cte_tables.push(cte),
139            }
140
141            if !self.parser.consume_token(&Token::Comma) {
142                break;
143            }
144        }
145
146        // Parse the main query
147        let main_query = self.parser.parse_query().context(error::SyntaxSnafu)?;
148
149        // Convert the hybrid CTEs to a standard query with hybrid metadata
150        let hybrid_cte = HybridCteWith {
151            recursive,
152            cte_tables: tql_cte_tables,
153        };
154
155        // Create a Query statement with hybrid CTE metadata
156        let mut query = Query::try_from(*main_query)?;
157        query.hybrid_cte = Some(hybrid_cte);
158        query.inner.with = Some(With {
159            recursive,
160            cte_tables: sql_cte_tables,
161            with_token: AttachedToken::empty(),
162        });
163
164        Ok(Statement::Query(Box::new(query)))
165    }
166
167    /// Parse a single CTE that can be either SQL or TQL
168    fn parse_hybrid_cte(&mut self, require_now_expr: bool) -> Result<HybridCte> {
169        // Parse CTE name
170        let name = self.parser.parse_identifier().context(error::SyntaxSnafu)?;
171        let name = Self::canonicalize_identifier(name);
172
173        // Parse optional column list
174        let columns = self
175            .parser
176            .parse_parenthesized_qualified_column_list(IsOptional::Optional, true)
177            .context(error::SyntaxSnafu)?;
178
179        // Expect AS keyword
180        self.parser
181            .expect_keyword(Keyword::AS)
182            .context(error::SyntaxSnafu)?;
183
184        // Parse the CTE content
185        self.parser
186            .expect_token(&Token::LParen)
187            .context(error::SyntaxSnafu)?;
188
189        let content = self.parse_cte_content(require_now_expr)?;
190
191        self.parser
192            .expect_token(&Token::RParen)
193            .context(error::SyntaxSnafu)?;
194
195        Ok(HybridCte {
196            name,
197            columns,
198            content,
199        })
200    }
201
202    /// Determine if CTE contains TQL or SQL and parse accordingly
203    fn parse_cte_content(&mut self, require_now_expr: bool) -> Result<CteContent> {
204        // Check if the next token is TQL
205        if let Token::Word(w) = &self.parser.peek_token().token
206            && w.keyword == Keyword::NoKeyword
207            && w.quote_style.is_none()
208            && w.value.to_uppercase() == tql_parser::TQL
209        {
210            let tql = self.parse_tql_content_in_cte(require_now_expr)?;
211            return Ok(CteContent::Tql(tql));
212        }
213
214        // Parse as SQL query
215        let sql_query = self.parser.parse_query().context(error::SyntaxSnafu)?;
216        Ok(CteContent::Sql(sql_query))
217    }
218
219    /// Parse TQL content within a CTE by extracting the raw query string.
220    ///
221    /// This method consumes all tokens that belong to the TQL statement and
222    /// stops right **before** the closing `)` of the CTE so that the caller
223    /// can handle it normally.
224    ///
225    /// Only `TQL EVAL` is supported inside CTEs.
226    fn parse_tql_content_in_cte(&mut self, require_now_expr: bool) -> Result<Tql> {
227        // Consume and get the position of the TQL keyword
228        let tql_token = self.parser.next_token();
229        if tql_token.token == Token::EOF {
230            return Err(error::InvalidSqlSnafu {
231                msg: "Unexpected end of input while parsing TQL inside CTE".to_string(),
232            }
233            .build());
234        }
235
236        let start_location = tql_token.span.start;
237
238        // Track parentheses depth to find the end of the CTE
239        let mut paren_depth = 0usize;
240        let end_location;
241
242        loop {
243            let token_with_span = self.parser.peek_token();
244
245            // Guard against unexpected EOF
246            if token_with_span.token == Token::EOF {
247                return Err(error::InvalidSqlSnafu {
248                    msg: "Unexpected end of input while parsing TQL inside CTE".to_string(),
249                }
250                .build());
251            }
252
253            // Stop **before** the closing parenthesis that ends the CTE
254            if token_with_span.token == Token::RParen && paren_depth == 0 {
255                end_location = token_with_span.span.start;
256                break;
257            }
258
259            // Consume the token and track parentheses depth
260            let consumed = self.parser.next_token();
261            match consumed.token {
262                Token::LParen => paren_depth += 1,
263                Token::RParen => {
264                    // This RParen must belong to a nested expression since
265                    // `paren_depth > 0` here. Decrease depth accordingly.
266                    paren_depth = paren_depth.saturating_sub(1);
267                }
268                _ => {}
269            }
270        }
271
272        // Extract the TQL query string directly from the original SQL
273        let start_index = location_to_index(self.sql, &start_location);
274        let end_index = location_to_index(self.sql, &end_location);
275        let tql_string = &self.sql[start_index..end_index];
276        let tql_string = tql_string.trim();
277
278        let mut parser_ctx = ParserContext::new(&GreptimeDbDialect {}, tql_string)?;
279        let statement = parser_ctx.parse_tql(require_now_expr)?;
280
281        match statement {
282            Statement::Tql(Tql::Eval(eval)) => Ok(Tql::Eval(eval)),
283            Statement::Tql(_) => Err(error::InvalidSqlSnafu {
284                msg: "Only TQL EVAL is supported in CTEs".to_string(),
285            }
286            .build()),
287            _ => Err(error::InvalidSqlSnafu {
288                msg: "Expected a TQL statement inside CTE".to_string(),
289            }
290            .build()),
291        }
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use crate::dialect::GreptimeDbDialect;
298    use crate::parser::{ParseOptions, ParserContext};
299    use crate::parsers::with_tql_parser::CteContent;
300    use crate::statements::statement::Statement;
301    use crate::statements::tql::Tql;
302
303    #[test]
304    fn test_parse_hybrid_cte_with_parentheses_in_query() {
305        // Test that parentheses within the TQL query don't interfere with CTE parsing
306        let sql = r#"
307            WITH tql_cte AS (
308                TQL EVAL (0, 100, '5s')
309                sum(rate(http_requests_total[1m])) + (max(cpu_usage) * (1 + 0.5))
310            )
311            SELECT * FROM tql_cte
312        "#;
313
314        let statements =
315            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
316                .unwrap();
317        assert_eq!(statements.len(), 1);
318
319        let Statement::Query(query) = &statements[0] else {
320            panic!("Expected Query statement");
321        };
322        let hybrid_cte = query.hybrid_cte.as_ref().unwrap();
323        assert_eq!(hybrid_cte.cte_tables.len(), 1);
324
325        // Should be TQL content
326        assert!(matches!(
327            hybrid_cte.cte_tables[0].content,
328            CteContent::Tql(_)
329        ));
330
331        // Check that the query includes the parentheses
332        if let CteContent::Tql(Tql::Eval(eval)) = &hybrid_cte.cte_tables[0].content {
333            // Verify that complex nested parentheses are preserved correctly
334            // The new approach preserves original spacing, so no extra spaces between tokens
335            assert!(eval.query.contains("sum(rate(http_requests_total[1m]))"));
336            assert!(eval.query.contains("(max(cpu_usage) * (1 + 0.5))"));
337            // Most importantly, verify the parentheses counting didn't break the parsing
338            assert!(eval.query.contains("+ (max"));
339        }
340    }
341
342    #[test]
343    fn test_parse_hybrid_cte_sql_and_tql() {
344        let sql = r#"
345            WITH
346                sql_cte(ts, value, label) AS (SELECT timestamp, val, name FROM metrics),
347                tql_cte(time, metric_value) AS (TQL EVAL (0, 100, '5s') cpu_usage)
348            SELECT s.ts, s.value, t.metric_value
349            FROM sql_cte s JOIN tql_cte t ON s.ts = t.time
350        "#;
351
352        let statements =
353            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
354                .unwrap();
355        assert_eq!(statements.len(), 1);
356
357        let Statement::Query(query) = &statements[0] else {
358            panic!("Expected Query statement");
359        };
360        let hybrid_cte = query.hybrid_cte.as_ref().unwrap();
361        assert_eq!(hybrid_cte.cte_tables.len(), 1); // only TQL CTE presents here
362
363        // First CTE should be TQL with column aliases
364        let second_cte = &hybrid_cte.cte_tables[0];
365        assert!(matches!(second_cte.content, CteContent::Tql(_)));
366        assert_eq!(second_cte.columns.len(), 2);
367        assert_eq!(
368            second_cte
369                .columns
370                .iter()
371                .map(|x| x.to_string())
372                .collect::<Vec<_>>()
373                .join(" "),
374            "time metric_value"
375        );
376    }
377}