1use std::ops::ControlFlow;
16use std::sync::Arc;
17use std::time::Duration;
18
19use arrow_schema::Field;
20use chrono::NaiveDate;
21use common_query::prelude::ScalarValue;
22use common_sql::convert::sql_value_to_value;
23use common_time::{Date, Timestamp};
24use datafusion_common::tree_node::{Transformed, TreeNode};
25use datafusion_expr::LogicalPlan;
26use datatypes::prelude::ConcreteDataType;
27use datatypes::schema::ColumnSchema;
28use datatypes::types::TimestampType;
29use datatypes::value::{self, Value};
30use itertools::Itertools;
31use opensrv_mysql::{ParamValue, ValueInner, to_naive_datetime};
32use snafu::ResultExt;
33use sql::ast::{Expr, Value as ValueExpr, ValueWithSpan, VisitMut, visit_expressions_mut};
34use sql::statements::statement::Statement;
35
36use crate::error::{self, DataFusionSnafu, Result};
37
38pub fn format_placeholder(i: usize) -> String {
40 format!("${}", i)
41}
42
43pub fn replace_placeholders(query: &str) -> (String, usize) {
46 let query_parts = query.split('?').collect::<Vec<_>>();
47 let parts_len = query_parts.len();
48 let mut index = 0;
49 let query = query_parts
50 .into_iter()
51 .enumerate()
52 .map(|(i, part)| {
53 if i == parts_len - 1 {
54 return part.to_string();
55 }
56
57 index += 1;
58 format!("{part}{}", format_placeholder(index))
59 })
60 .join("");
61
62 (query, index + 1)
63}
64
65pub fn transform_placeholders(stmt: Statement) -> Statement {
68 match stmt {
69 Statement::Query(mut query) => {
70 visit_placeholders(&mut query.inner);
71 Statement::Query(query)
72 }
73 Statement::Insert(mut insert) => {
74 visit_placeholders(&mut insert.inner);
75 Statement::Insert(insert)
76 }
77 Statement::Delete(mut delete) => {
78 visit_placeholders(&mut delete.inner);
79 Statement::Delete(delete)
80 }
81 stmt => stmt,
82 }
83}
84
85pub fn fix_placeholder_types(plan: &mut LogicalPlan) -> Result<()> {
89 let give_placeholder_types = |mut e: datafusion_expr::Expr| {
90 if let datafusion_expr::Expr::Cast(cast) = &mut e {
91 if let datafusion_expr::Expr::Placeholder(ph) = &mut *cast.expr {
92 if ph.field.is_none() {
93 ph.field = Some(Arc::new(Field::new("", cast.data_type.clone(), true)));
94 common_telemetry::debug!(
95 "give placeholder type {:?} to {:?}",
96 cast.data_type,
97 ph
98 );
99 Ok(Transformed::yes(e))
100 } else {
101 Ok(Transformed::no(e))
102 }
103 } else {
104 Ok(Transformed::no(e))
105 }
106 } else {
107 Ok(Transformed::no(e))
108 }
109 };
110 let give_placeholder_types_recursively =
111 |e: datafusion_expr::Expr| e.transform(give_placeholder_types);
112 *plan = std::mem::take(plan)
113 .transform(|p| p.map_expressions(give_placeholder_types_recursively))
114 .context(DataFusionSnafu)?
115 .data;
116 Ok(())
117}
118
119fn visit_placeholders<V>(v: &mut V)
120where
121 V: VisitMut,
122{
123 let mut index = 1;
124 let _ = visit_expressions_mut(v, |expr| {
125 if let Expr::Value(ValueWithSpan {
126 value: ValueExpr::Placeholder(s),
127 ..
128 }) = expr
129 {
130 *s = format_placeholder(index);
131 index += 1;
132 }
133 ControlFlow::<()>::Continue(())
134 });
135}
136
137pub fn convert_value(param: &ParamValue, t: &ConcreteDataType) -> Result<ScalarValue> {
140 match param.value.into_inner() {
141 ValueInner::Int(i) => match t {
142 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(i as i8))),
143 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(i as i16))),
144 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(i as i32))),
145 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(i))),
146 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(i as u8))),
147 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(i as u16))),
148 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(i as u32))),
149 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(i as u64))),
150 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(i as f32))),
151 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(i as f64))),
152 ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(i != 0))),
153 ConcreteDataType::Timestamp(ts_type) => Value::Timestamp(ts_type.create_timestamp(i))
154 .try_to_scalar_value(t)
155 .context(error::ConvertScalarValueSnafu),
156
157 _ => error::PreparedStmtTypeMismatchSnafu {
158 expected: t,
159 actual: param.coltype,
160 }
161 .fail(),
162 },
163 ValueInner::UInt(u) => match t {
164 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(u as i8))),
165 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(u as i16))),
166 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(u as i32))),
167 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(u as i64))),
168 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(u as u8))),
169 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(u as u16))),
170 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(u as u32))),
171 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(u))),
172 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(u as f32))),
173 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(u as f64))),
174 ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(u != 0))),
175 ConcreteDataType::Timestamp(ts_type) => {
176 Value::Timestamp(ts_type.create_timestamp(u as i64))
177 .try_to_scalar_value(t)
178 .context(error::ConvertScalarValueSnafu)
179 }
180
181 _ => error::PreparedStmtTypeMismatchSnafu {
182 expected: t,
183 actual: param.coltype,
184 }
185 .fail(),
186 },
187 ValueInner::Double(f) => match t {
188 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(f as i8))),
189 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(f as i16))),
190 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(f as i32))),
191 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(f as i64))),
192 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(f as u8))),
193 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(f as u16))),
194 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(f as u32))),
195 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(f as u64))),
196 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(f as f32))),
197 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(f))),
198
199 _ => error::PreparedStmtTypeMismatchSnafu {
200 expected: t,
201 actual: param.coltype,
202 }
203 .fail(),
204 },
205 ValueInner::NULL => value::to_null_scalar_value(t).context(error::ConvertScalarValueSnafu),
206 ValueInner::Bytes(b) => match t {
207 ConcreteDataType::String(t) => {
208 let s = String::from_utf8_lossy(b).to_string();
209 if t.is_large() {
210 Ok(ScalarValue::LargeUtf8(Some(s)))
211 } else {
212 Ok(ScalarValue::Utf8(Some(s)))
213 }
214 }
215 ConcreteDataType::Binary(_) => Ok(ScalarValue::Binary(Some(b.to_vec()))),
216 ConcreteDataType::Timestamp(ts_type) => convert_bytes_to_timestamp(b, ts_type),
217 ConcreteDataType::Date(_) => convert_bytes_to_date(b),
218 _ => error::PreparedStmtTypeMismatchSnafu {
219 expected: t,
220 actual: param.coltype,
221 }
222 .fail(),
223 },
224 ValueInner::Date(_) => {
225 let date: common_time::Date = NaiveDate::from(param.value).into();
226 Ok(ScalarValue::Date32(Some(date.val())))
227 }
228 ValueInner::Datetime(_) => {
229 let timestamp_millis = to_naive_datetime(param.value)
230 .map_err(|e| {
231 error::MysqlValueConversionSnafu {
232 err_msg: e.to_string(),
233 }
234 .build()
235 })?
236 .and_utc()
237 .timestamp_millis();
238
239 match t {
240 ConcreteDataType::Timestamp(_) => Ok(ScalarValue::TimestampMillisecond(
241 Some(timestamp_millis),
242 None,
243 )),
244 _ => error::PreparedStmtTypeMismatchSnafu {
245 expected: t,
246 actual: param.coltype,
247 }
248 .fail(),
249 }
250 }
251 ValueInner::Time(_) => Ok(ScalarValue::Time64Nanosecond(Some(
252 Duration::from(param.value).as_millis() as i64,
253 ))),
254 }
255}
256
257pub fn convert_expr_to_scalar_value(param: &Expr, t: &ConcreteDataType) -> Result<ScalarValue> {
260 let column_schema = ColumnSchema::new("", t.clone(), true);
261 match param {
262 Expr::Value(v) => {
263 let v = sql_value_to_value(&column_schema, &v.value, None, None, true);
264 match v {
265 Ok(v) => v
266 .try_to_scalar_value(t)
267 .context(error::ConvertScalarValueSnafu),
268 Err(e) => error::InvalidParameterSnafu {
269 reason: e.to_string(),
270 }
271 .fail(),
272 }
273 }
274 Expr::UnaryOp { op, expr } if let Expr::Value(v) = &**expr => {
275 let v = sql_value_to_value(&column_schema, &v.value, None, Some(*op), true);
276 match v {
277 Ok(v) => v
278 .try_to_scalar_value(t)
279 .context(error::ConvertScalarValueSnafu),
280 Err(e) => error::InvalidParameterSnafu {
281 reason: e.to_string(),
282 }
283 .fail(),
284 }
285 }
286 _ => error::InvalidParameterSnafu {
287 reason: format!("cannot convert {:?} to scalar value of type {}", param, t),
288 }
289 .fail(),
290 }
291}
292
293fn convert_bytes_to_timestamp(bytes: &[u8], ts_type: &TimestampType) -> Result<ScalarValue> {
294 let ts = Timestamp::from_str_utc(&String::from_utf8_lossy(bytes))
295 .map_err(|e| {
296 error::MysqlValueConversionSnafu {
297 err_msg: e.to_string(),
298 }
299 .build()
300 })?
301 .convert_to(ts_type.unit())
302 .ok_or_else(|| {
303 error::MysqlValueConversionSnafu {
304 err_msg: "Overflow when converting timestamp to target unit".to_string(),
305 }
306 .build()
307 })?;
308 match ts_type {
309 TimestampType::Nanosecond(_) => {
310 Ok(ScalarValue::TimestampNanosecond(Some(ts.value()), None))
311 }
312 TimestampType::Microsecond(_) => {
313 Ok(ScalarValue::TimestampMicrosecond(Some(ts.value()), None))
314 }
315 TimestampType::Millisecond(_) => {
316 Ok(ScalarValue::TimestampMillisecond(Some(ts.value()), None))
317 }
318 TimestampType::Second(_) => Ok(ScalarValue::TimestampSecond(Some(ts.value()), None)),
319 }
320}
321
322fn convert_bytes_to_date(bytes: &[u8]) -> Result<ScalarValue> {
323 let date = Date::from_str_utc(&String::from_utf8_lossy(bytes)).map_err(|e| {
324 error::MysqlValueConversionSnafu {
325 err_msg: e.to_string(),
326 }
327 .build()
328 })?;
329
330 Ok(ScalarValue::Date32(Some(date.val())))
331}
332
333#[cfg(test)]
334mod tests {
335 use datatypes::types::{
336 TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
337 TimestampSecondType,
338 };
339 use sql::dialect::MySqlDialect;
340 use sql::parser::{ParseOptions, ParserContext};
341
342 use super::*;
343
344 #[test]
345 fn test_format_placeholder() {
346 assert_eq!("$1", format_placeholder(1));
347 assert_eq!("$3", format_placeholder(3));
348 }
349
350 #[test]
351 fn test_replace_placeholders() {
352 let create = "create table demo(host string, ts timestamp time index)";
353 let (sql, index) = replace_placeholders(create);
354 assert_eq!(create, sql);
355 assert_eq!(1, index);
356
357 let insert = "insert into demo values(?,?,?)";
358 let (sql, index) = replace_placeholders(insert);
359 assert_eq!("insert into demo values($1,$2,$3)", sql);
360 assert_eq!(4, index);
361
362 let query = "select from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?";
363 let (sql, index) = replace_placeholders(query);
364 assert_eq!(
365 "select from demo where host=$1 and idc in (select idc from idcs where name=$2) and cpu>$3",
366 sql
367 );
368 assert_eq!(4, index);
369 }
370
371 fn parse_sql(sql: &str) -> Statement {
372 let mut stmts =
373 ParserContext::create_with_dialect(sql, &MySqlDialect {}, ParseOptions::default())
374 .unwrap();
375 stmts.remove(0)
376 }
377
378 #[test]
379 fn test_transform_placeholders() {
380 let insert = parse_sql("insert into demo values(?,?,?)");
381 let Statement::Insert(insert) = transform_placeholders(insert) else {
382 unreachable!()
383 };
384 assert_eq!(
385 "INSERT INTO demo VALUES ($1, $2, $3)",
386 insert.inner.to_string()
387 );
388
389 let delete = parse_sql("delete from demo where host=? and idc=?");
390 let Statement::Delete(delete) = transform_placeholders(delete) else {
391 unreachable!()
392 };
393 assert_eq!(
394 "DELETE FROM demo WHERE host = $1 AND idc = $2",
395 delete.inner.to_string()
396 );
397
398 let select = parse_sql(
399 "select * from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?",
400 );
401 let Statement::Query(select) = transform_placeholders(select) else {
402 unreachable!()
403 };
404 assert_eq!(
405 "SELECT * FROM demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3",
406 select.inner.to_string()
407 );
408 }
409
410 #[test]
411 fn test_convert_expr_to_scalar_value() {
412 let expr = Expr::Value(ValueExpr::Number("123".to_string(), false).into());
413 let t = ConcreteDataType::int32_datatype();
414 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
415 assert_eq!(ScalarValue::Int32(Some(123)), v);
416
417 let expr = Expr::Value(ValueExpr::Number("123.456789".to_string(), false).into());
418 let t = ConcreteDataType::float64_datatype();
419 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
420 assert_eq!(ScalarValue::Float64(Some(123.456789)), v);
421
422 let expr = Expr::Value(ValueExpr::SingleQuotedString("2001-01-02".to_string()).into());
423 let t = ConcreteDataType::date_datatype();
424 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
425 let scalar_v = ScalarValue::Utf8(Some("2001-01-02".to_string()))
426 .cast_to(&arrow_schema::DataType::Date32)
427 .unwrap();
428 assert_eq!(scalar_v, v);
429
430 let expr =
431 Expr::Value(ValueExpr::SingleQuotedString("2001-01-02 03:04:05".to_string()).into());
432 let t = ConcreteDataType::timestamp_microsecond_datatype();
433 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
434 let scalar_v = ScalarValue::Utf8(Some("2001-01-02 03:04:05".to_string()))
435 .cast_to(&arrow_schema::DataType::Timestamp(
436 arrow_schema::TimeUnit::Microsecond,
437 None,
438 ))
439 .unwrap();
440 assert_eq!(scalar_v, v);
441
442 let expr = Expr::Value(ValueExpr::SingleQuotedString("hello".to_string()).into());
443 let t = ConcreteDataType::string_datatype();
444 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
445 assert_eq!(ScalarValue::Utf8(Some("hello".to_string())), v);
446
447 let expr = Expr::Value(ValueExpr::Null.into());
448 let t = ConcreteDataType::time_microsecond_datatype();
449 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
450 assert_eq!(ScalarValue::Time64Microsecond(None), v);
451 }
452
453 #[test]
454 fn test_convert_bytes_to_timestamp() {
455 let test_cases = vec![
456 (
458 "2024-12-26 12:00:00",
459 TimestampType::Nanosecond(TimestampNanosecondType),
460 ScalarValue::TimestampNanosecond(Some(1735214400000000000), None),
461 ),
462 (
464 "2024-12-26 12:00:00",
465 TimestampType::Microsecond(TimestampMicrosecondType),
466 ScalarValue::TimestampMicrosecond(Some(1735214400000000), None),
467 ),
468 (
470 "2024-12-26 12:00:00",
471 TimestampType::Millisecond(TimestampMillisecondType),
472 ScalarValue::TimestampMillisecond(Some(1735214400000), None),
473 ),
474 (
476 "2024-12-26 12:00:00",
477 TimestampType::Second(TimestampSecondType),
478 ScalarValue::TimestampSecond(Some(1735214400), None),
479 ),
480 (
482 "2024-12-26 12:00:00.123",
483 TimestampType::Nanosecond(TimestampNanosecondType),
484 ScalarValue::TimestampNanosecond(Some(1735214400123000000), None),
485 ),
486 (
488 "2024-12-26 12:00:00.123",
489 TimestampType::Microsecond(TimestampMicrosecondType),
490 ScalarValue::TimestampMicrosecond(Some(1735214400123000), None),
491 ),
492 (
494 "2024-12-26 12:00:00.123",
495 TimestampType::Millisecond(TimestampMillisecondType),
496 ScalarValue::TimestampMillisecond(Some(1735214400123), None),
497 ),
498 (
500 "2024-12-26 12:00:00.123",
501 TimestampType::Second(TimestampSecondType),
502 ScalarValue::TimestampSecond(Some(1735214400), None),
503 ),
504 (
506 "2024-12-26 12:00:00.123456",
507 TimestampType::Nanosecond(TimestampNanosecondType),
508 ScalarValue::TimestampNanosecond(Some(1735214400123456000), None),
509 ),
510 (
512 "2024-12-26 12:00:00.123456",
513 TimestampType::Microsecond(TimestampMicrosecondType),
514 ScalarValue::TimestampMicrosecond(Some(1735214400123456), None),
515 ),
516 (
518 "2024-12-26 12:00:00.123456",
519 TimestampType::Millisecond(TimestampMillisecondType),
520 ScalarValue::TimestampMillisecond(Some(1735214400123), None),
521 ),
522 (
524 "2024-12-26 12:00:00.123456",
525 TimestampType::Second(TimestampSecondType),
526 ScalarValue::TimestampSecond(Some(1735214400), None),
527 ),
528 ];
529
530 for (input, ts_type, expected) in test_cases {
531 let result = convert_bytes_to_timestamp(input.as_bytes(), &ts_type).unwrap();
532 assert_eq!(result, expected);
533 }
534 }
535
536 #[test]
537 fn test_convert_bytes_to_date() {
538 let test_cases = vec![
539 ("1970-01-01", ScalarValue::Date32(Some(0))),
541 ("1969-12-31", ScalarValue::Date32(Some(-1))),
542 ("2024-02-29", ScalarValue::Date32(Some(19782))),
543 ("2024-01-01", ScalarValue::Date32(Some(19723))),
544 ("2024-12-31", ScalarValue::Date32(Some(20088))),
545 ("2001-01-02", ScalarValue::Date32(Some(11324))),
546 ("2050-06-14", ScalarValue::Date32(Some(29384))),
547 ("2020-03-15", ScalarValue::Date32(Some(18336))),
548 ];
549
550 for (input, expected) in test_cases {
551 let result = convert_bytes_to_date(input.as_bytes()).unwrap();
552 assert_eq!(result, expected, "Failed for input: {}", input);
553 }
554 }
555}