sql/statements/
insert.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use serde::Serialize;
16use sqlparser::ast::{
17    Insert as SpInsert, ObjectName, Query, SetExpr, Statement, TableObject, UnaryOperator, Values,
18};
19use sqlparser::parser::ParserError;
20use sqlparser_derive::{Visit, VisitMut};
21
22use crate::ast::{Expr, Value};
23use crate::error::{Result, UnsupportedSnafu};
24use crate::statements::query::Query as GtQuery;
25
26#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
27pub struct Insert {
28    // Can only be sqlparser::ast::Statement::Insert variant
29    pub inner: Statement,
30}
31
32macro_rules! parse_fail {
33    ($expr: expr) => {
34        return crate::error::ParseSqlValueSnafu {
35            msg: format!("{:?}", $expr),
36        }
37        .fail();
38    };
39}
40
41impl Insert {
42    pub fn table_name(&self) -> Result<&ObjectName> {
43        match &self.inner {
44            Statement::Insert(insert) => {
45                let TableObject::TableName(name) = &insert.table else {
46                    return UnsupportedSnafu {
47                        keyword: "TABLE FUNCTION".to_string(),
48                    }
49                    .fail();
50                };
51                Ok(name)
52            }
53            _ => unreachable!(),
54        }
55    }
56
57    pub fn columns(&self) -> Vec<&String> {
58        match &self.inner {
59            Statement::Insert(insert) => insert.columns.iter().map(|ident| &ident.value).collect(),
60            _ => unreachable!(),
61        }
62    }
63
64    /// Extracts the literal insert statement body if possible
65    pub fn values_body(&self) -> Result<Vec<Vec<Value>>> {
66        match &self.inner {
67            Statement::Insert(SpInsert {
68                source:
69                    Some(box Query {
70                        body: box SetExpr::Values(Values { rows, .. }),
71                        ..
72                    }),
73                ..
74            }) => sql_exprs_to_values(rows),
75            _ => unreachable!(),
76        }
77    }
78
79    /// Returns true when the insert statement can extract literal values.
80    /// The rules is the same as function `values_body()`.
81    pub fn can_extract_values(&self) -> bool {
82        match &self.inner {
83            Statement::Insert(SpInsert {
84                source:
85                    Some(box Query {
86                        body: box SetExpr::Values(Values { rows, .. }),
87                        ..
88                    }),
89                ..
90            }) => rows.iter().all(|es| {
91                es.iter().all(|expr| match expr {
92                    Expr::Value(_) => true,
93                    Expr::Identifier(ident) => {
94                        if ident.quote_style.is_none() {
95                            ident.value.to_lowercase() == "default"
96                        } else {
97                            ident.quote_style == Some('"')
98                        }
99                    }
100                    Expr::UnaryOp { op, expr } => {
101                        matches!(op, UnaryOperator::Minus | UnaryOperator::Plus)
102                            && matches!(&**expr, Expr::Value(Value::Number(_, _)))
103                    }
104                    _ => false,
105                })
106            }),
107            _ => false,
108        }
109    }
110
111    pub fn query_body(&self) -> Result<Option<GtQuery>> {
112        Ok(match &self.inner {
113            Statement::Insert(SpInsert {
114                source: Some(box query),
115                ..
116            }) => Some(query.clone().try_into()?),
117            _ => None,
118        })
119    }
120}
121
122fn sql_exprs_to_values(exprs: &[Vec<Expr>]) -> Result<Vec<Vec<Value>>> {
123    let mut values = Vec::with_capacity(exprs.len());
124    for es in exprs.iter() {
125        let mut vs = Vec::with_capacity(es.len());
126        for expr in es.iter() {
127            vs.push(match expr {
128                Expr::Value(v) => v.clone(),
129                Expr::Identifier(ident) => {
130                    if ident.quote_style.is_none() {
131                        // Special processing for `default` value
132                        if ident.value.to_lowercase() == "default" {
133                            Value::Placeholder(ident.value.clone())
134                        } else {
135                            parse_fail!(expr);
136                        }
137                    } else {
138                        // Identifiers with double quotes, we treat them as strings.
139                        if ident.quote_style == Some('"') {
140                            Value::SingleQuotedString(ident.value.clone())
141                        } else {
142                            parse_fail!(expr);
143                        }
144                    }
145                }
146                Expr::UnaryOp { op, expr }
147                    if matches!(op, UnaryOperator::Minus | UnaryOperator::Plus) =>
148                {
149                    if let Expr::Value(Value::Number(s, b)) = &**expr {
150                        match op {
151                            UnaryOperator::Minus => Value::Number(format!("-{s}"), *b),
152                            UnaryOperator::Plus => Value::Number(s.to_string(), *b),
153                            _ => unreachable!(),
154                        }
155                    } else {
156                        parse_fail!(expr);
157                    }
158                }
159                _ => {
160                    parse_fail!(expr);
161                }
162            });
163        }
164        values.push(vs);
165    }
166    Ok(values)
167}
168
169impl TryFrom<Statement> for Insert {
170    type Error = ParserError;
171
172    fn try_from(value: Statement) -> std::result::Result<Self, Self::Error> {
173        match value {
174            Statement::Insert { .. } => Ok(Insert { inner: value }),
175            unexp => Err(ParserError::ParserError(format!(
176                "Not expected to be {unexp}"
177            ))),
178        }
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use crate::dialect::GreptimeDbDialect;
186    use crate::parser::{ParseOptions, ParserContext};
187    use crate::statements::statement::Statement;
188
189    #[test]
190    fn test_insert_value_with_unary_op() {
191        // insert "-1"
192        let sql = "INSERT INTO my_table VALUES(-1)";
193        let stmt =
194            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
195                .unwrap()
196                .remove(0);
197        match stmt {
198            Statement::Insert(insert) => {
199                let values = insert.values_body().unwrap();
200                assert_eq!(values, vec![vec![Value::Number("-1".to_string(), false)]]);
201            }
202            _ => unreachable!(),
203        }
204
205        // insert "+1"
206        let sql = "INSERT INTO my_table VALUES(+1)";
207        let stmt =
208            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
209                .unwrap()
210                .remove(0);
211        match stmt {
212            Statement::Insert(insert) => {
213                let values = insert.values_body().unwrap();
214                assert_eq!(values, vec![vec![Value::Number("1".to_string(), false)]]);
215            }
216            _ => unreachable!(),
217        }
218    }
219
220    #[test]
221    fn test_insert_value_with_default() {
222        // insert "default"
223        let sql = "INSERT INTO my_table VALUES(default)";
224        let stmt =
225            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
226                .unwrap()
227                .remove(0);
228        match stmt {
229            Statement::Insert(insert) => {
230                let values = insert.values_body().unwrap();
231                assert_eq!(values, vec![vec![Value::Placeholder("default".to_owned())]]);
232            }
233            _ => unreachable!(),
234        }
235    }
236
237    #[test]
238    fn test_insert_value_with_default_uppercase() {
239        // insert "DEFAULT"
240        let sql = "INSERT INTO my_table VALUES(DEFAULT)";
241        let stmt =
242            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
243                .unwrap()
244                .remove(0);
245        match stmt {
246            Statement::Insert(insert) => {
247                let values = insert.values_body().unwrap();
248                assert_eq!(values, vec![vec![Value::Placeholder("DEFAULT".to_owned())]]);
249            }
250            _ => unreachable!(),
251        }
252    }
253
254    #[test]
255    fn test_insert_value_with_quoted_string() {
256        // insert 'default'
257        let sql = "INSERT INTO my_table VALUES('default')";
258        let stmt =
259            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
260                .unwrap()
261                .remove(0);
262        match stmt {
263            Statement::Insert(insert) => {
264                let values = insert.values_body().unwrap();
265                assert_eq!(
266                    values,
267                    vec![vec![Value::SingleQuotedString("default".to_owned())]]
268                );
269            }
270            _ => unreachable!(),
271        }
272
273        // insert "default". Treating double-quoted identifiers as strings.
274        let sql = "INSERT INTO my_table VALUES(\"default\")";
275        let stmt =
276            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
277                .unwrap()
278                .remove(0);
279        match stmt {
280            Statement::Insert(insert) => {
281                let values = insert.values_body().unwrap();
282                assert_eq!(
283                    values,
284                    vec![vec![Value::SingleQuotedString("default".to_owned())]]
285                );
286            }
287            _ => unreachable!(),
288        }
289
290        let sql = "INSERT INTO my_table VALUES(`default`)";
291        let stmt =
292            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
293                .unwrap()
294                .remove(0);
295        match stmt {
296            Statement::Insert(insert) => {
297                assert!(insert.values_body().is_err());
298            }
299            _ => unreachable!(),
300        }
301    }
302
303    #[test]
304    fn test_insert_select() {
305        let sql = "INSERT INTO my_table select * from other_table";
306        let stmt =
307            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
308                .unwrap()
309                .remove(0);
310        match stmt {
311            Statement::Insert(insert) => {
312                let q = insert.query_body().unwrap().unwrap();
313                assert!(matches!(
314                    q.inner,
315                    Query {
316                        body: box SetExpr::Select { .. },
317                        ..
318                    }
319                ));
320            }
321            _ => unreachable!(),
322        }
323    }
324}