1use std::fmt;
16
17use serde::Serialize;
18use snafu::ResultExt;
19use sqlparser::ast::helpers::attached_token::AttachedToken;
20use sqlparser::ast::{
21    Cte, Ident, ObjectName, Query as SpQuery, TableAlias, TableAliasColumnDef, With,
22};
23use sqlparser::keywords::Keyword;
24use sqlparser::parser::IsOptional;
25use sqlparser::tokenizer::Token;
26use sqlparser_derive::{Visit, VisitMut};
27
28use crate::dialect::GreptimeDbDialect;
29use crate::error::{self, Result};
30use crate::parser::{ParseOptions, ParserContext};
31use crate::parsers::tql_parser;
32use crate::statements::query::Query;
33use crate::statements::statement::Statement;
34use crate::statements::tql::Tql;
35use crate::util::location_to_index;
36
37#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
39pub enum CteContent {
40    Sql(Box<SpQuery>),
41    Tql(Tql),
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
46pub struct HybridCte {
47    pub name: Ident,
48    pub columns: Vec<ObjectName>,
50    pub content: CteContent,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
55pub struct HybridCteWith {
56    pub recursive: bool,
57    pub cte_tables: Vec<HybridCte>,
58}
59
60impl fmt::Display for HybridCteWith {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        write!(f, "WITH ")?;
63
64        if self.recursive {
65            write!(f, "RECURSIVE ")?;
66        }
67
68        for (i, cte) in self.cte_tables.iter().enumerate() {
69            if i > 0 {
70                write!(f, ", ")?;
71            }
72            write!(f, "{}", cte.name)?;
73
74            if !cte.columns.is_empty() {
75                write!(f, " (")?;
76                for (j, col) in cte.columns.iter().enumerate() {
77                    if j > 0 {
78                        write!(f, ", ")?;
79                    }
80                    write!(f, "{}", col)?;
81                }
82                write!(f, ")")?;
83            }
84
85            write!(f, " AS (")?;
86            match &cte.content {
87                CteContent::Sql(query) => write!(f, "{}", query)?,
88                CteContent::Tql(tql) => write!(f, "{}", tql)?,
89            }
90            write!(f, ")")?;
91        }
92        Ok(())
93    }
94}
95
96impl ParserContext<'_> {
98    pub(crate) fn parse_with_tql(&mut self) -> Result<Statement> {
100        self.parser
102            .expect_keyword(Keyword::WITH)
103            .context(error::SyntaxSnafu)?;
104
105        let recursive = self.parser.parse_keyword(Keyword::RECURSIVE);
107
108        let mut tql_cte_tables = Vec::new();
110        let mut sql_cte_tables = Vec::new();
111
112        loop {
113            let cte = self.parse_hybrid_cte()?;
114            match cte.content {
115                CteContent::Sql(body) => sql_cte_tables.push(Cte {
116                    alias: TableAlias {
117                        name: cte.name,
118                        columns: cte
119                            .columns
120                            .into_iter()
121                            .flat_map(|col| col.0[0].as_ident().cloned())
122                            .map(|name| TableAliasColumnDef {
123                                name,
124                                data_type: None,
125                            })
126                            .collect(),
127                    },
128                    query: body,
129                    from: None,
130                    materialized: None,
131                    closing_paren_token: AttachedToken::empty(),
132                }),
133                CteContent::Tql(_) => tql_cte_tables.push(cte),
134            }
135
136            if !self.parser.consume_token(&Token::Comma) {
137                break;
138            }
139        }
140
141        let main_query = self.parser.parse_query().context(error::SyntaxSnafu)?;
143
144        let hybrid_cte = HybridCteWith {
146            recursive,
147            cte_tables: tql_cte_tables,
148        };
149
150        let mut query = Query::try_from(*main_query)?;
152        query.hybrid_cte = Some(hybrid_cte);
153        query.inner.with = Some(With {
154            recursive,
155            cte_tables: sql_cte_tables,
156            with_token: AttachedToken::empty(),
157        });
158
159        Ok(Statement::Query(Box::new(query)))
160    }
161
162    fn parse_hybrid_cte(&mut self) -> Result<HybridCte> {
164        let name = self.parser.parse_identifier().context(error::SyntaxSnafu)?;
166        let name = Self::canonicalize_identifier(name);
167
168        let columns = self
170            .parser
171            .parse_parenthesized_qualified_column_list(IsOptional::Optional, true)
172            .context(error::SyntaxSnafu)?;
173
174        self.parser
176            .expect_keyword(Keyword::AS)
177            .context(error::SyntaxSnafu)?;
178
179        self.parser
181            .expect_token(&Token::LParen)
182            .context(error::SyntaxSnafu)?;
183
184        let content = self.parse_cte_content()?;
185
186        self.parser
187            .expect_token(&Token::RParen)
188            .context(error::SyntaxSnafu)?;
189
190        Ok(HybridCte {
191            name,
192            columns,
193            content,
194        })
195    }
196
197    fn parse_cte_content(&mut self) -> Result<CteContent> {
199        if let Token::Word(w) = &self.parser.peek_token().token
201            && w.keyword == Keyword::NoKeyword
202            && w.quote_style.is_none()
203            && w.value.to_uppercase() == tql_parser::TQL
204        {
205            let tql = self.parse_tql_content_in_cte()?;
206            return Ok(CteContent::Tql(tql));
207        }
208
209        let sql_query = self.parser.parse_query().context(error::SyntaxSnafu)?;
211        Ok(CteContent::Sql(sql_query))
212    }
213
214    fn parse_tql_content_in_cte(&mut self) -> Result<Tql> {
222        let tql_token = self.parser.next_token();
224        if tql_token.token == Token::EOF {
225            return Err(error::InvalidSqlSnafu {
226                msg: "Unexpected end of input while parsing TQL inside CTE".to_string(),
227            }
228            .build());
229        }
230
231        let start_location = tql_token.span.start;
232
233        let mut paren_depth = 0usize;
235        let end_location;
236
237        loop {
238            let token_with_span = self.parser.peek_token();
239
240            if token_with_span.token == Token::EOF {
242                return Err(error::InvalidSqlSnafu {
243                    msg: "Unexpected end of input while parsing TQL inside CTE".to_string(),
244                }
245                .build());
246            }
247
248            if token_with_span.token == Token::RParen && paren_depth == 0 {
250                end_location = token_with_span.span.start;
251                break;
252            }
253
254            let consumed = self.parser.next_token();
256            match consumed.token {
257                Token::LParen => paren_depth += 1,
258                Token::RParen => {
259                    paren_depth = paren_depth.saturating_sub(1);
262                }
263                _ => {}
264            }
265        }
266
267        let start_index = location_to_index(self.sql, &start_location);
269        let end_index = location_to_index(self.sql, &end_location);
270        let tql_string = &self.sql[start_index..end_index];
271        let tql_string = tql_string.trim();
272
273        let mut stmts = ParserContext::create_with_dialect(
275            tql_string,
276            &GreptimeDbDialect {},
277            ParseOptions::default(),
278        )?;
279
280        if stmts.len() != 1 {
281            return Err(error::InvalidSqlSnafu {
282                msg: "Expected a single TQL statement inside CTE".to_string(),
283            }
284            .build());
285        }
286
287        match stmts.remove(0) {
288            Statement::Tql(Tql::Eval(eval)) => Ok(Tql::Eval(eval)),
289            Statement::Tql(_) => Err(error::InvalidSqlSnafu {
290                msg: "Only TQL EVAL is supported in CTEs".to_string(),
291            }
292            .build()),
293            _ => Err(error::InvalidSqlSnafu {
294                msg: "Expected a TQL statement inside CTE".to_string(),
295            }
296            .build()),
297        }
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use crate::dialect::GreptimeDbDialect;
304    use crate::parser::{ParseOptions, ParserContext};
305    use crate::parsers::with_tql_parser::CteContent;
306    use crate::statements::statement::Statement;
307    use crate::statements::tql::Tql;
308
309    #[test]
310    fn test_parse_hybrid_cte_with_parentheses_in_query() {
311        let sql = r#"
313            WITH tql_cte AS (
314                TQL EVAL (0, 100, '5s') 
315                sum(rate(http_requests_total[1m])) + (max(cpu_usage) * (1 + 0.5))
316            ) 
317            SELECT * FROM tql_cte
318        "#;
319
320        let statements =
321            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
322                .unwrap();
323        assert_eq!(statements.len(), 1);
324
325        let Statement::Query(query) = &statements[0] else {
326            panic!("Expected Query statement");
327        };
328        let hybrid_cte = query.hybrid_cte.as_ref().unwrap();
329        assert_eq!(hybrid_cte.cte_tables.len(), 1);
330
331        assert!(matches!(
333            hybrid_cte.cte_tables[0].content,
334            CteContent::Tql(_)
335        ));
336
337        if let CteContent::Tql(Tql::Eval(eval)) = &hybrid_cte.cte_tables[0].content {
339            assert!(eval.query.contains("sum(rate(http_requests_total[1m]))"));
342            assert!(eval.query.contains("(max(cpu_usage) * (1 + 0.5))"));
343            assert!(eval.query.contains("+ (max"));
345        }
346    }
347
348    #[test]
349    fn test_parse_hybrid_cte_sql_and_tql() {
350        let sql = r#"
351            WITH 
352                sql_cte(ts, value, label) AS (SELECT timestamp, val, name FROM metrics),
353                tql_cte(time, metric_value) AS (TQL EVAL (0, 100, '5s') cpu_usage)
354            SELECT s.ts, s.value, t.metric_value 
355            FROM sql_cte s JOIN tql_cte t ON s.ts = t.time
356        "#;
357
358        let statements =
359            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
360                .unwrap();
361        assert_eq!(statements.len(), 1);
362
363        let Statement::Query(query) = &statements[0] else {
364            panic!("Expected Query statement");
365        };
366        let hybrid_cte = query.hybrid_cte.as_ref().unwrap();
367        assert_eq!(hybrid_cte.cte_tables.len(), 1); let second_cte = &hybrid_cte.cte_tables[0];
371        assert!(matches!(second_cte.content, CteContent::Tql(_)));
372        assert_eq!(second_cte.columns.len(), 2);
373        assert_eq!(
374            second_cte
375                .columns
376                .iter()
377                .map(|x| x.to_string())
378                .collect::<Vec<_>>()
379                .join(" "),
380            "time metric_value"
381        );
382    }
383}