1use serde::Serialize;
16use sqlparser::ast::{
17 Insert as SpInsert, ObjectName, Query, SetExpr, Statement, TableObject, UnaryOperator, Values,
18};
19use sqlparser::parser::ParserError;
20use sqlparser_derive::{Visit, VisitMut};
21
22use crate::ast::{Expr, Value};
23use crate::error::{Result, UnsupportedSnafu};
24use crate::statements::query::Query as GtQuery;
25
26#[derive(Debug, Clone, PartialEq, Eq, Visit, VisitMut, Serialize)]
27pub struct Insert {
28 pub inner: Statement,
30}
31
32macro_rules! parse_fail {
33 ($expr: expr) => {
34 return crate::error::ParseSqlValueSnafu {
35 msg: format!("{:?}", $expr),
36 }
37 .fail();
38 };
39}
40
41impl Insert {
42 pub fn table_name(&self) -> Result<&ObjectName> {
43 match &self.inner {
44 Statement::Insert(insert) => {
45 let TableObject::TableName(name) = &insert.table else {
46 return UnsupportedSnafu {
47 keyword: "TABLE FUNCTION".to_string(),
48 }
49 .fail();
50 };
51 Ok(name)
52 }
53 _ => unreachable!(),
54 }
55 }
56
57 pub fn columns(&self) -> Vec<&String> {
58 match &self.inner {
59 Statement::Insert(insert) => insert.columns.iter().map(|ident| &ident.value).collect(),
60 _ => unreachable!(),
61 }
62 }
63
64 pub fn values_body(&self) -> Result<Vec<Vec<Value>>> {
66 match &self.inner {
67 Statement::Insert(SpInsert {
68 source:
69 Some(box Query {
70 body: box SetExpr::Values(Values { rows, .. }),
71 ..
72 }),
73 ..
74 }) => sql_exprs_to_values(rows),
75 _ => unreachable!(),
76 }
77 }
78
79 pub fn can_extract_values(&self) -> bool {
82 match &self.inner {
83 Statement::Insert(SpInsert {
84 source:
85 Some(box Query {
86 body: box SetExpr::Values(Values { rows, .. }),
87 ..
88 }),
89 ..
90 }) => rows.iter().all(|es| {
91 es.iter().all(|expr| match expr {
92 Expr::Value(_) => true,
93 Expr::Identifier(ident) => {
94 if ident.quote_style.is_none() {
95 ident.value.to_lowercase() == "default"
96 } else {
97 ident.quote_style == Some('"')
98 }
99 }
100 Expr::UnaryOp { op, expr } => {
101 matches!(op, UnaryOperator::Minus | UnaryOperator::Plus)
102 && matches!(&**expr, Expr::Value(Value::Number(_, _)))
103 }
104 _ => false,
105 })
106 }),
107 _ => false,
108 }
109 }
110
111 pub fn query_body(&self) -> Result<Option<GtQuery>> {
112 Ok(match &self.inner {
113 Statement::Insert(SpInsert {
114 source: Some(box query),
115 ..
116 }) => Some(query.clone().try_into()?),
117 _ => None,
118 })
119 }
120}
121
122fn sql_exprs_to_values(exprs: &[Vec<Expr>]) -> Result<Vec<Vec<Value>>> {
123 let mut values = Vec::with_capacity(exprs.len());
124 for es in exprs.iter() {
125 let mut vs = Vec::with_capacity(es.len());
126 for expr in es.iter() {
127 vs.push(match expr {
128 Expr::Value(v) => v.clone(),
129 Expr::Identifier(ident) => {
130 if ident.quote_style.is_none() {
131 if ident.value.to_lowercase() == "default" {
133 Value::Placeholder(ident.value.clone())
134 } else {
135 parse_fail!(expr);
136 }
137 } else {
138 if ident.quote_style == Some('"') {
140 Value::SingleQuotedString(ident.value.clone())
141 } else {
142 parse_fail!(expr);
143 }
144 }
145 }
146 Expr::UnaryOp { op, expr }
147 if matches!(op, UnaryOperator::Minus | UnaryOperator::Plus) =>
148 {
149 if let Expr::Value(Value::Number(s, b)) = &**expr {
150 match op {
151 UnaryOperator::Minus => Value::Number(format!("-{s}"), *b),
152 UnaryOperator::Plus => Value::Number(s.to_string(), *b),
153 _ => unreachable!(),
154 }
155 } else {
156 parse_fail!(expr);
157 }
158 }
159 _ => {
160 parse_fail!(expr);
161 }
162 });
163 }
164 values.push(vs);
165 }
166 Ok(values)
167}
168
169impl TryFrom<Statement> for Insert {
170 type Error = ParserError;
171
172 fn try_from(value: Statement) -> std::result::Result<Self, Self::Error> {
173 match value {
174 Statement::Insert { .. } => Ok(Insert { inner: value }),
175 unexp => Err(ParserError::ParserError(format!(
176 "Not expected to be {unexp}"
177 ))),
178 }
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use crate::dialect::GreptimeDbDialect;
186 use crate::parser::{ParseOptions, ParserContext};
187 use crate::statements::statement::Statement;
188
189 #[test]
190 fn test_insert_value_with_unary_op() {
191 let sql = "INSERT INTO my_table VALUES(-1)";
193 let stmt =
194 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
195 .unwrap()
196 .remove(0);
197 match stmt {
198 Statement::Insert(insert) => {
199 let values = insert.values_body().unwrap();
200 assert_eq!(values, vec![vec![Value::Number("-1".to_string(), false)]]);
201 }
202 _ => unreachable!(),
203 }
204
205 let sql = "INSERT INTO my_table VALUES(+1)";
207 let stmt =
208 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
209 .unwrap()
210 .remove(0);
211 match stmt {
212 Statement::Insert(insert) => {
213 let values = insert.values_body().unwrap();
214 assert_eq!(values, vec![vec![Value::Number("1".to_string(), false)]]);
215 }
216 _ => unreachable!(),
217 }
218 }
219
220 #[test]
221 fn test_insert_value_with_default() {
222 let sql = "INSERT INTO my_table VALUES(default)";
224 let stmt =
225 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
226 .unwrap()
227 .remove(0);
228 match stmt {
229 Statement::Insert(insert) => {
230 let values = insert.values_body().unwrap();
231 assert_eq!(values, vec![vec![Value::Placeholder("default".to_owned())]]);
232 }
233 _ => unreachable!(),
234 }
235 }
236
237 #[test]
238 fn test_insert_value_with_default_uppercase() {
239 let sql = "INSERT INTO my_table VALUES(DEFAULT)";
241 let stmt =
242 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
243 .unwrap()
244 .remove(0);
245 match stmt {
246 Statement::Insert(insert) => {
247 let values = insert.values_body().unwrap();
248 assert_eq!(values, vec![vec![Value::Placeholder("DEFAULT".to_owned())]]);
249 }
250 _ => unreachable!(),
251 }
252 }
253
254 #[test]
255 fn test_insert_value_with_quoted_string() {
256 let sql = "INSERT INTO my_table VALUES('default')";
258 let stmt =
259 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
260 .unwrap()
261 .remove(0);
262 match stmt {
263 Statement::Insert(insert) => {
264 let values = insert.values_body().unwrap();
265 assert_eq!(
266 values,
267 vec![vec![Value::SingleQuotedString("default".to_owned())]]
268 );
269 }
270 _ => unreachable!(),
271 }
272
273 let sql = "INSERT INTO my_table VALUES(\"default\")";
275 let stmt =
276 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
277 .unwrap()
278 .remove(0);
279 match stmt {
280 Statement::Insert(insert) => {
281 let values = insert.values_body().unwrap();
282 assert_eq!(
283 values,
284 vec![vec![Value::SingleQuotedString("default".to_owned())]]
285 );
286 }
287 _ => unreachable!(),
288 }
289
290 let sql = "INSERT INTO my_table VALUES(`default`)";
291 let stmt =
292 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
293 .unwrap()
294 .remove(0);
295 match stmt {
296 Statement::Insert(insert) => {
297 assert!(insert.values_body().is_err());
298 }
299 _ => unreachable!(),
300 }
301 }
302
303 #[test]
304 fn test_insert_select() {
305 let sql = "INSERT INTO my_table select * from other_table";
306 let stmt =
307 ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
308 .unwrap()
309 .remove(0);
310 match stmt {
311 Statement::Insert(insert) => {
312 let q = insert.query_body().unwrap().unwrap();
313 assert!(matches!(
314 q.inner,
315 Query {
316 body: box SetExpr::Select { .. },
317 ..
318 }
319 ));
320 }
321 _ => unreachable!(),
322 }
323 }
324}