1use std::ops::ControlFlow;
16use std::time::Duration;
17
18use chrono::NaiveDate;
19use common_query::prelude::ScalarValue;
20use common_sql::convert::sql_value_to_value;
21use common_time::{Date, Timestamp};
22use datafusion_common::tree_node::{Transformed, TreeNode};
23use datafusion_expr::LogicalPlan;
24use datatypes::prelude::ConcreteDataType;
25use datatypes::types::TimestampType;
26use datatypes::value::{self, Value};
27use itertools::Itertools;
28use opensrv_mysql::{ParamValue, ValueInner, to_naive_datetime};
29use snafu::ResultExt;
30use sql::ast::{Expr, Value as ValueExpr, ValueWithSpan, VisitMut, visit_expressions_mut};
31use sql::statements::statement::Statement;
32
33use crate::error::{self, DataFusionSnafu, Result};
34
35pub fn format_placeholder(i: usize) -> String {
37 format!("${}", i)
38}
39
40pub fn replace_placeholders(query: &str) -> (String, usize) {
43 let query_parts = query.split('?').collect::<Vec<_>>();
44 let parts_len = query_parts.len();
45 let mut index = 0;
46 let query = query_parts
47 .into_iter()
48 .enumerate()
49 .map(|(i, part)| {
50 if i == parts_len - 1 {
51 return part.to_string();
52 }
53
54 index += 1;
55 format!("{part}{}", format_placeholder(index))
56 })
57 .join("");
58
59 (query, index + 1)
60}
61
62pub fn transform_placeholders(stmt: Statement) -> Statement {
65 match stmt {
66 Statement::Query(mut query) => {
67 visit_placeholders(&mut query.inner);
68 Statement::Query(query)
69 }
70 Statement::Insert(mut insert) => {
71 visit_placeholders(&mut insert.inner);
72 Statement::Insert(insert)
73 }
74 Statement::Delete(mut delete) => {
75 visit_placeholders(&mut delete.inner);
76 Statement::Delete(delete)
77 }
78 stmt => stmt,
79 }
80}
81
82pub fn fix_placeholder_types(plan: &mut LogicalPlan) -> Result<()> {
86 let give_placeholder_types = |mut e: datafusion_expr::Expr| {
87 if let datafusion_expr::Expr::Cast(cast) = &mut e {
88 if let datafusion_expr::Expr::Placeholder(ph) = &mut *cast.expr {
89 if ph.data_type.is_none() {
90 ph.data_type = Some(cast.data_type.clone());
91 common_telemetry::debug!(
92 "give placeholder type {:?} to {:?}",
93 cast.data_type,
94 ph
95 );
96 Ok(Transformed::yes(e))
97 } else {
98 Ok(Transformed::no(e))
99 }
100 } else {
101 Ok(Transformed::no(e))
102 }
103 } else {
104 Ok(Transformed::no(e))
105 }
106 };
107 let give_placeholder_types_recursively =
108 |e: datafusion_expr::Expr| e.transform(give_placeholder_types);
109 *plan = std::mem::take(plan)
110 .transform(|p| p.map_expressions(give_placeholder_types_recursively))
111 .context(DataFusionSnafu)?
112 .data;
113 Ok(())
114}
115
116fn visit_placeholders<V>(v: &mut V)
117where
118 V: VisitMut,
119{
120 let mut index = 1;
121 let _ = visit_expressions_mut(v, |expr| {
122 if let Expr::Value(ValueWithSpan {
123 value: ValueExpr::Placeholder(s),
124 ..
125 }) = expr
126 {
127 *s = format_placeholder(index);
128 index += 1;
129 }
130 ControlFlow::<()>::Continue(())
131 });
132}
133
134pub fn convert_value(param: &ParamValue, t: &ConcreteDataType) -> Result<ScalarValue> {
137 match param.value.into_inner() {
138 ValueInner::Int(i) => match t {
139 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(i as i8))),
140 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(i as i16))),
141 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(i as i32))),
142 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(i))),
143 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(i as u8))),
144 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(i as u16))),
145 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(i as u32))),
146 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(i as u64))),
147 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(i as f32))),
148 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(i as f64))),
149 ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(i != 0))),
150 ConcreteDataType::Timestamp(ts_type) => Value::Timestamp(ts_type.create_timestamp(i))
151 .try_to_scalar_value(t)
152 .context(error::ConvertScalarValueSnafu),
153
154 _ => error::PreparedStmtTypeMismatchSnafu {
155 expected: t,
156 actual: param.coltype,
157 }
158 .fail(),
159 },
160 ValueInner::UInt(u) => match t {
161 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(u as i8))),
162 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(u as i16))),
163 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(u as i32))),
164 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(u as i64))),
165 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(u as u8))),
166 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(u as u16))),
167 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(u as u32))),
168 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(u))),
169 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(u as f32))),
170 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(u as f64))),
171 ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(u != 0))),
172 ConcreteDataType::Timestamp(ts_type) => {
173 Value::Timestamp(ts_type.create_timestamp(u as i64))
174 .try_to_scalar_value(t)
175 .context(error::ConvertScalarValueSnafu)
176 }
177
178 _ => error::PreparedStmtTypeMismatchSnafu {
179 expected: t,
180 actual: param.coltype,
181 }
182 .fail(),
183 },
184 ValueInner::Double(f) => match t {
185 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(f as i8))),
186 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(f as i16))),
187 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(f as i32))),
188 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(f as i64))),
189 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(f as u8))),
190 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(f as u16))),
191 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(f as u32))),
192 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(f as u64))),
193 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(f as f32))),
194 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(f))),
195
196 _ => error::PreparedStmtTypeMismatchSnafu {
197 expected: t,
198 actual: param.coltype,
199 }
200 .fail(),
201 },
202 ValueInner::NULL => value::to_null_scalar_value(t).context(error::ConvertScalarValueSnafu),
203 ValueInner::Bytes(b) => match t {
204 ConcreteDataType::String(t) => {
205 let s = String::from_utf8_lossy(b).to_string();
206 if t.is_large() {
207 Ok(ScalarValue::LargeUtf8(Some(s)))
208 } else {
209 Ok(ScalarValue::Utf8(Some(s)))
210 }
211 }
212 ConcreteDataType::Binary(_) => Ok(ScalarValue::Binary(Some(b.to_vec()))),
213 ConcreteDataType::Timestamp(ts_type) => convert_bytes_to_timestamp(b, ts_type),
214 ConcreteDataType::Date(_) => convert_bytes_to_date(b),
215 _ => error::PreparedStmtTypeMismatchSnafu {
216 expected: t,
217 actual: param.coltype,
218 }
219 .fail(),
220 },
221 ValueInner::Date(_) => {
222 let date: common_time::Date = NaiveDate::from(param.value).into();
223 Ok(ScalarValue::Date32(Some(date.val())))
224 }
225 ValueInner::Datetime(_) => {
226 let timestamp_millis = to_naive_datetime(param.value)
227 .map_err(|e| {
228 error::MysqlValueConversionSnafu {
229 err_msg: e.to_string(),
230 }
231 .build()
232 })?
233 .and_utc()
234 .timestamp_millis();
235
236 match t {
237 ConcreteDataType::Timestamp(_) => Ok(ScalarValue::TimestampMillisecond(
238 Some(timestamp_millis),
239 None,
240 )),
241 _ => error::PreparedStmtTypeMismatchSnafu {
242 expected: t,
243 actual: param.coltype,
244 }
245 .fail(),
246 }
247 }
248 ValueInner::Time(_) => Ok(ScalarValue::Time64Nanosecond(Some(
249 Duration::from(param.value).as_millis() as i64,
250 ))),
251 }
252}
253
254pub fn convert_expr_to_scalar_value(param: &Expr, t: &ConcreteDataType) -> Result<ScalarValue> {
257 match param {
258 Expr::Value(v) => {
259 let v = sql_value_to_value("", t, &v.value, None, None, true);
260 match v {
261 Ok(v) => v
262 .try_to_scalar_value(t)
263 .context(error::ConvertScalarValueSnafu),
264 Err(e) => error::InvalidParameterSnafu {
265 reason: e.to_string(),
266 }
267 .fail(),
268 }
269 }
270 Expr::UnaryOp { op, expr } if let Expr::Value(v) = &**expr => {
271 let v = sql_value_to_value("", t, &v.value, None, Some(*op), true);
272 match v {
273 Ok(v) => v
274 .try_to_scalar_value(t)
275 .context(error::ConvertScalarValueSnafu),
276 Err(e) => error::InvalidParameterSnafu {
277 reason: e.to_string(),
278 }
279 .fail(),
280 }
281 }
282 _ => error::InvalidParameterSnafu {
283 reason: format!("cannot convert {:?} to scalar value of type {}", param, t),
284 }
285 .fail(),
286 }
287}
288
289fn convert_bytes_to_timestamp(bytes: &[u8], ts_type: &TimestampType) -> Result<ScalarValue> {
290 let ts = Timestamp::from_str_utc(&String::from_utf8_lossy(bytes))
291 .map_err(|e| {
292 error::MysqlValueConversionSnafu {
293 err_msg: e.to_string(),
294 }
295 .build()
296 })?
297 .convert_to(ts_type.unit())
298 .ok_or_else(|| {
299 error::MysqlValueConversionSnafu {
300 err_msg: "Overflow when converting timestamp to target unit".to_string(),
301 }
302 .build()
303 })?;
304 match ts_type {
305 TimestampType::Nanosecond(_) => {
306 Ok(ScalarValue::TimestampNanosecond(Some(ts.value()), None))
307 }
308 TimestampType::Microsecond(_) => {
309 Ok(ScalarValue::TimestampMicrosecond(Some(ts.value()), None))
310 }
311 TimestampType::Millisecond(_) => {
312 Ok(ScalarValue::TimestampMillisecond(Some(ts.value()), None))
313 }
314 TimestampType::Second(_) => Ok(ScalarValue::TimestampSecond(Some(ts.value()), None)),
315 }
316}
317
318fn convert_bytes_to_date(bytes: &[u8]) -> Result<ScalarValue> {
319 let date = Date::from_str_utc(&String::from_utf8_lossy(bytes)).map_err(|e| {
320 error::MysqlValueConversionSnafu {
321 err_msg: e.to_string(),
322 }
323 .build()
324 })?;
325
326 Ok(ScalarValue::Date32(Some(date.val())))
327}
328
329#[cfg(test)]
330mod tests {
331 use datatypes::types::{
332 TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
333 TimestampSecondType,
334 };
335 use sql::dialect::MySqlDialect;
336 use sql::parser::{ParseOptions, ParserContext};
337
338 use super::*;
339
340 #[test]
341 fn test_format_placeholder() {
342 assert_eq!("$1", format_placeholder(1));
343 assert_eq!("$3", format_placeholder(3));
344 }
345
346 #[test]
347 fn test_replace_placeholders() {
348 let create = "create table demo(host string, ts timestamp time index)";
349 let (sql, index) = replace_placeholders(create);
350 assert_eq!(create, sql);
351 assert_eq!(1, index);
352
353 let insert = "insert into demo values(?,?,?)";
354 let (sql, index) = replace_placeholders(insert);
355 assert_eq!("insert into demo values($1,$2,$3)", sql);
356 assert_eq!(4, index);
357
358 let query = "select from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?";
359 let (sql, index) = replace_placeholders(query);
360 assert_eq!(
361 "select from demo where host=$1 and idc in (select idc from idcs where name=$2) and cpu>$3",
362 sql
363 );
364 assert_eq!(4, index);
365 }
366
367 fn parse_sql(sql: &str) -> Statement {
368 let mut stmts =
369 ParserContext::create_with_dialect(sql, &MySqlDialect {}, ParseOptions::default())
370 .unwrap();
371 stmts.remove(0)
372 }
373
374 #[test]
375 fn test_transform_placeholders() {
376 let insert = parse_sql("insert into demo values(?,?,?)");
377 let Statement::Insert(insert) = transform_placeholders(insert) else {
378 unreachable!()
379 };
380 assert_eq!(
381 "INSERT INTO demo VALUES ($1, $2, $3)",
382 insert.inner.to_string()
383 );
384
385 let delete = parse_sql("delete from demo where host=? and idc=?");
386 let Statement::Delete(delete) = transform_placeholders(delete) else {
387 unreachable!()
388 };
389 assert_eq!(
390 "DELETE FROM demo WHERE host = $1 AND idc = $2",
391 delete.inner.to_string()
392 );
393
394 let select = parse_sql(
395 "select * from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?",
396 );
397 let Statement::Query(select) = transform_placeholders(select) else {
398 unreachable!()
399 };
400 assert_eq!(
401 "SELECT * FROM demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3",
402 select.inner.to_string()
403 );
404 }
405
406 #[test]
407 fn test_convert_expr_to_scalar_value() {
408 let expr = Expr::Value(ValueExpr::Number("123".to_string(), false).into());
409 let t = ConcreteDataType::int32_datatype();
410 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
411 assert_eq!(ScalarValue::Int32(Some(123)), v);
412
413 let expr = Expr::Value(ValueExpr::Number("123.456789".to_string(), false).into());
414 let t = ConcreteDataType::float64_datatype();
415 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
416 assert_eq!(ScalarValue::Float64(Some(123.456789)), v);
417
418 let expr = Expr::Value(ValueExpr::SingleQuotedString("2001-01-02".to_string()).into());
419 let t = ConcreteDataType::date_datatype();
420 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
421 let scalar_v = ScalarValue::Utf8(Some("2001-01-02".to_string()))
422 .cast_to(&arrow_schema::DataType::Date32)
423 .unwrap();
424 assert_eq!(scalar_v, v);
425
426 let expr =
427 Expr::Value(ValueExpr::SingleQuotedString("2001-01-02 03:04:05".to_string()).into());
428 let t = ConcreteDataType::timestamp_microsecond_datatype();
429 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
430 let scalar_v = ScalarValue::Utf8(Some("2001-01-02 03:04:05".to_string()))
431 .cast_to(&arrow_schema::DataType::Timestamp(
432 arrow_schema::TimeUnit::Microsecond,
433 None,
434 ))
435 .unwrap();
436 assert_eq!(scalar_v, v);
437
438 let expr = Expr::Value(ValueExpr::SingleQuotedString("hello".to_string()).into());
439 let t = ConcreteDataType::string_datatype();
440 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
441 assert_eq!(ScalarValue::Utf8(Some("hello".to_string())), v);
442
443 let expr = Expr::Value(ValueExpr::Null.into());
444 let t = ConcreteDataType::time_microsecond_datatype();
445 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
446 assert_eq!(ScalarValue::Time64Microsecond(None), v);
447 }
448
449 #[test]
450 fn test_convert_bytes_to_timestamp() {
451 let test_cases = vec![
452 (
454 "2024-12-26 12:00:00",
455 TimestampType::Nanosecond(TimestampNanosecondType),
456 ScalarValue::TimestampNanosecond(Some(1735214400000000000), None),
457 ),
458 (
460 "2024-12-26 12:00:00",
461 TimestampType::Microsecond(TimestampMicrosecondType),
462 ScalarValue::TimestampMicrosecond(Some(1735214400000000), None),
463 ),
464 (
466 "2024-12-26 12:00:00",
467 TimestampType::Millisecond(TimestampMillisecondType),
468 ScalarValue::TimestampMillisecond(Some(1735214400000), None),
469 ),
470 (
472 "2024-12-26 12:00:00",
473 TimestampType::Second(TimestampSecondType),
474 ScalarValue::TimestampSecond(Some(1735214400), None),
475 ),
476 (
478 "2024-12-26 12:00:00.123",
479 TimestampType::Nanosecond(TimestampNanosecondType),
480 ScalarValue::TimestampNanosecond(Some(1735214400123000000), None),
481 ),
482 (
484 "2024-12-26 12:00:00.123",
485 TimestampType::Microsecond(TimestampMicrosecondType),
486 ScalarValue::TimestampMicrosecond(Some(1735214400123000), None),
487 ),
488 (
490 "2024-12-26 12:00:00.123",
491 TimestampType::Millisecond(TimestampMillisecondType),
492 ScalarValue::TimestampMillisecond(Some(1735214400123), None),
493 ),
494 (
496 "2024-12-26 12:00:00.123",
497 TimestampType::Second(TimestampSecondType),
498 ScalarValue::TimestampSecond(Some(1735214400), None),
499 ),
500 (
502 "2024-12-26 12:00:00.123456",
503 TimestampType::Nanosecond(TimestampNanosecondType),
504 ScalarValue::TimestampNanosecond(Some(1735214400123456000), None),
505 ),
506 (
508 "2024-12-26 12:00:00.123456",
509 TimestampType::Microsecond(TimestampMicrosecondType),
510 ScalarValue::TimestampMicrosecond(Some(1735214400123456), None),
511 ),
512 (
514 "2024-12-26 12:00:00.123456",
515 TimestampType::Millisecond(TimestampMillisecondType),
516 ScalarValue::TimestampMillisecond(Some(1735214400123), None),
517 ),
518 (
520 "2024-12-26 12:00:00.123456",
521 TimestampType::Second(TimestampSecondType),
522 ScalarValue::TimestampSecond(Some(1735214400), None),
523 ),
524 ];
525
526 for (input, ts_type, expected) in test_cases {
527 let result = convert_bytes_to_timestamp(input.as_bytes(), &ts_type).unwrap();
528 assert_eq!(result, expected);
529 }
530 }
531
532 #[test]
533 fn test_convert_bytes_to_date() {
534 let test_cases = vec![
535 ("1970-01-01", ScalarValue::Date32(Some(0))),
537 ("1969-12-31", ScalarValue::Date32(Some(-1))),
538 ("2024-02-29", ScalarValue::Date32(Some(19782))),
539 ("2024-01-01", ScalarValue::Date32(Some(19723))),
540 ("2024-12-31", ScalarValue::Date32(Some(20088))),
541 ("2001-01-02", ScalarValue::Date32(Some(11324))),
542 ("2050-06-14", ScalarValue::Date32(Some(29384))),
543 ("2020-03-15", ScalarValue::Date32(Some(18336))),
544 ];
545
546 for (input, expected) in test_cases {
547 let result = convert_bytes_to_date(input.as_bytes()).unwrap();
548 assert_eq!(result, expected, "Failed for input: {}", input);
549 }
550 }
551}