1use std::ops::ControlFlow;
16use std::time::Duration;
17
18use chrono::NaiveDate;
19use common_query::prelude::ScalarValue;
20use common_sql::convert::sql_value_to_value;
21use common_time::{Date, Timestamp};
22use datatypes::prelude::ConcreteDataType;
23use datatypes::schema::ColumnSchema;
24use datatypes::types::TimestampType;
25use datatypes::value::{self, Value};
26use itertools::Itertools;
27use opensrv_mysql::{ParamValue, ValueInner, to_naive_datetime};
28use snafu::ResultExt;
29use sql::ast::{Expr, Value as ValueExpr, ValueWithSpan, VisitMut, visit_expressions_mut};
30use sql::statements::statement::Statement;
31
32use crate::error::{self, Result};
33
34pub fn format_placeholder(i: usize) -> String {
36 format!("${}", i)
37}
38
39pub fn replace_placeholders(query: &str) -> (String, usize) {
42 let query_parts = query.split('?').collect::<Vec<_>>();
43 let parts_len = query_parts.len();
44 let mut index = 0;
45 let query = query_parts
46 .into_iter()
47 .enumerate()
48 .map(|(i, part)| {
49 if i == parts_len - 1 {
50 return part.to_string();
51 }
52
53 index += 1;
54 format!("{part}{}", format_placeholder(index))
55 })
56 .join("");
57
58 (query, index + 1)
59}
60
61pub fn transform_placeholders(stmt: Statement) -> Statement {
64 match stmt {
65 Statement::Query(mut query) => {
66 visit_placeholders(&mut query.inner);
67 Statement::Query(query)
68 }
69 Statement::Insert(mut insert) => {
70 visit_placeholders(&mut insert.inner);
71 Statement::Insert(insert)
72 }
73 Statement::Delete(mut delete) => {
74 visit_placeholders(&mut delete.inner);
75 Statement::Delete(delete)
76 }
77 stmt => stmt,
78 }
79}
80
81fn visit_placeholders<V>(v: &mut V)
82where
83 V: VisitMut,
84{
85 let mut index = 1;
86 let _ = visit_expressions_mut(v, |expr| {
87 if let Expr::Value(ValueWithSpan {
88 value: ValueExpr::Placeholder(s),
89 ..
90 }) = expr
91 {
92 *s = format_placeholder(index);
93 index += 1;
94 }
95 ControlFlow::<()>::Continue(())
96 });
97}
98
99pub fn convert_value(param: &ParamValue, t: &ConcreteDataType) -> Result<ScalarValue> {
102 match param.value.into_inner() {
103 ValueInner::Int(i) => match t {
104 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(i as i8))),
105 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(i as i16))),
106 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(i as i32))),
107 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(i))),
108 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(i as u8))),
109 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(i as u16))),
110 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(i as u32))),
111 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(i as u64))),
112 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(i as f32))),
113 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(i as f64))),
114 ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(i != 0))),
115 ConcreteDataType::Timestamp(ts_type) => Value::Timestamp(ts_type.create_timestamp(i))
116 .try_to_scalar_value(t)
117 .context(error::ConvertScalarValueSnafu),
118
119 _ => error::PreparedStmtTypeMismatchSnafu {
120 expected: t,
121 actual: param.coltype,
122 }
123 .fail(),
124 },
125 ValueInner::UInt(u) => match t {
126 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(u as i8))),
127 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(u as i16))),
128 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(u as i32))),
129 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(u as i64))),
130 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(u as u8))),
131 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(u as u16))),
132 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(u as u32))),
133 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(u))),
134 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(u as f32))),
135 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(u as f64))),
136 ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(u != 0))),
137 ConcreteDataType::Timestamp(ts_type) => {
138 Value::Timestamp(ts_type.create_timestamp(u as i64))
139 .try_to_scalar_value(t)
140 .context(error::ConvertScalarValueSnafu)
141 }
142
143 _ => error::PreparedStmtTypeMismatchSnafu {
144 expected: t,
145 actual: param.coltype,
146 }
147 .fail(),
148 },
149 ValueInner::Double(f) => match t {
150 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(f as i8))),
151 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(f as i16))),
152 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(f as i32))),
153 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(f as i64))),
154 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(f as u8))),
155 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(f as u16))),
156 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(f as u32))),
157 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(f as u64))),
158 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(f as f32))),
159 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(f))),
160
161 _ => error::PreparedStmtTypeMismatchSnafu {
162 expected: t,
163 actual: param.coltype,
164 }
165 .fail(),
166 },
167 ValueInner::NULL => value::to_null_scalar_value(t).context(error::ConvertScalarValueSnafu),
168 ValueInner::Bytes(b) => match t {
169 ConcreteDataType::String(t) => {
170 let s = String::from_utf8_lossy(b).to_string();
171 if t.is_large() {
172 Ok(ScalarValue::LargeUtf8(Some(s)))
173 } else {
174 Ok(ScalarValue::Utf8(Some(s)))
175 }
176 }
177 ConcreteDataType::Binary(_) => Ok(ScalarValue::Binary(Some(b.to_vec()))),
178 ConcreteDataType::Timestamp(ts_type) => convert_bytes_to_timestamp(b, ts_type),
179 ConcreteDataType::Date(_) => convert_bytes_to_date(b),
180 _ => error::PreparedStmtTypeMismatchSnafu {
181 expected: t,
182 actual: param.coltype,
183 }
184 .fail(),
185 },
186 ValueInner::Date(_) => {
187 let date: common_time::Date = NaiveDate::from(param.value).into();
188 Ok(ScalarValue::Date32(Some(date.val())))
189 }
190 ValueInner::Datetime(_) => {
191 let timestamp_millis = to_naive_datetime(param.value)
192 .map_err(|e| {
193 error::MysqlValueConversionSnafu {
194 err_msg: e.to_string(),
195 }
196 .build()
197 })?
198 .and_utc()
199 .timestamp_millis();
200
201 match t {
202 ConcreteDataType::Timestamp(_) => Ok(ScalarValue::TimestampMillisecond(
203 Some(timestamp_millis),
204 None,
205 )),
206 _ => error::PreparedStmtTypeMismatchSnafu {
207 expected: t,
208 actual: param.coltype,
209 }
210 .fail(),
211 }
212 }
213 ValueInner::Time(_) => Ok(ScalarValue::Time64Nanosecond(Some(
214 Duration::from(param.value).as_millis() as i64,
215 ))),
216 }
217}
218
219pub fn convert_expr_to_scalar_value(param: &Expr, t: &ConcreteDataType) -> Result<ScalarValue> {
222 let column_schema = ColumnSchema::new("", t.clone(), true);
223 match param {
224 Expr::Value(v) => {
225 let v = sql_value_to_value(&column_schema, &v.value, None, None, true);
226 match v {
227 Ok(v) => v
228 .try_to_scalar_value(t)
229 .context(error::ConvertScalarValueSnafu),
230 Err(e) => error::InvalidParameterSnafu {
231 reason: e.to_string(),
232 }
233 .fail(),
234 }
235 }
236 Expr::UnaryOp { op, expr } if let Expr::Value(v) = &**expr => {
237 let v = sql_value_to_value(&column_schema, &v.value, None, Some(*op), true);
238 match v {
239 Ok(v) => v
240 .try_to_scalar_value(t)
241 .context(error::ConvertScalarValueSnafu),
242 Err(e) => error::InvalidParameterSnafu {
243 reason: e.to_string(),
244 }
245 .fail(),
246 }
247 }
248 _ => error::InvalidParameterSnafu {
249 reason: format!("cannot convert {:?} to scalar value of type {}", param, t),
250 }
251 .fail(),
252 }
253}
254
255fn convert_bytes_to_timestamp(bytes: &[u8], ts_type: &TimestampType) -> Result<ScalarValue> {
256 let ts = Timestamp::from_str_utc(&String::from_utf8_lossy(bytes))
257 .map_err(|e| {
258 error::MysqlValueConversionSnafu {
259 err_msg: e.to_string(),
260 }
261 .build()
262 })?
263 .convert_to(ts_type.unit())
264 .ok_or_else(|| {
265 error::MysqlValueConversionSnafu {
266 err_msg: "Overflow when converting timestamp to target unit".to_string(),
267 }
268 .build()
269 })?;
270 match ts_type {
271 TimestampType::Nanosecond(_) => {
272 Ok(ScalarValue::TimestampNanosecond(Some(ts.value()), None))
273 }
274 TimestampType::Microsecond(_) => {
275 Ok(ScalarValue::TimestampMicrosecond(Some(ts.value()), None))
276 }
277 TimestampType::Millisecond(_) => {
278 Ok(ScalarValue::TimestampMillisecond(Some(ts.value()), None))
279 }
280 TimestampType::Second(_) => Ok(ScalarValue::TimestampSecond(Some(ts.value()), None)),
281 }
282}
283
284fn convert_bytes_to_date(bytes: &[u8]) -> Result<ScalarValue> {
285 let date = Date::from_str_utc(&String::from_utf8_lossy(bytes)).map_err(|e| {
286 error::MysqlValueConversionSnafu {
287 err_msg: e.to_string(),
288 }
289 .build()
290 })?;
291
292 Ok(ScalarValue::Date32(Some(date.val())))
293}
294
295#[cfg(test)]
296mod tests {
297 use datatypes::types::{
298 TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
299 TimestampSecondType,
300 };
301 use sql::dialect::MySqlDialect;
302 use sql::parser::{ParseOptions, ParserContext};
303
304 use super::*;
305
306 #[test]
307 fn test_format_placeholder() {
308 assert_eq!("$1", format_placeholder(1));
309 assert_eq!("$3", format_placeholder(3));
310 }
311
312 #[test]
313 fn test_replace_placeholders() {
314 let create = "create table demo(host string, ts timestamp time index)";
315 let (sql, index) = replace_placeholders(create);
316 assert_eq!(create, sql);
317 assert_eq!(1, index);
318
319 let insert = "insert into demo values(?,?,?)";
320 let (sql, index) = replace_placeholders(insert);
321 assert_eq!("insert into demo values($1,$2,$3)", sql);
322 assert_eq!(4, index);
323
324 let query = "select from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?";
325 let (sql, index) = replace_placeholders(query);
326 assert_eq!(
327 "select from demo where host=$1 and idc in (select idc from idcs where name=$2) and cpu>$3",
328 sql
329 );
330 assert_eq!(4, index);
331 }
332
333 fn parse_sql(sql: &str) -> Statement {
334 let mut stmts =
335 ParserContext::create_with_dialect(sql, &MySqlDialect {}, ParseOptions::default())
336 .unwrap();
337 stmts.remove(0)
338 }
339
340 #[test]
341 fn test_transform_placeholders() {
342 let insert = parse_sql("insert into demo values(?,?,?)");
343 let Statement::Insert(insert) = transform_placeholders(insert) else {
344 unreachable!()
345 };
346 assert_eq!(
347 "INSERT INTO demo VALUES ($1, $2, $3)",
348 insert.inner.to_string()
349 );
350
351 let delete = parse_sql("delete from demo where host=? and idc=?");
352 let Statement::Delete(delete) = transform_placeholders(delete) else {
353 unreachable!()
354 };
355 assert_eq!(
356 "DELETE FROM demo WHERE host = $1 AND idc = $2",
357 delete.inner.to_string()
358 );
359
360 let select = parse_sql(
361 "select * from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?",
362 );
363 let Statement::Query(select) = transform_placeholders(select) else {
364 unreachable!()
365 };
366 assert_eq!(
367 "SELECT * FROM demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3",
368 select.inner.to_string()
369 );
370 }
371
372 #[test]
373 fn test_convert_expr_to_scalar_value() {
374 let expr = Expr::Value(ValueExpr::Number("123".to_string(), false).into());
375 let t = ConcreteDataType::int32_datatype();
376 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
377 assert_eq!(ScalarValue::Int32(Some(123)), v);
378
379 let expr = Expr::Value(ValueExpr::Number("123.456789".to_string(), false).into());
380 let t = ConcreteDataType::float64_datatype();
381 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
382 assert_eq!(ScalarValue::Float64(Some(123.456789)), v);
383
384 let expr = Expr::Value(ValueExpr::SingleQuotedString("2001-01-02".to_string()).into());
385 let t = ConcreteDataType::date_datatype();
386 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
387 let scalar_v = ScalarValue::Utf8(Some("2001-01-02".to_string()))
388 .cast_to(&arrow_schema::DataType::Date32)
389 .unwrap();
390 assert_eq!(scalar_v, v);
391
392 let expr =
393 Expr::Value(ValueExpr::SingleQuotedString("2001-01-02 03:04:05".to_string()).into());
394 let t = ConcreteDataType::timestamp_microsecond_datatype();
395 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
396 let scalar_v = ScalarValue::Utf8(Some("2001-01-02 03:04:05".to_string()))
397 .cast_to(&arrow_schema::DataType::Timestamp(
398 arrow_schema::TimeUnit::Microsecond,
399 None,
400 ))
401 .unwrap();
402 assert_eq!(scalar_v, v);
403
404 let expr = Expr::Value(ValueExpr::SingleQuotedString("hello".to_string()).into());
405 let t = ConcreteDataType::string_datatype();
406 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
407 assert_eq!(ScalarValue::Utf8(Some("hello".to_string())), v);
408
409 let expr = Expr::Value(ValueExpr::Null.into());
410 let t = ConcreteDataType::time_microsecond_datatype();
411 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
412 assert_eq!(ScalarValue::Time64Microsecond(None), v);
413 }
414
415 #[test]
416 fn test_convert_bytes_to_timestamp() {
417 let test_cases = vec![
418 (
420 "2024-12-26 12:00:00",
421 TimestampType::Nanosecond(TimestampNanosecondType),
422 ScalarValue::TimestampNanosecond(Some(1735214400000000000), None),
423 ),
424 (
426 "2024-12-26 12:00:00",
427 TimestampType::Microsecond(TimestampMicrosecondType),
428 ScalarValue::TimestampMicrosecond(Some(1735214400000000), None),
429 ),
430 (
432 "2024-12-26 12:00:00",
433 TimestampType::Millisecond(TimestampMillisecondType),
434 ScalarValue::TimestampMillisecond(Some(1735214400000), None),
435 ),
436 (
438 "2024-12-26 12:00:00",
439 TimestampType::Second(TimestampSecondType),
440 ScalarValue::TimestampSecond(Some(1735214400), None),
441 ),
442 (
444 "2024-12-26 12:00:00.123",
445 TimestampType::Nanosecond(TimestampNanosecondType),
446 ScalarValue::TimestampNanosecond(Some(1735214400123000000), None),
447 ),
448 (
450 "2024-12-26 12:00:00.123",
451 TimestampType::Microsecond(TimestampMicrosecondType),
452 ScalarValue::TimestampMicrosecond(Some(1735214400123000), None),
453 ),
454 (
456 "2024-12-26 12:00:00.123",
457 TimestampType::Millisecond(TimestampMillisecondType),
458 ScalarValue::TimestampMillisecond(Some(1735214400123), None),
459 ),
460 (
462 "2024-12-26 12:00:00.123",
463 TimestampType::Second(TimestampSecondType),
464 ScalarValue::TimestampSecond(Some(1735214400), None),
465 ),
466 (
468 "2024-12-26 12:00:00.123456",
469 TimestampType::Nanosecond(TimestampNanosecondType),
470 ScalarValue::TimestampNanosecond(Some(1735214400123456000), None),
471 ),
472 (
474 "2024-12-26 12:00:00.123456",
475 TimestampType::Microsecond(TimestampMicrosecondType),
476 ScalarValue::TimestampMicrosecond(Some(1735214400123456), None),
477 ),
478 (
480 "2024-12-26 12:00:00.123456",
481 TimestampType::Millisecond(TimestampMillisecondType),
482 ScalarValue::TimestampMillisecond(Some(1735214400123), None),
483 ),
484 (
486 "2024-12-26 12:00:00.123456",
487 TimestampType::Second(TimestampSecondType),
488 ScalarValue::TimestampSecond(Some(1735214400), None),
489 ),
490 ];
491
492 for (input, ts_type, expected) in test_cases {
493 let result = convert_bytes_to_timestamp(input.as_bytes(), &ts_type).unwrap();
494 assert_eq!(result, expected);
495 }
496 }
497
498 #[test]
499 fn test_convert_bytes_to_date() {
500 let test_cases = vec![
501 ("1970-01-01", ScalarValue::Date32(Some(0))),
503 ("1969-12-31", ScalarValue::Date32(Some(-1))),
504 ("2024-02-29", ScalarValue::Date32(Some(19782))),
505 ("2024-01-01", ScalarValue::Date32(Some(19723))),
506 ("2024-12-31", ScalarValue::Date32(Some(20088))),
507 ("2001-01-02", ScalarValue::Date32(Some(11324))),
508 ("2050-06-14", ScalarValue::Date32(Some(29384))),
509 ("2020-03-15", ScalarValue::Date32(Some(18336))),
510 ];
511
512 for (input, expected) in test_cases {
513 let result = convert_bytes_to_date(input.as_bytes()).unwrap();
514 assert_eq!(result, expected, "Failed for input: {}", input);
515 }
516 }
517}