1use serde::Serialize;
16use sqlparser::ast::{
17 Insert as SpInsert, ObjectName, Query, SetExpr, Statement, TableObject, UnaryOperator,
18 ValueWithSpan, Values,
19};
20use sqlparser::parser::ParserError;
21use sqlparser_derive::{Visit, VisitMut};
22
23use crate::ast::{Expr, Value};
24use crate::error::{Result, UnsupportedSnafu};
25use crate::statements::query::Query as GtQuery;
26
27#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
28pub struct Insert {
29 pub inner: Statement,
31}
32
33macro_rules! parse_fail {
34 ($expr: expr) => {
35 return crate::error::ParseSqlValueSnafu {
36 msg: format!("{:?}", $expr),
37 }
38 .fail();
39 };
40}
41
42impl Insert {
43 pub fn table_name(&self) -> Result<&ObjectName> {
44 match &self.inner {
45 Statement::Insert(insert) => {
46 let TableObject::TableName(name) = &insert.table else {
47 return UnsupportedSnafu {
48 keyword: "TABLE FUNCTION".to_string(),
49 }
50 .fail();
51 };
52 Ok(name)
53 }
54 _ => unreachable!(),
55 }
56 }
57
58 pub fn columns(&self) -> Vec<&String> {
59 match &self.inner {
60 Statement::Insert(insert) => insert.columns.iter().map(|ident| &ident.value).collect(),
61 _ => unreachable!(),
62 }
63 }
64
65 pub fn values_body(&self) -> Result<Vec<Vec<Value>>> {
67 match &self.inner {
68 Statement::Insert(SpInsert {
69 source:
70 Some(box Query {
71 body: box SetExpr::Values(Values { rows, .. }),
72 ..
73 }),
74 ..
75 }) => sql_exprs_to_values(rows),
76 _ => unreachable!(),
77 }
78 }
79
80 pub fn can_extract_values(&self) -> bool {
83 match &self.inner {
84 Statement::Insert(SpInsert {
85 source:
86 Some(box Query {
87 body: box SetExpr::Values(Values { rows, .. }),
88 ..
89 }),
90 ..
91 }) => rows.iter().all(|es| {
92 es.iter().all(|expr| match expr {
93 Expr::Value(_) => true,
94 Expr::Identifier(ident) => {
95 if ident.quote_style.is_none() {
96 ident.value.to_lowercase() == "default"
97 } else {
98 ident.quote_style == Some('"')
99 }
100 }
101 Expr::UnaryOp { op, expr } => {
102 matches!(op, UnaryOperator::Minus | UnaryOperator::Plus)
103 && matches!(
104 &**expr,
105 Expr::Value(ValueWithSpan {
106 value: Value::Number(_, _),
107 ..
108 })
109 )
110 }
111 _ => false,
112 })
113 }),
114 _ => false,
115 }
116 }
117
118 pub fn query_body(&self) -> Result<Option<GtQuery>> {
119 Ok(match &self.inner {
120 Statement::Insert(SpInsert {
121 source: Some(box query),
122 ..
123 }) => Some(query.clone().try_into()?),
124 _ => None,
125 })
126 }
127}
128
129fn sql_exprs_to_values(exprs: &[Vec<Expr>]) -> Result<Vec<Vec<Value>>> {
130 let mut values = Vec::with_capacity(exprs.len());
131 for es in exprs.iter() {
132 let mut vs = Vec::with_capacity(es.len());
133 for expr in es.iter() {
134 vs.push(match expr {
135 Expr::Value(v) => v.value.clone(),
136 Expr::Identifier(ident) => {
137 if ident.quote_style.is_none() {
138 if ident.value.to_lowercase() == "default" {
140 Value::Placeholder(ident.value.clone())
141 } else {
142 parse_fail!(expr);
143 }
144 } else {
145 if ident.quote_style == Some('"') {
147 Value::SingleQuotedString(ident.value.clone())
148 } else {
149 parse_fail!(expr);
150 }
151 }
152 }
153 Expr::UnaryOp { op, expr }
154 if matches!(op, UnaryOperator::Minus | UnaryOperator::Plus) =>
155 {
156 if let Expr::Value(ValueWithSpan {
157 value: Value::Number(s, b),
158 ..
159 }) = &**expr
160 {
161 match op {
162 UnaryOperator::Minus => Value::Number(format!("-{s}"), *b),
163 UnaryOperator::Plus => Value::Number(s.to_string(), *b),
164 _ => unreachable!(),
165 }
166 } else {
167 parse_fail!(expr);
168 }
169 }
170 _ => {
171 parse_fail!(expr);
172 }
173 });
174 }
175 values.push(vs);
176 }
177 Ok(values)
178}
179
180impl TryFrom<Statement> for Insert {
181 type Error = ParserError;
182
183 fn try_from(value: Statement) -> std::result::Result<Self, Self::Error> {
184 match value {
185 Statement::Insert { .. } => Ok(Insert { inner: value }),
186 unexp => Err(ParserError::ParserError(format!(
187 "Not expected to be {unexp}"
188 ))),
189 }
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use crate::dialect::GreptimeDbDialect;
197 use crate::parser::{ParseOptions, ParserContext};
198 use crate::statements::statement::Statement;
199
200 #[test]
201 fn test_insert_value_with_unary_op() {
202 let sql = "INSERT INTO my_table VALUES(-1)";
204 let stmt =
205 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
206 .unwrap()
207 .remove(0);
208 match stmt {
209 Statement::Insert(insert) => {
210 let values = insert.values_body().unwrap();
211 assert_eq!(values, vec![vec![Value::Number("-1".to_string(), false)]]);
212 }
213 _ => unreachable!(),
214 }
215
216 let sql = "INSERT INTO my_table VALUES(+1)";
218 let stmt =
219 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
220 .unwrap()
221 .remove(0);
222 match stmt {
223 Statement::Insert(insert) => {
224 let values = insert.values_body().unwrap();
225 assert_eq!(values, vec![vec![Value::Number("1".to_string(), false)]]);
226 }
227 _ => unreachable!(),
228 }
229 }
230
231 #[test]
232 fn test_insert_value_with_default() {
233 let sql = "INSERT INTO my_table VALUES(default)";
235 let stmt =
236 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
237 .unwrap()
238 .remove(0);
239 match stmt {
240 Statement::Insert(insert) => {
241 let values = insert.values_body().unwrap();
242 assert_eq!(values, vec![vec![Value::Placeholder("default".to_owned())]]);
243 }
244 _ => unreachable!(),
245 }
246 }
247
248 #[test]
249 fn test_insert_value_with_default_uppercase() {
250 let sql = "INSERT INTO my_table VALUES(DEFAULT)";
252 let stmt =
253 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
254 .unwrap()
255 .remove(0);
256 match stmt {
257 Statement::Insert(insert) => {
258 let values = insert.values_body().unwrap();
259 assert_eq!(values, vec![vec![Value::Placeholder("DEFAULT".to_owned())]]);
260 }
261 _ => unreachable!(),
262 }
263 }
264
265 #[test]
266 fn test_insert_value_with_quoted_string() {
267 let sql = "INSERT INTO my_table VALUES('default')";
269 let stmt =
270 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
271 .unwrap()
272 .remove(0);
273 match stmt {
274 Statement::Insert(insert) => {
275 let values = insert.values_body().unwrap();
276 assert_eq!(
277 values,
278 vec![vec![Value::SingleQuotedString("default".to_owned())]]
279 );
280 }
281 _ => unreachable!(),
282 }
283
284 let sql = "INSERT INTO my_table VALUES(\"default\")";
286 let stmt =
287 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
288 .unwrap()
289 .remove(0);
290 match stmt {
291 Statement::Insert(insert) => {
292 let values = insert.values_body().unwrap();
293 assert_eq!(
294 values,
295 vec![vec![Value::SingleQuotedString("default".to_owned())]]
296 );
297 }
298 _ => unreachable!(),
299 }
300
301 let sql = "INSERT INTO my_table VALUES(`default`)";
302 let stmt =
303 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
304 .unwrap()
305 .remove(0);
306 match stmt {
307 Statement::Insert(insert) => {
308 assert!(insert.values_body().is_err());
309 }
310 _ => unreachable!(),
311 }
312 }
313
314 #[test]
315 fn test_insert_select() {
316 let sql = "INSERT INTO my_table select * from other_table";
317 let stmt =
318 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
319 .unwrap()
320 .remove(0);
321 match stmt {
322 Statement::Insert(insert) => {
323 let q = insert.query_body().unwrap().unwrap();
324 assert!(matches!(
325 q.inner,
326 Query {
327 body: box SetExpr::Select { .. },
328 ..
329 }
330 ));
331 }
332 _ => unreachable!(),
333 }
334 }
335}