1use std::collections::HashSet;
16use std::fmt::{Display, Formatter};
17
18use itertools::Itertools;
19use promql_parser::label::{METRIC_NAME, MatchOp};
20use promql_parser::parser::{
21 AggregateExpr as PromAggregateExpr, BinaryExpr as PromBinaryExpr, Call as PromCall,
22 Expr as PromExpr, MatrixSelector as PromMatrixSelector, ParenExpr as PromParenExpr,
23 SubqueryExpr as PromSubqueryExpr, UnaryExpr as PromUnaryExpr,
24 VectorSelector as PromVectorSelector,
25};
26use serde::Serialize;
27use snafu::ensure;
28use sqlparser::ast::{
29 Array, Expr, Ident, ObjectName, ObjectNamePart, SetExpr, SqlOption, StructField, TableFactor,
30 Value, ValueWithSpan,
31};
32use sqlparser_derive::{Visit, VisitMut};
33
34use crate::ast::ObjectNamePartExt;
35use crate::error::{InvalidExprAsOptionValueSnafu, InvalidSqlSnafu, Result};
36use crate::statements::create::SqlOrTql;
37use crate::statements::tql::Tql;
38
39const SCHEMA_MATCHER: &str = "__schema__";
40const DATABASE_MATCHER: &str = "__database__";
41
42pub fn format_raw_object_name(name: &ObjectName) -> String {
44 struct Inner<'a> {
45 name: &'a ObjectName,
46 }
47
48 impl Display for Inner<'_> {
49 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
50 let mut delim = "";
51 for ident in self.name.0.iter() {
52 write!(f, "{delim}")?;
53 delim = ".";
54 write!(f, "{}", ident.to_string_unquoted())?;
55 }
56 Ok(())
57 }
58 }
59
60 format!("{}", Inner { name })
61}
62
63#[derive(Clone, Debug, PartialEq, Eq, Serialize, Visit, VisitMut)]
64pub struct OptionValue(Expr);
65
66impl OptionValue {
67 pub(crate) fn try_new(expr: Expr) -> Result<Self> {
68 ensure!(
69 matches!(
70 expr,
71 Expr::Value(_) | Expr::Identifier(_) | Expr::Array(_) | Expr::Struct { .. }
72 ),
73 InvalidExprAsOptionValueSnafu {
74 error: format!("{expr} not accepted")
75 }
76 );
77 Ok(Self(expr))
78 }
79
80 fn expr_as_string(expr: &Expr) -> Option<&str> {
81 match expr {
82 Expr::Value(ValueWithSpan { value, .. }) => match value {
83 Value::SingleQuotedString(s)
84 | Value::DoubleQuotedString(s)
85 | Value::TripleSingleQuotedString(s)
86 | Value::TripleDoubleQuotedString(s)
87 | Value::SingleQuotedByteStringLiteral(s)
88 | Value::DoubleQuotedByteStringLiteral(s)
89 | Value::TripleSingleQuotedByteStringLiteral(s)
90 | Value::TripleDoubleQuotedByteStringLiteral(s)
91 | Value::SingleQuotedRawStringLiteral(s)
92 | Value::DoubleQuotedRawStringLiteral(s)
93 | Value::TripleSingleQuotedRawStringLiteral(s)
94 | Value::TripleDoubleQuotedRawStringLiteral(s)
95 | Value::EscapedStringLiteral(s)
96 | Value::UnicodeStringLiteral(s)
97 | Value::NationalStringLiteral(s)
98 | Value::HexStringLiteral(s) => Some(s),
99 Value::DollarQuotedString(s) => Some(&s.value),
100 Value::Number(s, _) => Some(s),
101 Value::Boolean(b) => Some(if *b { "true" } else { "false" }),
102 _ => None,
103 },
104 Expr::Identifier(ident) => Some(&ident.value),
105 _ => None,
106 }
107 }
108
109 pub fn as_string(&self) -> Option<&str> {
113 Self::expr_as_string(&self.0)
114 }
115
116 pub fn as_list(&self) -> Option<Vec<&str>> {
117 let expr = &self.0;
118 match expr {
119 Expr::Value(_) | Expr::Identifier(_) => self.as_string().map(|s| vec![s]),
120 Expr::Array(array) => array
121 .elem
122 .iter()
123 .map(Self::expr_as_string)
124 .collect::<Option<Vec<_>>>(),
125 _ => None,
126 }
127 }
128
129 pub(crate) fn as_struct_fields(&self) -> Option<&[StructField]> {
130 match &self.0 {
131 Expr::Struct { fields, .. } => Some(fields),
132 _ => None,
133 }
134 }
135}
136
137impl From<String> for OptionValue {
138 fn from(value: String) -> Self {
139 Self(Expr::Identifier(Ident::new(value)))
140 }
141}
142
143impl From<&str> for OptionValue {
144 fn from(value: &str) -> Self {
145 Self(Expr::Identifier(Ident::new(value)))
146 }
147}
148
149impl From<Vec<&str>> for OptionValue {
150 fn from(value: Vec<&str>) -> Self {
151 Self(Expr::Array(Array {
152 elem: value
153 .into_iter()
154 .map(|x| Expr::Identifier(Ident::new(x)))
155 .collect(),
156 named: false,
157 }))
158 }
159}
160
161impl Display for OptionValue {
162 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
163 if let Some(s) = self.as_string() {
164 write!(f, "'{s}'")
165 } else if let Some(s) = self.as_list() {
166 write!(
167 f,
168 "[{}]",
169 s.into_iter().map(|x| format!("'{x}'")).join(", ")
170 )
171 } else {
172 write!(f, "'{}'", self.0)
173 }
174 }
175}
176
177pub fn parse_option_string(option: SqlOption) -> Result<(String, OptionValue)> {
178 let SqlOption::KeyValue { key, value } = option else {
179 return InvalidSqlSnafu {
180 msg: "Expecting a key-value pair in the option",
181 }
182 .fail();
183 };
184 let v = OptionValue::try_new(value)?;
185 let k = key.value.to_lowercase();
186 Ok((k, v))
187}
188
189pub fn extract_tables_from_query(query: &SqlOrTql) -> impl Iterator<Item = ObjectName> {
191 let mut names = HashSet::new();
192
193 match query {
194 SqlOrTql::Sql(query, _) => extract_tables_from_set_expr(&query.body, &mut names),
195 SqlOrTql::Tql(tql, _) => extract_tables_from_tql(tql, &mut names),
196 }
197
198 names.into_iter()
199}
200
201fn extract_tables_from_tql(tql: &Tql, names: &mut HashSet<ObjectName>) {
202 let promql = match tql {
203 Tql::Eval(eval) => &eval.query,
204 Tql::Explain(explain) => &explain.query,
205 Tql::Analyze(analyze) => &analyze.query,
206 };
207
208 if let Ok(expr) = promql_parser::parser::parse(promql) {
209 extract_tables_from_prom_expr(&expr, names);
210 }
211}
212
213fn extract_tables_from_prom_expr(expr: &PromExpr, names: &mut HashSet<ObjectName>) {
214 match expr {
215 PromExpr::Aggregate(PromAggregateExpr { expr, .. }) => {
216 extract_tables_from_prom_expr(expr, names);
217 }
218 PromExpr::Unary(PromUnaryExpr { expr, .. }) => {
219 extract_tables_from_prom_expr(expr, names);
220 }
221 PromExpr::Binary(PromBinaryExpr { lhs, rhs, .. }) => {
222 extract_tables_from_prom_expr(lhs, names);
223 extract_tables_from_prom_expr(rhs, names);
224 }
225 PromExpr::Paren(PromParenExpr { expr }) => {
226 extract_tables_from_prom_expr(expr, names);
227 }
228 PromExpr::Subquery(PromSubqueryExpr { expr, .. }) => {
229 extract_tables_from_prom_expr(expr, names);
230 }
231 PromExpr::VectorSelector(selector) => {
232 extract_metric_name_from_vector_selector(selector, names);
233 }
234 PromExpr::MatrixSelector(PromMatrixSelector { vs, .. }) => {
235 extract_metric_name_from_vector_selector(vs, names);
236 }
237 PromExpr::Call(PromCall { args, .. }) => {
238 for arg in &args.args {
239 extract_tables_from_prom_expr(arg, names);
240 }
241 }
242 PromExpr::NumberLiteral(_) | PromExpr::StringLiteral(_) | PromExpr::Extension(_) => {}
243 }
244}
245
246fn extract_metric_name_from_vector_selector(
247 selector: &PromVectorSelector,
248 names: &mut HashSet<ObjectName>,
249) {
250 let metric_name = selector.name.clone().or_else(|| {
251 let mut metric_name_matchers = selector.matchers.find_matchers(METRIC_NAME);
252 if metric_name_matchers.len() == 1 && metric_name_matchers[0].op == MatchOp::Equal {
253 metric_name_matchers.pop().map(|matcher| matcher.value)
254 } else {
255 None
256 }
257 });
258 let Some(metric_name) = metric_name else {
259 return;
260 };
261
262 let schema_matcher = selector.matchers.matchers.iter().rev().find(|matcher| {
263 matcher.op == MatchOp::Equal
264 && (matcher.name == SCHEMA_MATCHER || matcher.name == DATABASE_MATCHER)
265 });
266
267 if let Some(schema) = schema_matcher {
268 names.insert(ObjectName(vec![
269 ObjectNamePart::Identifier(Ident::new(&schema.value)),
270 ObjectNamePart::Identifier(Ident::new(metric_name)),
271 ]));
272 } else {
273 names.insert(ObjectName(vec![ObjectNamePart::Identifier(Ident::new(
274 metric_name,
275 ))]));
276 }
277}
278
279pub fn location_to_index(sql: &str, location: &sqlparser::tokenizer::Location) -> usize {
281 let mut index = 0;
282 for (lno, line) in sql.lines().enumerate() {
283 if lno + 1 == location.line as usize {
284 index += location.column as usize;
285 break;
286 } else {
287 index += line.len() + 1; }
289 }
290 index - 1
293}
294
295fn extract_tables_from_set_expr(set_expr: &SetExpr, names: &mut HashSet<ObjectName>) {
299 match set_expr {
300 SetExpr::Select(select) => {
301 for from in &select.from {
302 table_factor_to_object_name(&from.relation, names);
303 for join in &from.joins {
304 table_factor_to_object_name(&join.relation, names);
305 }
306 }
307 }
308 SetExpr::Query(query) => {
309 extract_tables_from_set_expr(&query.body, names);
310 }
311 SetExpr::SetOperation { left, right, .. } => {
312 extract_tables_from_set_expr(left, names);
313 extract_tables_from_set_expr(right, names);
314 }
315 _ => {}
316 };
317}
318
319fn table_factor_to_object_name(table_factor: &TableFactor, names: &mut HashSet<ObjectName>) {
323 if let TableFactor::Table { name, .. } = table_factor {
324 names.insert(name.to_owned());
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use sqlparser::tokenizer::Token;
331
332 use super::*;
333 use crate::dialect::GreptimeDbDialect;
334 use crate::parser::{ParseOptions, ParserContext};
335 use crate::statements::statement::Statement;
336
337 #[test]
338 fn test_location_to_index() {
339 let testcases = vec![
340 "SELECT * FROM t WHERE a = 1",
341 r"
343SELECT *
344FROM
345t
346WHERE a =
3471
348",
349 r"SELECT *
350FROM
351t
352WHERE a =
3531
354",
355 r"
356SELECT *
357FROM
358t
359WHERE a =
3601",
361 ];
362
363 for sql in testcases {
364 let mut parser = ParserContext::new(&GreptimeDbDialect {}, sql).unwrap();
365 loop {
366 let token = parser.parser.next_token();
367 if token == Token::EOF {
368 break;
369 }
370 let span = token.span;
371 let subslice =
372 &sql[location_to_index(sql, &span.start)..location_to_index(sql, &span.end)];
373 assert_eq!(token.to_string(), subslice);
374 }
375 }
376 }
377
378 #[test]
379 fn test_extract_tables_from_tql_query() {
380 let testcases = vec![
381 (
382 r#"
383CREATE FLOW calc_reqs SINK TO cnt_reqs AS
384TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", http_requests);"#,
385 vec!["http_requests".to_string()],
386 ),
387 (
388 r#"
389CREATE FLOW calc_reqs SINK TO cnt_reqs AS
390TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", {__name__="http_requests"});"#,
391 vec!["http_requests".to_string()],
392 ),
393 ];
394
395 for (sql, expected_tables) in testcases {
396 let mut stmts = ParserContext::create_with_dialect(
397 sql,
398 &GreptimeDbDialect {},
399 ParseOptions::default(),
400 )
401 .unwrap();
402 let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
403 unreachable!()
404 };
405
406 let mut tables = extract_tables_from_query(&create_flow.query)
407 .map(|table| format_raw_object_name(&table))
408 .collect_vec();
409 tables.sort();
410 assert_eq!(expected_tables, tables);
411 }
412 }
413
414 #[test]
415 fn test_extract_tables_from_tql_query_with_schema_matcher() {
416 let sql = r#"
417CREATE FLOW calc_reqs SINK TO cnt_reqs AS
418TQL EVAL (now() - '15s'::interval, now(), '5s') count_values("status_code", http_requests{__schema__="greptime_private"});"#;
419 let mut stmts =
420 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
421 .unwrap();
422 let Statement::CreateFlow(create_flow) = stmts.pop().unwrap() else {
423 unreachable!()
424 };
425
426 let mut tables = extract_tables_from_query(&create_flow.query)
427 .map(|table| format_raw_object_name(&table))
428 .collect_vec();
429 tables.sort();
430 assert_eq!(vec!["greptime_private.http_requests".to_string()], tables);
431 }
432}