1use std::collections::HashSet;
16use std::fmt::{Display, Formatter};
17
18use itertools::Itertools;
19use promql_parser::label::{METRIC_NAME, MatchOp};
20use promql_parser::parser::{
21 AggregateExpr as PromAggregateExpr, BinaryExpr as PromBinaryExpr, Call as PromCall,
22 Expr as PromExpr, MatrixSelector as PromMatrixSelector, ParenExpr as PromParenExpr,
23 SubqueryExpr as PromSubqueryExpr, UnaryExpr as PromUnaryExpr,
24 VectorSelector as PromVectorSelector,
25};
26use serde::Serialize;
27use snafu::ensure;
28use sqlparser::ast::{
29 Array, Expr, Ident, ObjectName, ObjectNamePart, SetExpr, SqlOption, StructField, TableFactor,
30 Value, ValueWithSpan,
31};
32use sqlparser_derive::{Visit, VisitMut};
33
34use crate::ast::ObjectNamePartExt;
35use crate::error::{InvalidExprAsOptionValueSnafu, InvalidSqlSnafu, Result};
36use crate::parser::ParserContext;
37use crate::parsers::with_tql_parser::CteContent;
38use crate::statements::create::SqlOrTql;
39use crate::statements::query::Query;
40use crate::statements::tql::Tql;
41
42const SCHEMA_MATCHER: &str = "__schema__";
43const DATABASE_MATCHER: &str = "__database__";
44
45pub fn format_raw_object_name(name: &ObjectName) -> String {
47 struct Inner<'a> {
48 name: &'a ObjectName,
49 }
50
51 impl Display for Inner<'_> {
52 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
53 let mut delim = "";
54 for ident in self.name.0.iter() {
55 write!(f, "{delim}")?;
56 delim = ".";
57 write!(f, "{}", ident.to_string_unquoted())?;
58 }
59 Ok(())
60 }
61 }
62
63 format!("{}", Inner { name })
64}
65
66#[derive(Clone, Debug, PartialEq, Eq, Serialize, Visit, VisitMut)]
67pub struct OptionValue(Expr);
68
69impl OptionValue {
70 pub(crate) fn try_new(expr: Expr) -> Result<Self> {
71 ensure!(
72 matches!(
73 expr,
74 Expr::Value(_) | Expr::Identifier(_) | Expr::Array(_) | Expr::Struct { .. }
75 ),
76 InvalidExprAsOptionValueSnafu {
77 error: format!("{expr} not accepted")
78 }
79 );
80 Ok(Self(expr))
81 }
82
83 fn expr_as_string(expr: &Expr) -> Option<&str> {
84 match expr {
85 Expr::Value(ValueWithSpan { value, .. }) => match value {
86 Value::SingleQuotedString(s)
87 | Value::DoubleQuotedString(s)
88 | Value::TripleSingleQuotedString(s)
89 | Value::TripleDoubleQuotedString(s)
90 | Value::SingleQuotedByteStringLiteral(s)
91 | Value::DoubleQuotedByteStringLiteral(s)
92 | Value::TripleSingleQuotedByteStringLiteral(s)
93 | Value::TripleDoubleQuotedByteStringLiteral(s)
94 | Value::SingleQuotedRawStringLiteral(s)
95 | Value::DoubleQuotedRawStringLiteral(s)
96 | Value::TripleSingleQuotedRawStringLiteral(s)
97 | Value::TripleDoubleQuotedRawStringLiteral(s)
98 | Value::EscapedStringLiteral(s)
99 | Value::UnicodeStringLiteral(s)
100 | Value::NationalStringLiteral(s)
101 | Value::HexStringLiteral(s) => Some(s),
102 Value::DollarQuotedString(s) => Some(&s.value),
103 Value::Number(s, _) => Some(s),
104 Value::Boolean(b) => Some(if *b { "true" } else { "false" }),
105 _ => None,
106 },
107 Expr::Identifier(ident) => Some(&ident.value),
108 _ => None,
109 }
110 }
111
112 pub fn as_string(&self) -> Option<&str> {
116 Self::expr_as_string(&self.0)
117 }
118
119 pub fn as_list(&self) -> Option<Vec<&str>> {
120 let expr = &self.0;
121 match expr {
122 Expr::Value(_) | Expr::Identifier(_) => self.as_string().map(|s| vec![s]),
123 Expr::Array(array) => array
124 .elem
125 .iter()
126 .map(Self::expr_as_string)
127 .collect::<Option<Vec<_>>>(),
128 _ => None,
129 }
130 }
131
132 pub(crate) fn as_struct_fields(&self) -> Option<&[StructField]> {
133 match &self.0 {
134 Expr::Struct { fields, .. } => Some(fields),
135 _ => None,
136 }
137 }
138}
139
140impl From<String> for OptionValue {
141 fn from(value: String) -> Self {
142 Self(Expr::Identifier(Ident::new(value)))
143 }
144}
145
146impl From<&str> for OptionValue {
147 fn from(value: &str) -> Self {
148 Self(Expr::Identifier(Ident::new(value)))
149 }
150}
151
152impl From<Vec<&str>> for OptionValue {
153 fn from(value: Vec<&str>) -> Self {
154 Self(Expr::Array(Array {
155 elem: value
156 .into_iter()
157 .map(|x| Expr::Identifier(Ident::new(x)))
158 .collect(),
159 named: false,
160 }))
161 }
162}
163
164impl Display for OptionValue {
165 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
166 if let Some(s) = self.as_string() {
167 write!(f, "'{s}'")
168 } else if let Some(s) = self.as_list() {
169 write!(
170 f,
171 "[{}]",
172 s.into_iter().map(|x| format!("'{x}'")).join(", ")
173 )
174 } else {
175 write!(f, "'{}'", self.0)
176 }
177 }
178}
179
180pub fn parse_option_string(option: SqlOption) -> Result<(String, OptionValue)> {
181 let SqlOption::KeyValue { key, value } = option else {
182 return InvalidSqlSnafu {
183 msg: "Expecting a key-value pair in the option",
184 }
185 .fail();
186 };
187 let v = OptionValue::try_new(value)?;
188 let k = key.value.to_lowercase();
189 Ok((k, v))
190}
191
192pub fn extract_tables_from_query(query: &SqlOrTql) -> impl Iterator<Item = ObjectName> {
194 let mut names = HashSet::new();
195
196 match query {
197 SqlOrTql::Sql(query, _) => {
198 extract_tables_from_set_expr(&query.inner.body, &mut names);
199 extract_tables_from_hybrid_cte_query(query, &mut names);
200 }
201 SqlOrTql::Tql(tql, _) => extract_tables_from_tql(tql, &mut names),
202 }
203
204 names.into_iter()
205}
206
207fn extract_tables_from_hybrid_cte_query(query: &Query, sql_names: &mut HashSet<ObjectName>) {
208 let mut tql_names = HashSet::new();
209 let mut cte_names: HashSet<String> = HashSet::new();
210 if let Some(hybrid_cte) = &query.hybrid_cte {
211 for cte in &hybrid_cte.cte_tables {
212 cte_names.insert(ParserContext::canonicalize_identifier(cte.name.clone()).value);
213 if let CteContent::Tql(tql) = &cte.content {
214 extract_tables_from_tql(tql, &mut tql_names);
215 }
216 }
217 }
218
219 if let Some(with) = &query.inner.with {
220 for cte in &with.cte_tables {
221 cte_names.insert(ParserContext::canonicalize_identifier(cte.alias.name.clone()).value);
222 }
223 }
224
225 remove_cte_names(sql_names, &cte_names);
226
227 sql_names.extend(tql_names);
228}
229
230fn remove_cte_names(names: &mut HashSet<ObjectName>, cte_names: &HashSet<String>) {
231 if cte_names.is_empty() {
232 return;
233 }
234
235 names.retain(|name| {
236 if name.0.len() != 1 {
237 return true;
238 }
239 let Some(ident) = name.0[0].as_ident() else {
240 return true;
241 };
242
243 let canonical = ParserContext::canonicalize_identifier(ident.clone()).value;
244 !cte_names.contains(&canonical)
245 });
246}
247
248fn extract_tables_from_tql(tql: &Tql, names: &mut HashSet<ObjectName>) {
249 let promql = match tql {
250 Tql::Eval(eval) => &eval.query,
251 Tql::Explain(explain) => &explain.query,
252 Tql::Analyze(analyze) => &analyze.query,
253 };
254
255 if let Ok(expr) = promql_parser::parser::parse(promql) {
256 extract_tables_from_prom_expr(&expr, names);
257 }
258}
259
260fn extract_tables_from_prom_expr(expr: &PromExpr, names: &mut HashSet<ObjectName>) {
261 match expr {
262 PromExpr::Aggregate(PromAggregateExpr { expr, .. }) => {
263 extract_tables_from_prom_expr(expr, names);
264 }
265 PromExpr::Unary(PromUnaryExpr { expr, .. }) => {
266 extract_tables_from_prom_expr(expr, names);
267 }
268 PromExpr::Binary(PromBinaryExpr { lhs, rhs, .. }) => {
269 extract_tables_from_prom_expr(lhs, names);
270 extract_tables_from_prom_expr(rhs, names);
271 }
272 PromExpr::Paren(PromParenExpr { expr }) => {
273 extract_tables_from_prom_expr(expr, names);
274 }
275 PromExpr::Subquery(PromSubqueryExpr { expr, .. }) => {
276 extract_tables_from_prom_expr(expr, names);
277 }
278 PromExpr::VectorSelector(selector) => {
279 extract_metric_name_from_vector_selector(selector, names);
280 }
281 PromExpr::MatrixSelector(PromMatrixSelector { vs, .. }) => {
282 extract_metric_name_from_vector_selector(vs, names);
283 }
284 PromExpr::Call(PromCall { args, .. }) => {
285 for arg in &args.args {
286 extract_tables_from_prom_expr(arg, names);
287 }
288 }
289 PromExpr::NumberLiteral(_) | PromExpr::StringLiteral(_) | PromExpr::Extension(_) => {}
290 }
291}
292
293fn extract_metric_name_from_vector_selector(
294 selector: &PromVectorSelector,
295 names: &mut HashSet<ObjectName>,
296) {
297 let metric_name = selector.name.clone().or_else(|| {
298 let mut metric_name_matchers = selector.matchers.find_matchers(METRIC_NAME);
299 if metric_name_matchers.len() == 1 && metric_name_matchers[0].op == MatchOp::Equal {
300 metric_name_matchers.pop().map(|matcher| matcher.value)
301 } else {
302 None
303 }
304 });
305 let Some(metric_name) = metric_name else {
306 return;
307 };
308
309 let schema_matcher = selector.matchers.matchers.iter().rev().find(|matcher| {
310 matcher.op == MatchOp::Equal
311 && (matcher.name == SCHEMA_MATCHER || matcher.name == DATABASE_MATCHER)
312 });
313
314 if let Some(schema) = schema_matcher {
315 names.insert(ObjectName(vec![
316 ObjectNamePart::Identifier(Ident::new(&schema.value)),
317 ObjectNamePart::Identifier(Ident::new(metric_name)),
318 ]));
319 } else {
320 names.insert(ObjectName(vec![ObjectNamePart::Identifier(Ident::new(
321 metric_name,
322 ))]));
323 }
324}
325
326pub fn location_to_index(sql: &str, location: &sqlparser::tokenizer::Location) -> usize {
328 let mut index = 0;
329 for (lno, line) in sql.lines().enumerate() {
330 if lno + 1 == location.line as usize {
331 index += location.column as usize;
332 break;
333 } else {
334 index += line.len() + 1; }
336 }
337 index - 1
340}
341
342fn extract_tables_from_set_expr(set_expr: &SetExpr, names: &mut HashSet<ObjectName>) {
346 match set_expr {
347 SetExpr::Select(select) => {
348 for from in &select.from {
349 table_factor_to_object_name(&from.relation, names);
350 for join in &from.joins {
351 table_factor_to_object_name(&join.relation, names);
352 }
353 }
354 }
355 SetExpr::Query(query) => {
356 extract_tables_from_set_expr(&query.body, names);
357 }
358 SetExpr::SetOperation { left, right, .. } => {
359 extract_tables_from_set_expr(left, names);
360 extract_tables_from_set_expr(right, names);
361 }
362 _ => {}
363 };
364}
365
366fn table_factor_to_object_name(table_factor: &TableFactor, names: &mut HashSet<ObjectName>) {
370 if let TableFactor::Table { name, .. } = table_factor {
371 names.insert(name.to_owned());
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use sqlparser::tokenizer::Token;
378
379 use super::*;
380 use crate::dialect::GreptimeDbDialect;
381 use crate::parser::{ParseOptions, ParserContext};
382 use crate::statements::statement::Statement;
383
384 #[test]
385 fn test_location_to_index() {
386 let testcases = vec![
387 "SELECT * FROM t WHERE a = 1",
388 r"
390SELECT *
391FROM
392t
393WHERE a =
3941
395",
396 r"SELECT *
397FROM
398t
399WHERE a =
4001
401",
402 r"
403SELECT *
404FROM
405t
406WHERE a =
4071",
408 ];
409
410 for sql in testcases {
411 let mut parser = ParserContext::new(&GreptimeDbDialect {}, sql).unwrap();
412 loop {
413 let token = parser.parser.next_token();
414 if token == Token::EOF {
415 break;
416 }
417 let span = token.span;
418 let subslice =
419 &sql[location_to_index(sql, &span.start)..location_to_index(sql, &span.end)];
420 assert_eq!(token.to_string(), subslice);
421 }
422 }
423 }
424
425 #[test]
426 fn test_extract_tables_from_tql_query() {
427 let testcases = vec![
428 (
429 r#"
430CREATE FLOW calc_reqs SINK TO cnt_reqs AS
431TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", http_requests);"#,
432 vec!["http_requests".to_string()],
433 ),
434 (
435 r#"
436CREATE FLOW calc_reqs SINK TO cnt_reqs AS
437TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", {__name__="http_requests"});"#,
438 vec!["http_requests".to_string()],
439 ),
440 ];
441
442 for (sql, expected_tables) in testcases {
443 let mut stmts = ParserContext::create_with_dialect(
444 sql,
445 &GreptimeDbDialect {},
446 ParseOptions::default(),
447 )
448 .unwrap();
449 let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
450 unreachable!()
451 };
452
453 let mut tables = extract_tables_from_query(&create_flow.query)
454 .map(|table| format_raw_object_name(&table))
455 .collect_vec();
456 tables.sort();
457 assert_eq!(expected_tables, tables);
458 }
459 }
460
461 #[test]
462 fn test_extract_tables_from_tql_query_with_schema_matcher() {
463 let sql = r#"
464CREATE FLOW calc_reqs SINK TO cnt_reqs AS
465TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", http_requests{__schema__="greptime_private"});"#;
466 let mut stmts =
467 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
468 .unwrap();
469 let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
470 unreachable!()
471 };
472
473 let mut tables = extract_tables_from_query(&create_flow.query)
474 .map(|table| format_raw_object_name(&table))
475 .collect_vec();
476 tables.sort();
477 assert_eq!(vec!["greptime_private.http_requests".to_string()], tables);
478 }
479}