1use std::fmt;
16
17use serde::Serialize;
18use snafu::ResultExt;
19use sqlparser::ast::helpers::attached_token::AttachedToken;
20use sqlparser::ast::{
21 Cte, Ident, ObjectName, Query as SpQuery, TableAlias, TableAliasColumnDef, With,
22};
23use sqlparser::keywords::Keyword;
24use sqlparser::parser::IsOptional;
25use sqlparser::tokenizer::Token;
26use sqlparser_derive::{Visit, VisitMut};
27
28use crate::dialect::GreptimeDbDialect;
29use crate::error::{self, Result};
30use crate::parser::{ParseOptions, ParserContext};
31use crate::parsers::tql_parser;
32use crate::statements::query::Query;
33use crate::statements::statement::Statement;
34use crate::statements::tql::Tql;
35
36#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
38pub enum CteContent {
39 Sql(Box<SpQuery>),
40 Tql(Tql),
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
45pub struct HybridCte {
46 pub name: Ident,
47 pub columns: Vec<ObjectName>,
49 pub content: CteContent,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
54pub struct HybridCteWith {
55 pub recursive: bool,
56 pub cte_tables: Vec<HybridCte>,
57}
58
59impl fmt::Display for HybridCteWith {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 write!(f, "WITH ")?;
62
63 if self.recursive {
64 write!(f, "RECURSIVE ")?;
65 }
66
67 for (i, cte) in self.cte_tables.iter().enumerate() {
68 if i > 0 {
69 write!(f, ", ")?;
70 }
71 write!(f, "{}", cte.name)?;
72
73 if !cte.columns.is_empty() {
74 write!(f, " (")?;
75 for (j, col) in cte.columns.iter().enumerate() {
76 if j > 0 {
77 write!(f, ", ")?;
78 }
79 write!(f, "{}", col)?;
80 }
81 write!(f, ")")?;
82 }
83
84 write!(f, " AS (")?;
85 match &cte.content {
86 CteContent::Sql(query) => write!(f, "{}", query)?,
87 CteContent::Tql(tql) => write!(f, "{}", tql)?,
88 }
89 write!(f, ")")?;
90 }
91 Ok(())
92 }
93}
94
95impl ParserContext<'_> {
97 pub(crate) fn parse_with_tql(&mut self) -> Result<Statement> {
99 self.parser
101 .expect_keyword(Keyword::WITH)
102 .context(error::SyntaxSnafu)?;
103
104 let recursive = self.parser.parse_keyword(Keyword::RECURSIVE);
106
107 let mut tql_cte_tables = Vec::new();
109 let mut sql_cte_tables = Vec::new();
110
111 loop {
112 let cte = self.parse_hybrid_cte()?;
113 match cte.content {
114 CteContent::Sql(body) => sql_cte_tables.push(Cte {
115 alias: TableAlias {
116 name: cte.name,
117 columns: cte
118 .columns
119 .into_iter()
120 .map(|col| TableAliasColumnDef {
121 name: col.0[0].clone(),
122 data_type: None,
123 })
124 .collect(),
125 },
126 query: body,
127 from: None,
128 materialized: None,
129 closing_paren_token: AttachedToken::empty(),
130 }),
131 CteContent::Tql(_) => tql_cte_tables.push(cte),
132 }
133
134 if !self.parser.consume_token(&Token::Comma) {
135 break;
136 }
137 }
138
139 let main_query = self.parser.parse_query().context(error::SyntaxSnafu)?;
141
142 let hybrid_cte = HybridCteWith {
144 recursive,
145 cte_tables: tql_cte_tables,
146 };
147
148 let mut query = Query::try_from(*main_query)?;
150 query.hybrid_cte = Some(hybrid_cte);
151 query.inner.with = Some(With {
152 recursive,
153 cte_tables: sql_cte_tables,
154 with_token: AttachedToken::empty(),
155 });
156
157 Ok(Statement::Query(Box::new(query)))
158 }
159
160 fn parse_hybrid_cte(&mut self) -> Result<HybridCte> {
162 let name = self.parser.parse_identifier().context(error::SyntaxSnafu)?;
164 let name = Self::canonicalize_identifier(name);
165
166 let columns = self
168 .parser
169 .parse_parenthesized_qualified_column_list(IsOptional::Optional, true)
170 .context(error::SyntaxSnafu)?;
171
172 self.parser
174 .expect_keyword(Keyword::AS)
175 .context(error::SyntaxSnafu)?;
176
177 self.parser
179 .expect_token(&Token::LParen)
180 .context(error::SyntaxSnafu)?;
181
182 let content = self.parse_cte_content()?;
183
184 self.parser
185 .expect_token(&Token::RParen)
186 .context(error::SyntaxSnafu)?;
187
188 Ok(HybridCte {
189 name,
190 columns,
191 content,
192 })
193 }
194
195 fn parse_cte_content(&mut self) -> Result<CteContent> {
197 if let Token::Word(w) = &self.parser.peek_token().token {
199 if w.keyword == Keyword::NoKeyword
200 && w.quote_style.is_none()
201 && w.value.to_uppercase() == tql_parser::TQL
202 {
203 let tql = self.parse_tql_content_in_cte()?;
204 return Ok(CteContent::Tql(tql));
205 }
206 }
207
208 let sql_query = self.parser.parse_query().context(error::SyntaxSnafu)?;
210 Ok(CteContent::Sql(sql_query))
211 }
212
213 fn parse_tql_content_in_cte(&mut self) -> Result<Tql> {
221 let mut collected: Vec<Token> = Vec::new();
222 let mut paren_depth = 0usize;
223
224 loop {
225 let token_with_span = self.parser.peek_token();
226
227 if token_with_span.token == Token::EOF {
229 return Err(error::InvalidSqlSnafu {
230 msg: "Unexpected end of input while parsing TQL inside CTE".to_string(),
231 }
232 .build());
233 }
234
235 if token_with_span.token == Token::RParen && paren_depth == 0 {
237 break;
238 }
239
240 let consumed = self.parser.next_token();
242 match consumed.token {
243 Token::LParen => paren_depth += 1,
244 Token::RParen => {
245 paren_depth = paren_depth.saturating_sub(1);
248 }
249 _ => {}
250 }
251
252 collected.push(consumed.token);
253 }
254
255 let tql_string = collected
257 .iter()
258 .map(|tok| tok.to_string())
259 .collect::<Vec<_>>()
260 .join(" ");
261
262 let mut stmts = ParserContext::create_with_dialect(
264 &tql_string,
265 &GreptimeDbDialect {},
266 ParseOptions::default(),
267 )?;
268
269 if stmts.len() != 1 {
270 return Err(error::InvalidSqlSnafu {
271 msg: "Expected a single TQL statement inside CTE".to_string(),
272 }
273 .build());
274 }
275
276 match stmts.remove(0) {
277 Statement::Tql(Tql::Eval(eval)) => Ok(Tql::Eval(eval)),
278 Statement::Tql(_) => Err(error::InvalidSqlSnafu {
279 msg: "Only TQL EVAL is supported in CTEs".to_string(),
280 }
281 .build()),
282 _ => Err(error::InvalidSqlSnafu {
283 msg: "Expected a TQL statement inside CTE".to_string(),
284 }
285 .build()),
286 }
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use crate::dialect::GreptimeDbDialect;
293 use crate::parser::{ParseOptions, ParserContext};
294 use crate::parsers::with_tql_parser::CteContent;
295 use crate::statements::statement::Statement;
296 use crate::statements::tql::Tql;
297
298 #[test]
299 fn test_parse_hybrid_cte_with_parentheses_in_query() {
300 let sql = r#"
302 WITH tql_cte AS (
303 TQL EVAL (0, 100, '5s')
304 sum(rate(http_requests_total[1m])) + (max(cpu_usage) * (1 + 0.5))
305 )
306 SELECT * FROM tql_cte
307 "#;
308
309 let statements =
310 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
311 .unwrap();
312 assert_eq!(statements.len(), 1);
313
314 let Statement::Query(query) = &statements[0] else {
315 panic!("Expected Query statement");
316 };
317 let hybrid_cte = query.hybrid_cte.as_ref().unwrap();
318 assert_eq!(hybrid_cte.cte_tables.len(), 1);
319
320 assert!(matches!(
322 hybrid_cte.cte_tables[0].content,
323 CteContent::Tql(_)
324 ));
325
326 if let CteContent::Tql(Tql::Eval(eval)) = &hybrid_cte.cte_tables[0].content {
328 assert!(eval
330 .query
331 .contains("sum ( rate ( http_requests_total [ 1 m ] ) )"));
332 assert!(eval.query.contains("( max ( cpu_usage ) * ( 1 + 0.5 ) )"));
333 assert!(eval.query.contains("+ ( max"));
335 }
336 }
337
338 #[test]
339 fn test_parse_hybrid_cte_sql_and_tql() {
340 let sql = r#"
341 WITH
342 sql_cte(ts, value, label) AS (SELECT timestamp, val, name FROM metrics),
343 tql_cte(time, metric_value) AS (TQL EVAL (0, 100, '5s') cpu_usage)
344 SELECT s.ts, s.value, t.metric_value
345 FROM sql_cte s JOIN tql_cte t ON s.ts = t.time
346 "#;
347
348 let statements =
349 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
350 .unwrap();
351 assert_eq!(statements.len(), 1);
352
353 let Statement::Query(query) = &statements[0] else {
354 panic!("Expected Query statement");
355 };
356 let hybrid_cte = query.hybrid_cte.as_ref().unwrap();
357 assert_eq!(hybrid_cte.cte_tables.len(), 1); let second_cte = &hybrid_cte.cte_tables[0];
361 assert!(matches!(second_cte.content, CteContent::Tql(_)));
362 assert_eq!(second_cte.columns.len(), 2);
363 assert_eq!(second_cte.columns[0].0[0].value, "time");
364 assert_eq!(second_cte.columns[1].0[0].value, "metric_value");
365 }
366}