sql/statements/
insert.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use serde::Serialize;
16use sqlparser::ast::{
17    Insert as SpInsert, ObjectName, Query, SetExpr, Statement, TableObject, UnaryOperator,
18    ValueWithSpan, Values,
19};
20use sqlparser::parser::ParserError;
21use sqlparser_derive::{Visit, VisitMut};
22
23use crate::ast::{Expr, Value};
24use crate::error::{Result, UnsupportedSnafu};
25use crate::statements::query::Query as GtQuery;
26
27#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
28pub struct Insert {
29    // Can only be sqlparser::ast::Statement::Insert variant
30    pub inner: Statement,
31}
32
33macro_rules! parse_fail {
34    ($expr: expr) => {
35        return crate::error::ParseSqlValueSnafu {
36            msg: format!("{:?}", $expr),
37        }
38        .fail();
39    };
40}
41
42impl Insert {
43    pub fn table_name(&self) -> Result<&ObjectName> {
44        match &self.inner {
45            Statement::Insert(insert) => {
46                let TableObject::TableName(name) = &insert.table else {
47                    return UnsupportedSnafu {
48                        keyword: "TABLE FUNCTION".to_string(),
49                    }
50                    .fail();
51                };
52                Ok(name)
53            }
54            _ => unreachable!(),
55        }
56    }
57
58    pub fn columns(&self) -> Vec<&String> {
59        match &self.inner {
60            Statement::Insert(insert) => insert.columns.iter().map(|ident| &ident.value).collect(),
61            _ => unreachable!(),
62        }
63    }
64
65    /// Extracts the literal insert statement body if possible
66    pub fn values_body(&self) -> Result<Vec<Vec<Value>>> {
67        match &self.inner {
68            Statement::Insert(SpInsert {
69                source:
70                    Some(box Query {
71                        body: box SetExpr::Values(Values { rows, .. }),
72                        ..
73                    }),
74                ..
75            }) => sql_exprs_to_values(rows),
76            _ => unreachable!(),
77        }
78    }
79
80    /// Returns true when the insert statement can extract literal values.
81    /// The rules is the same as function `values_body()`.
82    pub fn can_extract_values(&self) -> bool {
83        match &self.inner {
84            Statement::Insert(SpInsert {
85                source:
86                    Some(box Query {
87                        body: box SetExpr::Values(Values { rows, .. }),
88                        ..
89                    }),
90                ..
91            }) => rows.iter().all(|es| {
92                es.iter().all(|expr| match expr {
93                    Expr::Value(_) => true,
94                    Expr::Identifier(ident) => {
95                        if ident.quote_style.is_none() {
96                            ident.value.to_lowercase() == "default"
97                        } else {
98                            ident.quote_style == Some('"')
99                        }
100                    }
101                    Expr::UnaryOp { op, expr } => {
102                        matches!(op, UnaryOperator::Minus | UnaryOperator::Plus)
103                            && matches!(
104                                &**expr,
105                                Expr::Value(ValueWithSpan {
106                                    value: Value::Number(_, _),
107                                    ..
108                                })
109                            )
110                    }
111                    _ => false,
112                })
113            }),
114            _ => false,
115        }
116    }
117
118    pub fn query_body(&self) -> Result<Option<GtQuery>> {
119        Ok(match &self.inner {
120            Statement::Insert(SpInsert {
121                source: Some(box query),
122                ..
123            }) => Some(query.clone().try_into()?),
124            _ => None,
125        })
126    }
127}
128
129fn sql_exprs_to_values(exprs: &[Vec<Expr>]) -> Result<Vec<Vec<Value>>> {
130    let mut values = Vec::with_capacity(exprs.len());
131    for es in exprs.iter() {
132        let mut vs = Vec::with_capacity(es.len());
133        for expr in es.iter() {
134            vs.push(match expr {
135                Expr::Value(v) => v.value.clone(),
136                Expr::Identifier(ident) => {
137                    if ident.quote_style.is_none() {
138                        // Special processing for `default` value
139                        if ident.value.to_lowercase() == "default" {
140                            Value::Placeholder(ident.value.clone())
141                        } else {
142                            parse_fail!(expr);
143                        }
144                    } else {
145                        // Identifiers with double quotes, we treat them as strings.
146                        if ident.quote_style == Some('"') {
147                            Value::SingleQuotedString(ident.value.clone())
148                        } else {
149                            parse_fail!(expr);
150                        }
151                    }
152                }
153                Expr::UnaryOp { op, expr }
154                    if matches!(op, UnaryOperator::Minus | UnaryOperator::Plus) =>
155                {
156                    if let Expr::Value(ValueWithSpan {
157                        value: Value::Number(s, b),
158                        ..
159                    }) = &**expr
160                    {
161                        match op {
162                            UnaryOperator::Minus => Value::Number(format!("-{s}"), *b),
163                            UnaryOperator::Plus => Value::Number(s.to_string(), *b),
164                            _ => unreachable!(),
165                        }
166                    } else {
167                        parse_fail!(expr);
168                    }
169                }
170                _ => {
171                    parse_fail!(expr);
172                }
173            });
174        }
175        values.push(vs);
176    }
177    Ok(values)
178}
179
180impl TryFrom<Statement> for Insert {
181    type Error = ParserError;
182
183    fn try_from(value: Statement) -> std::result::Result<Self, Self::Error> {
184        match value {
185            Statement::Insert { .. } => Ok(Insert { inner: value }),
186            unexp => Err(ParserError::ParserError(format!(
187                "Not expected to be {unexp}"
188            ))),
189        }
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use crate::dialect::GreptimeDbDialect;
197    use crate::parser::{ParseOptions, ParserContext};
198    use crate::statements::statement::Statement;
199
200    #[test]
201    fn test_insert_value_with_unary_op() {
202        // insert "-1"
203        let sql = "INSERT INTO my_table VALUES(-1)";
204        let stmt =
205            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
206                .unwrap()
207                .remove(0);
208        match stmt {
209            Statement::Insert(insert) => {
210                let values = insert.values_body().unwrap();
211                assert_eq!(values, vec![vec![Value::Number("-1".to_string(), false)]]);
212            }
213            _ => unreachable!(),
214        }
215
216        // insert "+1"
217        let sql = "INSERT INTO my_table VALUES(+1)";
218        let stmt =
219            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
220                .unwrap()
221                .remove(0);
222        match stmt {
223            Statement::Insert(insert) => {
224                let values = insert.values_body().unwrap();
225                assert_eq!(values, vec![vec![Value::Number("1".to_string(), false)]]);
226            }
227            _ => unreachable!(),
228        }
229    }
230
231    #[test]
232    fn test_insert_value_with_default() {
233        // insert "default"
234        let sql = "INSERT INTO my_table VALUES(default)";
235        let stmt =
236            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
237                .unwrap()
238                .remove(0);
239        match stmt {
240            Statement::Insert(insert) => {
241                let values = insert.values_body().unwrap();
242                assert_eq!(values, vec![vec![Value::Placeholder("default".to_owned())]]);
243            }
244            _ => unreachable!(),
245        }
246    }
247
248    #[test]
249    fn test_insert_value_with_default_uppercase() {
250        // insert "DEFAULT"
251        let sql = "INSERT INTO my_table VALUES(DEFAULT)";
252        let stmt =
253            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
254                .unwrap()
255                .remove(0);
256        match stmt {
257            Statement::Insert(insert) => {
258                let values = insert.values_body().unwrap();
259                assert_eq!(values, vec![vec![Value::Placeholder("DEFAULT".to_owned())]]);
260            }
261            _ => unreachable!(),
262        }
263    }
264
265    #[test]
266    fn test_insert_value_with_quoted_string() {
267        // insert 'default'
268        let sql = "INSERT INTO my_table VALUES('default')";
269        let stmt =
270            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
271                .unwrap()
272                .remove(0);
273        match stmt {
274            Statement::Insert(insert) => {
275                let values = insert.values_body().unwrap();
276                assert_eq!(
277                    values,
278                    vec![vec![Value::SingleQuotedString("default".to_owned())]]
279                );
280            }
281            _ => unreachable!(),
282        }
283
284        // insert "default". Treating double-quoted identifiers as strings.
285        let sql = "INSERT INTO my_table VALUES(\"default\")";
286        let stmt =
287            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
288                .unwrap()
289                .remove(0);
290        match stmt {
291            Statement::Insert(insert) => {
292                let values = insert.values_body().unwrap();
293                assert_eq!(
294                    values,
295                    vec![vec![Value::SingleQuotedString("default".to_owned())]]
296                );
297            }
298            _ => unreachable!(),
299        }
300
301        let sql = "INSERT INTO my_table VALUES(`default`)";
302        let stmt =
303            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
304                .unwrap()
305                .remove(0);
306        match stmt {
307            Statement::Insert(insert) => {
308                assert!(insert.values_body().is_err());
309            }
310            _ => unreachable!(),
311        }
312    }
313
314    #[test]
315    fn test_insert_select() {
316        let sql = "INSERT INTO my_table select * from other_table";
317        let stmt =
318            ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
319                .unwrap()
320                .remove(0);
321        match stmt {
322            Statement::Insert(insert) => {
323                let q = insert.query_body().unwrap().unwrap();
324                assert!(matches!(
325                    q.inner,
326                    Query {
327                        body: box SetExpr::Select { .. },
328                        ..
329                    }
330                ));
331            }
332            _ => unreachable!(),
333        }
334    }
335}