1use std::ops::ControlFlow;
16use std::time::Duration;
17
18use chrono::NaiveDate;
19use common_query::prelude::ScalarValue;
20use common_sql::convert::sql_value_to_value;
21use common_time::{Date, Timestamp};
22use datatypes::prelude::ConcreteDataType;
23use datatypes::schema::ColumnSchema;
24use datatypes::types::TimestampType;
25use datatypes::value::{self, Value};
26#[cfg(test)]
27use itertools::Itertools;
28use opensrv_mysql::{ParamValue, ValueInner, to_naive_datetime};
29use snafu::ResultExt;
30use sql::ast::{Expr, Value as ValueExpr, ValueWithSpan, VisitMut, visit_expressions_mut};
31use sql::statements::statement::Statement;
32
33use crate::error::{self, Result};
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub(crate) struct PlaceholderSpan {
38 pub(crate) index: usize,
40 pub(crate) start_line: u64,
41 pub(crate) start_column: u64,
42 pub(crate) end_line: u64,
43 pub(crate) end_column: u64,
44}
45
46pub fn format_placeholder(i: usize) -> String {
48 format!("${}", i)
49}
50
51#[cfg(test)]
54pub fn replace_placeholders(query: &str) -> (String, usize) {
55 let query_parts = query.split('?').collect::<Vec<_>>();
56 let parts_len = query_parts.len();
57 let mut index = 0;
58 let query = query_parts
59 .into_iter()
60 .enumerate()
61 .map(|(i, part)| {
62 if i == parts_len - 1 {
63 return part.to_string();
64 }
65
66 index += 1;
67 format!("{part}{}", format_placeholder(index))
68 })
69 .join("");
70
71 (query, index + 1)
72}
73
74pub fn transform_placeholders_with_count(mut stmt: Statement) -> (Statement, usize) {
77 let count = visit_placeholders(&mut stmt);
78 (stmt, count)
79}
80
81pub(crate) fn placeholder_spans(mut stmt: Statement) -> Vec<PlaceholderSpan> {
83 let mut spans = Vec::new();
84 collect_placeholder_spans(&mut stmt, &mut spans);
85 spans
86}
87
88fn collect_placeholder_spans<V>(v: &mut V, spans: &mut Vec<PlaceholderSpan>)
89where
90 V: VisitMut,
91{
92 let _ = visit_expressions_mut(v, |expr| {
93 if let Expr::Value(ValueWithSpan {
94 value: ValueExpr::Placeholder(s),
95 span,
96 }) = expr
97 && let Some(index) = placeholder_index(s)
98 {
99 spans.push(PlaceholderSpan {
100 index,
101 start_line: span.start.line,
102 start_column: span.start.column,
103 end_line: span.end.line,
104 end_column: span.end.column,
105 });
106 }
107 ControlFlow::<()>::Continue(())
108 });
109}
110
111fn placeholder_index(s: &str) -> Option<usize> {
112 s.strip_prefix('$')?
113 .parse::<usize>()
114 .ok()
115 .filter(|i| *i > 0)
116}
117
118fn visit_placeholders<V>(v: &mut V) -> usize
119where
120 V: VisitMut,
121{
122 let mut index = 1;
123 let _ = visit_expressions_mut(v, |expr| {
124 if let Expr::Value(ValueWithSpan {
125 value: ValueExpr::Placeholder(s),
126 ..
127 }) = expr
128 && s == "?"
129 {
130 *s = format_placeholder(index);
131 index += 1;
132 }
133 ControlFlow::<()>::Continue(())
134 });
135 index - 1
136}
137
138pub fn convert_value(param: &ParamValue, t: &ConcreteDataType) -> Result<ScalarValue> {
141 match param.value.into_inner() {
142 ValueInner::Int(i) => match t {
143 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(i as i8))),
144 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(i as i16))),
145 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(i as i32))),
146 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(i))),
147 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(i as u8))),
148 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(i as u16))),
149 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(i as u32))),
150 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(i as u64))),
151 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(i as f32))),
152 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(i as f64))),
153 ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(i != 0))),
154 ConcreteDataType::Timestamp(ts_type) => Value::Timestamp(ts_type.create_timestamp(i))
155 .try_to_scalar_value(t)
156 .context(error::ConvertScalarValueSnafu),
157
158 _ => error::PreparedStmtTypeMismatchSnafu {
159 expected: t,
160 actual: param.coltype,
161 }
162 .fail(),
163 },
164 ValueInner::UInt(u) => match t {
165 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(u as i8))),
166 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(u as i16))),
167 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(u as i32))),
168 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(u as i64))),
169 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(u as u8))),
170 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(u as u16))),
171 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(u as u32))),
172 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(u))),
173 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(u as f32))),
174 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(u as f64))),
175 ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(u != 0))),
176 ConcreteDataType::Timestamp(ts_type) => {
177 Value::Timestamp(ts_type.create_timestamp(u as i64))
178 .try_to_scalar_value(t)
179 .context(error::ConvertScalarValueSnafu)
180 }
181
182 _ => error::PreparedStmtTypeMismatchSnafu {
183 expected: t,
184 actual: param.coltype,
185 }
186 .fail(),
187 },
188 ValueInner::Double(f) => match t {
189 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(f as i8))),
190 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(f as i16))),
191 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(f as i32))),
192 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(f as i64))),
193 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(f as u8))),
194 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(f as u16))),
195 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(f as u32))),
196 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(f as u64))),
197 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(f as f32))),
198 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(f))),
199
200 _ => error::PreparedStmtTypeMismatchSnafu {
201 expected: t,
202 actual: param.coltype,
203 }
204 .fail(),
205 },
206 ValueInner::NULL => value::to_null_scalar_value(t).context(error::ConvertScalarValueSnafu),
207 ValueInner::Bytes(b) => match t {
208 ConcreteDataType::String(t) => {
209 let s = String::from_utf8_lossy(b).to_string();
210 if t.is_large() {
211 Ok(ScalarValue::LargeUtf8(Some(s)))
212 } else {
213 Ok(ScalarValue::Utf8(Some(s)))
214 }
215 }
216 ConcreteDataType::Binary(_) => Ok(ScalarValue::Binary(Some(b.to_vec()))),
217 ConcreteDataType::Timestamp(ts_type) => convert_bytes_to_timestamp(b, ts_type),
218 ConcreteDataType::Date(_) => convert_bytes_to_date(b),
219 _ => error::PreparedStmtTypeMismatchSnafu {
220 expected: t,
221 actual: param.coltype,
222 }
223 .fail(),
224 },
225 ValueInner::Date(_) => {
226 let date: common_time::Date = NaiveDate::from(param.value).into();
227 Ok(ScalarValue::Date32(Some(date.val())))
228 }
229 ValueInner::Datetime(_) => {
230 let timestamp_millis = to_naive_datetime(param.value)
231 .map_err(|e| {
232 error::MysqlValueConversionSnafu {
233 err_msg: e.to_string(),
234 }
235 .build()
236 })?
237 .and_utc()
238 .timestamp_millis();
239
240 match t {
241 ConcreteDataType::Timestamp(_) => Ok(ScalarValue::TimestampMillisecond(
242 Some(timestamp_millis),
243 None,
244 )),
245 _ => error::PreparedStmtTypeMismatchSnafu {
246 expected: t,
247 actual: param.coltype,
248 }
249 .fail(),
250 }
251 }
252 ValueInner::Time(_) => Ok(ScalarValue::Time64Nanosecond(Some(
253 Duration::from(param.value).as_millis() as i64,
254 ))),
255 }
256}
257
258pub fn convert_expr_to_scalar_value(param: &Expr, t: &ConcreteDataType) -> Result<ScalarValue> {
261 let column_schema = ColumnSchema::new("", t.clone(), true);
262 match param {
263 Expr::Value(v) => {
264 let v = sql_value_to_value(&column_schema, &v.value, None, None, true);
265 match v {
266 Ok(v) => v
267 .try_to_scalar_value(t)
268 .context(error::ConvertScalarValueSnafu),
269 Err(e) => error::InvalidParameterSnafu {
270 reason: e.to_string(),
271 }
272 .fail(),
273 }
274 }
275 Expr::UnaryOp { op, expr } if let Expr::Value(v) = &**expr => {
276 let v = sql_value_to_value(&column_schema, &v.value, None, Some(*op), true);
277 match v {
278 Ok(v) => v
279 .try_to_scalar_value(t)
280 .context(error::ConvertScalarValueSnafu),
281 Err(e) => error::InvalidParameterSnafu {
282 reason: e.to_string(),
283 }
284 .fail(),
285 }
286 }
287 _ => error::InvalidParameterSnafu {
288 reason: format!("cannot convert {:?} to scalar value of type {}", param, t),
289 }
290 .fail(),
291 }
292}
293
294fn convert_bytes_to_timestamp(bytes: &[u8], ts_type: &TimestampType) -> Result<ScalarValue> {
295 let ts = Timestamp::from_str_utc(&String::from_utf8_lossy(bytes))
296 .map_err(|e| {
297 error::MysqlValueConversionSnafu {
298 err_msg: e.to_string(),
299 }
300 .build()
301 })?
302 .convert_to(ts_type.unit())
303 .ok_or_else(|| {
304 error::MysqlValueConversionSnafu {
305 err_msg: "Overflow when converting timestamp to target unit".to_string(),
306 }
307 .build()
308 })?;
309 match ts_type {
310 TimestampType::Nanosecond(_) => {
311 Ok(ScalarValue::TimestampNanosecond(Some(ts.value()), None))
312 }
313 TimestampType::Microsecond(_) => {
314 Ok(ScalarValue::TimestampMicrosecond(Some(ts.value()), None))
315 }
316 TimestampType::Millisecond(_) => {
317 Ok(ScalarValue::TimestampMillisecond(Some(ts.value()), None))
318 }
319 TimestampType::Second(_) => Ok(ScalarValue::TimestampSecond(Some(ts.value()), None)),
320 }
321}
322
323fn convert_bytes_to_date(bytes: &[u8]) -> Result<ScalarValue> {
324 let date = Date::from_str_utc(&String::from_utf8_lossy(bytes)).map_err(|e| {
325 error::MysqlValueConversionSnafu {
326 err_msg: e.to_string(),
327 }
328 .build()
329 })?;
330
331 Ok(ScalarValue::Date32(Some(date.val())))
332}
333
334#[cfg(test)]
335mod tests {
336 use datatypes::types::{
337 TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
338 TimestampSecondType,
339 };
340 use sql::dialect::MySqlDialect;
341 use sql::parser::{ParseOptions, ParserContext};
342
343 use super::*;
344
345 #[test]
346 fn test_format_placeholder() {
347 assert_eq!("$1", format_placeholder(1));
348 assert_eq!("$3", format_placeholder(3));
349 }
350
351 #[test]
352 fn test_replace_placeholders() {
353 let create = "create table demo(host string, ts timestamp time index)";
354 let (sql, index) = replace_placeholders(create);
355 assert_eq!(create, sql);
356 assert_eq!(1, index);
357
358 let insert = "insert into demo values(?,?,?)";
359 let (sql, index) = replace_placeholders(insert);
360 assert_eq!("insert into demo values($1,$2,$3)", sql);
361 assert_eq!(4, index);
362
363 let query = "select from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?";
364 let (sql, index) = replace_placeholders(query);
365 assert_eq!(
366 "select from demo where host=$1 and idc in (select idc from idcs where name=$2) and cpu>$3",
367 sql
368 );
369 assert_eq!(4, index);
370 }
371
372 fn parse_sql(sql: &str) -> Statement {
373 let mut stmts =
374 ParserContext::create_with_dialect(sql, &MySqlDialect {}, ParseOptions::default())
375 .unwrap();
376 stmts.remove(0)
377 }
378
379 #[test]
380 fn test_transform_placeholders() {
381 let insert = parse_sql("insert into demo values(?,?,?)");
382 let (stmt, count) = transform_placeholders_with_count(insert);
383 let Statement::Insert(insert) = stmt else {
384 unreachable!()
385 };
386 assert_eq!(
387 "INSERT INTO demo VALUES ($1, $2, $3)",
388 insert.inner.to_string()
389 );
390 assert_eq!(3, count);
391
392 let delete = parse_sql("delete from demo where host=? and idc=?");
393 let (stmt, count) = transform_placeholders_with_count(delete);
394 let Statement::Delete(delete) = stmt else {
395 unreachable!()
396 };
397 assert_eq!(
398 "DELETE FROM demo WHERE host = $1 AND idc = $2",
399 delete.inner.to_string()
400 );
401 assert_eq!(2, count);
402
403 let select = parse_sql(
404 "select * from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?",
405 );
406 let (stmt, count) = transform_placeholders_with_count(select);
407 let Statement::Query(select) = stmt else {
408 unreachable!()
409 };
410 assert_eq!(
411 "SELECT * FROM demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3",
412 select.inner.to_string()
413 );
414 assert_eq!(3, count);
415
416 let select = parse_sql("select '?', ?");
417 let (stmt, count) = transform_placeholders_with_count(select);
418 let Statement::Query(select) = stmt else {
419 unreachable!()
420 };
421 assert_eq!("SELECT '?', $1", select.inner.to_string());
422 assert_eq!(1, count);
423
424 let set = parse_sql("set time_zone = ?");
425 let (stmt, count) = transform_placeholders_with_count(set);
426 assert_eq!("SET time_zone = $1", stmt.to_string());
427 assert_eq!(1, count);
428 }
429
430 #[test]
431 fn test_convert_expr_to_scalar_value() {
432 let expr = Expr::Value(ValueExpr::Number("123".to_string(), false).into());
433 let t = ConcreteDataType::int32_datatype();
434 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
435 assert_eq!(ScalarValue::Int32(Some(123)), v);
436
437 let expr = Expr::Value(ValueExpr::Number("123.456789".to_string(), false).into());
438 let t = ConcreteDataType::float64_datatype();
439 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
440 assert_eq!(ScalarValue::Float64(Some(123.456789)), v);
441
442 let expr = Expr::Value(ValueExpr::SingleQuotedString("2001-01-02".to_string()).into());
443 let t = ConcreteDataType::date_datatype();
444 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
445 let scalar_v = ScalarValue::Utf8(Some("2001-01-02".to_string()))
446 .cast_to(&arrow_schema::DataType::Date32)
447 .unwrap();
448 assert_eq!(scalar_v, v);
449
450 let expr =
451 Expr::Value(ValueExpr::SingleQuotedString("2001-01-02 03:04:05".to_string()).into());
452 let t = ConcreteDataType::timestamp_microsecond_datatype();
453 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
454 let scalar_v = ScalarValue::Utf8(Some("2001-01-02 03:04:05".to_string()))
455 .cast_to(&arrow_schema::DataType::Timestamp(
456 arrow_schema::TimeUnit::Microsecond,
457 None,
458 ))
459 .unwrap();
460 assert_eq!(scalar_v, v);
461
462 let expr = Expr::Value(ValueExpr::SingleQuotedString("hello".to_string()).into());
463 let t = ConcreteDataType::string_datatype();
464 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
465 assert_eq!(ScalarValue::Utf8(Some("hello".to_string())), v);
466
467 let expr = Expr::Value(ValueExpr::Null.into());
468 let t = ConcreteDataType::time_microsecond_datatype();
469 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
470 assert_eq!(ScalarValue::Time64Microsecond(None), v);
471 }
472
473 #[test]
474 fn test_convert_bytes_to_timestamp() {
475 let test_cases = vec![
476 (
478 "2024-12-26 12:00:00",
479 TimestampType::Nanosecond(TimestampNanosecondType),
480 ScalarValue::TimestampNanosecond(Some(1735214400000000000), None),
481 ),
482 (
484 "2024-12-26 12:00:00",
485 TimestampType::Microsecond(TimestampMicrosecondType),
486 ScalarValue::TimestampMicrosecond(Some(1735214400000000), None),
487 ),
488 (
490 "2024-12-26 12:00:00",
491 TimestampType::Millisecond(TimestampMillisecondType),
492 ScalarValue::TimestampMillisecond(Some(1735214400000), None),
493 ),
494 (
496 "2024-12-26 12:00:00",
497 TimestampType::Second(TimestampSecondType),
498 ScalarValue::TimestampSecond(Some(1735214400), None),
499 ),
500 (
502 "2024-12-26 12:00:00.123",
503 TimestampType::Nanosecond(TimestampNanosecondType),
504 ScalarValue::TimestampNanosecond(Some(1735214400123000000), None),
505 ),
506 (
508 "2024-12-26 12:00:00.123",
509 TimestampType::Microsecond(TimestampMicrosecondType),
510 ScalarValue::TimestampMicrosecond(Some(1735214400123000), None),
511 ),
512 (
514 "2024-12-26 12:00:00.123",
515 TimestampType::Millisecond(TimestampMillisecondType),
516 ScalarValue::TimestampMillisecond(Some(1735214400123), None),
517 ),
518 (
520 "2024-12-26 12:00:00.123",
521 TimestampType::Second(TimestampSecondType),
522 ScalarValue::TimestampSecond(Some(1735214400), None),
523 ),
524 (
526 "2024-12-26 12:00:00.123456",
527 TimestampType::Nanosecond(TimestampNanosecondType),
528 ScalarValue::TimestampNanosecond(Some(1735214400123456000), None),
529 ),
530 (
532 "2024-12-26 12:00:00.123456",
533 TimestampType::Microsecond(TimestampMicrosecondType),
534 ScalarValue::TimestampMicrosecond(Some(1735214400123456), None),
535 ),
536 (
538 "2024-12-26 12:00:00.123456",
539 TimestampType::Millisecond(TimestampMillisecondType),
540 ScalarValue::TimestampMillisecond(Some(1735214400123), None),
541 ),
542 (
544 "2024-12-26 12:00:00.123456",
545 TimestampType::Second(TimestampSecondType),
546 ScalarValue::TimestampSecond(Some(1735214400), None),
547 ),
548 ];
549
550 for (input, ts_type, expected) in test_cases {
551 let result = convert_bytes_to_timestamp(input.as_bytes(), &ts_type).unwrap();
552 assert_eq!(result, expected);
553 }
554 }
555
556 #[test]
557 fn test_convert_bytes_to_date() {
558 let test_cases = vec![
559 ("1970-01-01", ScalarValue::Date32(Some(0))),
561 ("1969-12-31", ScalarValue::Date32(Some(-1))),
562 ("2024-02-29", ScalarValue::Date32(Some(19782))),
563 ("2024-01-01", ScalarValue::Date32(Some(19723))),
564 ("2024-12-31", ScalarValue::Date32(Some(20088))),
565 ("2001-01-02", ScalarValue::Date32(Some(11324))),
566 ("2050-06-14", ScalarValue::Date32(Some(29384))),
567 ("2020-03-15", ScalarValue::Date32(Some(18336))),
568 ];
569
570 for (input, expected) in test_cases {
571 let result = convert_bytes_to_date(input.as_bytes()).unwrap();
572 assert_eq!(result, expected, "Failed for input: {}", input);
573 }
574 }
575}