1use std::ops::ControlFlow;
16use std::time::Duration;
17
18use chrono::NaiveDate;
19use common_query::prelude::ScalarValue;
20use common_sql::convert::sql_value_to_value;
21use common_time::Timestamp;
22use datafusion_common::tree_node::{Transformed, TreeNode};
23use datafusion_expr::LogicalPlan;
24use datatypes::prelude::ConcreteDataType;
25use datatypes::types::TimestampType;
26use datatypes::value::{self, Value};
27use itertools::Itertools;
28use opensrv_mysql::{to_naive_datetime, ParamValue, ValueInner};
29use snafu::ResultExt;
30use sql::ast::{visit_expressions_mut, Expr, Value as ValueExpr, ValueWithSpan, VisitMut};
31use sql::statements::statement::Statement;
32
33use crate::error::{self, DataFusionSnafu, Result};
34
35pub fn format_placeholder(i: usize) -> String {
37 format!("${}", i)
38}
39
40pub fn replace_placeholders(query: &str) -> (String, usize) {
43 let query_parts = query.split('?').collect::<Vec<_>>();
44 let parts_len = query_parts.len();
45 let mut index = 0;
46 let query = query_parts
47 .into_iter()
48 .enumerate()
49 .map(|(i, part)| {
50 if i == parts_len - 1 {
51 return part.to_string();
52 }
53
54 index += 1;
55 format!("{part}{}", format_placeholder(index))
56 })
57 .join("");
58
59 (query, index + 1)
60}
61
62pub fn transform_placeholders(stmt: Statement) -> Statement {
65 match stmt {
66 Statement::Query(mut query) => {
67 visit_placeholders(&mut query.inner);
68 Statement::Query(query)
69 }
70 Statement::Insert(mut insert) => {
71 visit_placeholders(&mut insert.inner);
72 Statement::Insert(insert)
73 }
74 Statement::Delete(mut delete) => {
75 visit_placeholders(&mut delete.inner);
76 Statement::Delete(delete)
77 }
78 stmt => stmt,
79 }
80}
81
82pub fn fix_placeholder_types(plan: &mut LogicalPlan) -> Result<()> {
86 let give_placeholder_types = |mut e: datafusion_expr::Expr| {
87 if let datafusion_expr::Expr::Cast(cast) = &mut e {
88 if let datafusion_expr::Expr::Placeholder(ph) = &mut *cast.expr {
89 if ph.data_type.is_none() {
90 ph.data_type = Some(cast.data_type.clone());
91 common_telemetry::debug!(
92 "give placeholder type {:?} to {:?}",
93 cast.data_type,
94 ph
95 );
96 Ok(Transformed::yes(e))
97 } else {
98 Ok(Transformed::no(e))
99 }
100 } else {
101 Ok(Transformed::no(e))
102 }
103 } else {
104 Ok(Transformed::no(e))
105 }
106 };
107 let give_placeholder_types_recursively =
108 |e: datafusion_expr::Expr| e.transform(give_placeholder_types);
109 *plan = std::mem::take(plan)
110 .transform(|p| p.map_expressions(give_placeholder_types_recursively))
111 .context(DataFusionSnafu)?
112 .data;
113 Ok(())
114}
115
116fn visit_placeholders<V>(v: &mut V)
117where
118 V: VisitMut,
119{
120 let mut index = 1;
121 let _ = visit_expressions_mut(v, |expr| {
122 if let Expr::Value(ValueWithSpan {
123 value: ValueExpr::Placeholder(s),
124 ..
125 }) = expr
126 {
127 *s = format_placeholder(index);
128 index += 1;
129 }
130 ControlFlow::<()>::Continue(())
131 });
132}
133
134pub fn convert_value(param: &ParamValue, t: &ConcreteDataType) -> Result<ScalarValue> {
137 match param.value.into_inner() {
138 ValueInner::Int(i) => match t {
139 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(i as i8))),
140 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(i as i16))),
141 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(i as i32))),
142 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(i))),
143 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(i as u8))),
144 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(i as u16))),
145 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(i as u32))),
146 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(i as u64))),
147 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(i as f32))),
148 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(i as f64))),
149 ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(i != 0))),
150 ConcreteDataType::Timestamp(ts_type) => Value::Timestamp(ts_type.create_timestamp(i))
151 .try_to_scalar_value(t)
152 .context(error::ConvertScalarValueSnafu),
153
154 _ => error::PreparedStmtTypeMismatchSnafu {
155 expected: t,
156 actual: param.coltype,
157 }
158 .fail(),
159 },
160 ValueInner::UInt(u) => match t {
161 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(u as i8))),
162 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(u as i16))),
163 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(u as i32))),
164 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(u as i64))),
165 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(u as u8))),
166 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(u as u16))),
167 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(u as u32))),
168 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(u))),
169 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(u as f32))),
170 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(u as f64))),
171 ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(u != 0))),
172 ConcreteDataType::Timestamp(ts_type) => {
173 Value::Timestamp(ts_type.create_timestamp(u as i64))
174 .try_to_scalar_value(t)
175 .context(error::ConvertScalarValueSnafu)
176 }
177
178 _ => error::PreparedStmtTypeMismatchSnafu {
179 expected: t,
180 actual: param.coltype,
181 }
182 .fail(),
183 },
184 ValueInner::Double(f) => match t {
185 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(f as i8))),
186 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(f as i16))),
187 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(f as i32))),
188 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(f as i64))),
189 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(f as u8))),
190 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(f as u16))),
191 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(f as u32))),
192 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(f as u64))),
193 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(f as f32))),
194 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(f))),
195
196 _ => error::PreparedStmtTypeMismatchSnafu {
197 expected: t,
198 actual: param.coltype,
199 }
200 .fail(),
201 },
202 ValueInner::NULL => value::to_null_scalar_value(t).context(error::ConvertScalarValueSnafu),
203 ValueInner::Bytes(b) => match t {
204 ConcreteDataType::String(_) => Ok(ScalarValue::Utf8(Some(
205 String::from_utf8_lossy(b).to_string(),
206 ))),
207 ConcreteDataType::Binary(_) => Ok(ScalarValue::Binary(Some(b.to_vec()))),
208 ConcreteDataType::Timestamp(ts_type) => covert_bytes_to_timestamp(b, ts_type),
209 _ => error::PreparedStmtTypeMismatchSnafu {
210 expected: t,
211 actual: param.coltype,
212 }
213 .fail(),
214 },
215 ValueInner::Date(_) => {
216 let date: common_time::Date = NaiveDate::from(param.value).into();
217 Ok(ScalarValue::Date32(Some(date.val())))
218 }
219 ValueInner::Datetime(_) => {
220 let timestamp_millis = to_naive_datetime(param.value)
221 .map_err(|e| {
222 error::MysqlValueConversionSnafu {
223 err_msg: e.to_string(),
224 }
225 .build()
226 })?
227 .and_utc()
228 .timestamp_millis();
229
230 match t {
231 ConcreteDataType::Timestamp(_) => Ok(ScalarValue::TimestampMillisecond(
232 Some(timestamp_millis),
233 None,
234 )),
235 _ => error::PreparedStmtTypeMismatchSnafu {
236 expected: t,
237 actual: param.coltype,
238 }
239 .fail(),
240 }
241 }
242 ValueInner::Time(_) => Ok(ScalarValue::Time64Nanosecond(Some(
243 Duration::from(param.value).as_millis() as i64,
244 ))),
245 }
246}
247
248pub fn convert_expr_to_scalar_value(param: &Expr, t: &ConcreteDataType) -> Result<ScalarValue> {
251 match param {
252 Expr::Value(v) => {
253 let v = sql_value_to_value("", t, &v.value, None, None, true);
254 match v {
255 Ok(v) => v
256 .try_to_scalar_value(t)
257 .context(error::ConvertScalarValueSnafu),
258 Err(e) => error::InvalidParameterSnafu {
259 reason: e.to_string(),
260 }
261 .fail(),
262 }
263 }
264 Expr::UnaryOp { op, expr } if let Expr::Value(v) = &**expr => {
265 let v = sql_value_to_value("", t, &v.value, None, Some(*op), true);
266 match v {
267 Ok(v) => v
268 .try_to_scalar_value(t)
269 .context(error::ConvertScalarValueSnafu),
270 Err(e) => error::InvalidParameterSnafu {
271 reason: e.to_string(),
272 }
273 .fail(),
274 }
275 }
276 _ => error::InvalidParameterSnafu {
277 reason: format!("cannot convert {:?} to scalar value of type {}", param, t),
278 }
279 .fail(),
280 }
281}
282
283fn covert_bytes_to_timestamp(bytes: &[u8], ts_type: &TimestampType) -> Result<ScalarValue> {
284 let ts = Timestamp::from_str_utc(&String::from_utf8_lossy(bytes))
285 .map_err(|e| {
286 error::MysqlValueConversionSnafu {
287 err_msg: e.to_string(),
288 }
289 .build()
290 })?
291 .convert_to(ts_type.unit())
292 .ok_or_else(|| {
293 error::MysqlValueConversionSnafu {
294 err_msg: "Overflow when converting timestamp to target unit".to_string(),
295 }
296 .build()
297 })?;
298 match ts_type {
299 TimestampType::Nanosecond(_) => {
300 Ok(ScalarValue::TimestampNanosecond(Some(ts.value()), None))
301 }
302 TimestampType::Microsecond(_) => {
303 Ok(ScalarValue::TimestampMicrosecond(Some(ts.value()), None))
304 }
305 TimestampType::Millisecond(_) => {
306 Ok(ScalarValue::TimestampMillisecond(Some(ts.value()), None))
307 }
308 TimestampType::Second(_) => Ok(ScalarValue::TimestampSecond(Some(ts.value()), None)),
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use datatypes::types::{
315 TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
316 TimestampSecondType,
317 };
318 use sql::dialect::MySqlDialect;
319 use sql::parser::{ParseOptions, ParserContext};
320
321 use super::*;
322
323 #[test]
324 fn test_format_placeholder() {
325 assert_eq!("$1", format_placeholder(1));
326 assert_eq!("$3", format_placeholder(3));
327 }
328
329 #[test]
330 fn test_replace_placeholders() {
331 let create = "create table demo(host string, ts timestamp time index)";
332 let (sql, index) = replace_placeholders(create);
333 assert_eq!(create, sql);
334 assert_eq!(1, index);
335
336 let insert = "insert into demo values(?,?,?)";
337 let (sql, index) = replace_placeholders(insert);
338 assert_eq!("insert into demo values($1,$2,$3)", sql);
339 assert_eq!(4, index);
340
341 let query = "select from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?";
342 let (sql, index) = replace_placeholders(query);
343 assert_eq!("select from demo where host=$1 and idc in (select idc from idcs where name=$2) and cpu>$3", sql);
344 assert_eq!(4, index);
345 }
346
347 fn parse_sql(sql: &str) -> Statement {
348 let mut stmts =
349 ParserContext::create_with_dialect(sql, &MySqlDialect {}, ParseOptions::default())
350 .unwrap();
351 stmts.remove(0)
352 }
353
354 #[test]
355 fn test_transform_placeholders() {
356 let insert = parse_sql("insert into demo values(?,?,?)");
357 let Statement::Insert(insert) = transform_placeholders(insert) else {
358 unreachable!()
359 };
360 assert_eq!(
361 "INSERT INTO demo VALUES ($1, $2, $3)",
362 insert.inner.to_string()
363 );
364
365 let delete = parse_sql("delete from demo where host=? and idc=?");
366 let Statement::Delete(delete) = transform_placeholders(delete) else {
367 unreachable!()
368 };
369 assert_eq!(
370 "DELETE FROM demo WHERE host = $1 AND idc = $2",
371 delete.inner.to_string()
372 );
373
374 let select = parse_sql("select * from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?");
375 let Statement::Query(select) = transform_placeholders(select) else {
376 unreachable!()
377 };
378 assert_eq!("SELECT * FROM demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3", select.inner.to_string());
379 }
380
381 #[test]
382 fn test_convert_expr_to_scalar_value() {
383 let expr = Expr::Value(ValueExpr::Number("123".to_string(), false).into());
384 let t = ConcreteDataType::int32_datatype();
385 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
386 assert_eq!(ScalarValue::Int32(Some(123)), v);
387
388 let expr = Expr::Value(ValueExpr::Number("123.456789".to_string(), false).into());
389 let t = ConcreteDataType::float64_datatype();
390 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
391 assert_eq!(ScalarValue::Float64(Some(123.456789)), v);
392
393 let expr = Expr::Value(ValueExpr::SingleQuotedString("2001-01-02".to_string()).into());
394 let t = ConcreteDataType::date_datatype();
395 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
396 let scalar_v = ScalarValue::Utf8(Some("2001-01-02".to_string()))
397 .cast_to(&arrow_schema::DataType::Date32)
398 .unwrap();
399 assert_eq!(scalar_v, v);
400
401 let expr =
402 Expr::Value(ValueExpr::SingleQuotedString("2001-01-02 03:04:05".to_string()).into());
403 let t = ConcreteDataType::timestamp_microsecond_datatype();
404 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
405 let scalar_v = ScalarValue::Utf8(Some("2001-01-02 03:04:05".to_string()))
406 .cast_to(&arrow_schema::DataType::Timestamp(
407 arrow_schema::TimeUnit::Microsecond,
408 None,
409 ))
410 .unwrap();
411 assert_eq!(scalar_v, v);
412
413 let expr = Expr::Value(ValueExpr::SingleQuotedString("hello".to_string()).into());
414 let t = ConcreteDataType::string_datatype();
415 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
416 assert_eq!(ScalarValue::Utf8(Some("hello".to_string())), v);
417
418 let expr = Expr::Value(ValueExpr::Null.into());
419 let t = ConcreteDataType::time_microsecond_datatype();
420 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
421 assert_eq!(ScalarValue::Time64Microsecond(None), v);
422 }
423
424 #[test]
425 fn test_convert_bytes_to_timestamp() {
426 let test_cases = vec![
427 (
429 "2024-12-26 12:00:00",
430 TimestampType::Nanosecond(TimestampNanosecondType),
431 ScalarValue::TimestampNanosecond(Some(1735214400000000000), None),
432 ),
433 (
435 "2024-12-26 12:00:00",
436 TimestampType::Microsecond(TimestampMicrosecondType),
437 ScalarValue::TimestampMicrosecond(Some(1735214400000000), None),
438 ),
439 (
441 "2024-12-26 12:00:00",
442 TimestampType::Millisecond(TimestampMillisecondType),
443 ScalarValue::TimestampMillisecond(Some(1735214400000), None),
444 ),
445 (
447 "2024-12-26 12:00:00",
448 TimestampType::Second(TimestampSecondType),
449 ScalarValue::TimestampSecond(Some(1735214400), None),
450 ),
451 (
453 "2024-12-26 12:00:00.123",
454 TimestampType::Nanosecond(TimestampNanosecondType),
455 ScalarValue::TimestampNanosecond(Some(1735214400123000000), None),
456 ),
457 (
459 "2024-12-26 12:00:00.123",
460 TimestampType::Microsecond(TimestampMicrosecondType),
461 ScalarValue::TimestampMicrosecond(Some(1735214400123000), None),
462 ),
463 (
465 "2024-12-26 12:00:00.123",
466 TimestampType::Millisecond(TimestampMillisecondType),
467 ScalarValue::TimestampMillisecond(Some(1735214400123), None),
468 ),
469 (
471 "2024-12-26 12:00:00.123",
472 TimestampType::Second(TimestampSecondType),
473 ScalarValue::TimestampSecond(Some(1735214400), None),
474 ),
475 (
477 "2024-12-26 12:00:00.123456",
478 TimestampType::Nanosecond(TimestampNanosecondType),
479 ScalarValue::TimestampNanosecond(Some(1735214400123456000), None),
480 ),
481 (
483 "2024-12-26 12:00:00.123456",
484 TimestampType::Microsecond(TimestampMicrosecondType),
485 ScalarValue::TimestampMicrosecond(Some(1735214400123456), None),
486 ),
487 (
489 "2024-12-26 12:00:00.123456",
490 TimestampType::Millisecond(TimestampMillisecondType),
491 ScalarValue::TimestampMillisecond(Some(1735214400123), None),
492 ),
493 (
495 "2024-12-26 12:00:00.123456",
496 TimestampType::Second(TimestampSecondType),
497 ScalarValue::TimestampSecond(Some(1735214400), None),
498 ),
499 ];
500
501 for (input, ts_type, expected) in test_cases {
502 let result = covert_bytes_to_timestamp(input.as_bytes(), &ts_type).unwrap();
503 assert_eq!(result, expected);
504 }
505 }
506}