1use std::ops::ControlFlow;
16use std::time::Duration;
17
18use chrono::NaiveDate;
19use common_query::prelude::ScalarValue;
20use common_time::Timestamp;
21use datafusion_common::tree_node::{Transformed, TreeNode};
22use datafusion_expr::LogicalPlan;
23use datatypes::prelude::ConcreteDataType;
24use datatypes::types::TimestampType;
25use datatypes::value::{self, Value};
26use itertools::Itertools;
27use opensrv_mysql::{to_naive_datetime, ParamValue, ValueInner};
28use snafu::ResultExt;
29use sql::ast::{visit_expressions_mut, Expr, Value as ValueExpr, VisitMut};
30use sql::statements::sql_value_to_value;
31use sql::statements::statement::Statement;
32
33use crate::error::{self, DataFusionSnafu, Result};
34
35pub fn format_placeholder(i: usize) -> String {
37 format!("${}", i)
38}
39
40pub fn replace_placeholders(query: &str) -> (String, usize) {
43 let query_parts = query.split('?').collect::<Vec<_>>();
44 let parts_len = query_parts.len();
45 let mut index = 0;
46 let query = query_parts
47 .into_iter()
48 .enumerate()
49 .map(|(i, part)| {
50 if i == parts_len - 1 {
51 return part.to_string();
52 }
53
54 index += 1;
55 format!("{part}{}", format_placeholder(index))
56 })
57 .join("");
58
59 (query, index + 1)
60}
61
62pub fn transform_placeholders(stmt: Statement) -> Statement {
65 match stmt {
66 Statement::Query(mut query) => {
67 visit_placeholders(&mut query.inner);
68 Statement::Query(query)
69 }
70 Statement::Insert(mut insert) => {
71 visit_placeholders(&mut insert.inner);
72 Statement::Insert(insert)
73 }
74 Statement::Delete(mut delete) => {
75 visit_placeholders(&mut delete.inner);
76 Statement::Delete(delete)
77 }
78 stmt => stmt,
79 }
80}
81
82pub fn fix_placeholder_types(plan: &mut LogicalPlan) -> Result<()> {
86 let give_placeholder_types = |mut e: datafusion_expr::Expr| {
87 if let datafusion_expr::Expr::Cast(cast) = &mut e {
88 if let datafusion_expr::Expr::Placeholder(ph) = &mut *cast.expr {
89 if ph.data_type.is_none() {
90 ph.data_type = Some(cast.data_type.clone());
91 common_telemetry::debug!(
92 "give placeholder type {:?} to {:?}",
93 cast.data_type,
94 ph
95 );
96 Ok(Transformed::yes(e))
97 } else {
98 Ok(Transformed::no(e))
99 }
100 } else {
101 Ok(Transformed::no(e))
102 }
103 } else {
104 Ok(Transformed::no(e))
105 }
106 };
107 let give_placeholder_types_recursively =
108 |e: datafusion_expr::Expr| e.transform(give_placeholder_types);
109 *plan = std::mem::take(plan)
110 .transform(|p| p.map_expressions(give_placeholder_types_recursively))
111 .context(DataFusionSnafu)?
112 .data;
113 Ok(())
114}
115
116fn visit_placeholders<V>(v: &mut V)
117where
118 V: VisitMut,
119{
120 let mut index = 1;
121 let _ = visit_expressions_mut(v, |expr| {
122 if let Expr::Value(ValueExpr::Placeholder(s)) = expr {
123 *s = format_placeholder(index);
124 index += 1;
125 }
126 ControlFlow::<()>::Continue(())
127 });
128}
129
130pub fn convert_value(param: &ParamValue, t: &ConcreteDataType) -> Result<ScalarValue> {
133 match param.value.into_inner() {
134 ValueInner::Int(i) => match t {
135 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(i as i8))),
136 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(i as i16))),
137 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(i as i32))),
138 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(i))),
139 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(i as u8))),
140 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(i as u16))),
141 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(i as u32))),
142 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(i as u64))),
143 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(i as f32))),
144 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(i as f64))),
145 ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(i != 0))),
146 ConcreteDataType::Timestamp(ts_type) => Value::Timestamp(ts_type.create_timestamp(i))
147 .try_to_scalar_value(t)
148 .context(error::ConvertScalarValueSnafu),
149
150 _ => error::PreparedStmtTypeMismatchSnafu {
151 expected: t,
152 actual: param.coltype,
153 }
154 .fail(),
155 },
156 ValueInner::UInt(u) => match t {
157 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(u as i8))),
158 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(u as i16))),
159 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(u as i32))),
160 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(u as i64))),
161 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(u as u8))),
162 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(u as u16))),
163 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(u as u32))),
164 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(u))),
165 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(u as f32))),
166 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(u as f64))),
167 ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(u != 0))),
168 ConcreteDataType::Timestamp(ts_type) => {
169 Value::Timestamp(ts_type.create_timestamp(u as i64))
170 .try_to_scalar_value(t)
171 .context(error::ConvertScalarValueSnafu)
172 }
173
174 _ => error::PreparedStmtTypeMismatchSnafu {
175 expected: t,
176 actual: param.coltype,
177 }
178 .fail(),
179 },
180 ValueInner::Double(f) => match t {
181 ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(f as i8))),
182 ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(f as i16))),
183 ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(f as i32))),
184 ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(f as i64))),
185 ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(f as u8))),
186 ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(f as u16))),
187 ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(f as u32))),
188 ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(f as u64))),
189 ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(f as f32))),
190 ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(f))),
191
192 _ => error::PreparedStmtTypeMismatchSnafu {
193 expected: t,
194 actual: param.coltype,
195 }
196 .fail(),
197 },
198 ValueInner::NULL => value::to_null_scalar_value(t).context(error::ConvertScalarValueSnafu),
199 ValueInner::Bytes(b) => match t {
200 ConcreteDataType::String(_) => Ok(ScalarValue::Utf8(Some(
201 String::from_utf8_lossy(b).to_string(),
202 ))),
203 ConcreteDataType::Binary(_) => Ok(ScalarValue::Binary(Some(b.to_vec()))),
204 ConcreteDataType::Timestamp(ts_type) => covert_bytes_to_timestamp(b, ts_type),
205 _ => error::PreparedStmtTypeMismatchSnafu {
206 expected: t,
207 actual: param.coltype,
208 }
209 .fail(),
210 },
211 ValueInner::Date(_) => {
212 let date: common_time::Date = NaiveDate::from(param.value).into();
213 Ok(ScalarValue::Date32(Some(date.val())))
214 }
215 ValueInner::Datetime(_) => {
216 let timestamp_millis = to_naive_datetime(param.value)
217 .map_err(|e| {
218 error::MysqlValueConversionSnafu {
219 err_msg: e.to_string(),
220 }
221 .build()
222 })?
223 .and_utc()
224 .timestamp_millis();
225
226 match t {
227 ConcreteDataType::Timestamp(_) => Ok(ScalarValue::TimestampMillisecond(
228 Some(timestamp_millis),
229 None,
230 )),
231 _ => error::PreparedStmtTypeMismatchSnafu {
232 expected: t,
233 actual: param.coltype,
234 }
235 .fail(),
236 }
237 }
238 ValueInner::Time(_) => Ok(ScalarValue::Time64Nanosecond(Some(
239 Duration::from(param.value).as_millis() as i64,
240 ))),
241 }
242}
243
244pub fn convert_expr_to_scalar_value(param: &Expr, t: &ConcreteDataType) -> Result<ScalarValue> {
247 match param {
248 Expr::Value(v) => {
249 let v = sql_value_to_value("", t, v, None, None, true);
250 match v {
251 Ok(v) => v
252 .try_to_scalar_value(t)
253 .context(error::ConvertScalarValueSnafu),
254 Err(e) => error::InvalidParameterSnafu {
255 reason: e.to_string(),
256 }
257 .fail(),
258 }
259 }
260 Expr::UnaryOp { op, expr } if let Expr::Value(v) = &**expr => {
261 let v = sql_value_to_value("", t, v, None, Some(*op), true);
262 match v {
263 Ok(v) => v
264 .try_to_scalar_value(t)
265 .context(error::ConvertScalarValueSnafu),
266 Err(e) => error::InvalidParameterSnafu {
267 reason: e.to_string(),
268 }
269 .fail(),
270 }
271 }
272 _ => error::InvalidParameterSnafu {
273 reason: format!("cannot convert {:?} to scalar value of type {}", param, t),
274 }
275 .fail(),
276 }
277}
278
279fn covert_bytes_to_timestamp(bytes: &[u8], ts_type: &TimestampType) -> Result<ScalarValue> {
280 let ts = Timestamp::from_str_utc(&String::from_utf8_lossy(bytes))
281 .map_err(|e| {
282 error::MysqlValueConversionSnafu {
283 err_msg: e.to_string(),
284 }
285 .build()
286 })?
287 .convert_to(ts_type.unit())
288 .ok_or_else(|| {
289 error::MysqlValueConversionSnafu {
290 err_msg: "Overflow when converting timestamp to target unit".to_string(),
291 }
292 .build()
293 })?;
294 match ts_type {
295 TimestampType::Nanosecond(_) => {
296 Ok(ScalarValue::TimestampNanosecond(Some(ts.value()), None))
297 }
298 TimestampType::Microsecond(_) => {
299 Ok(ScalarValue::TimestampMicrosecond(Some(ts.value()), None))
300 }
301 TimestampType::Millisecond(_) => {
302 Ok(ScalarValue::TimestampMillisecond(Some(ts.value()), None))
303 }
304 TimestampType::Second(_) => Ok(ScalarValue::TimestampSecond(Some(ts.value()), None)),
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use datatypes::types::{
311 TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
312 TimestampSecondType,
313 };
314 use sql::dialect::MySqlDialect;
315 use sql::parser::{ParseOptions, ParserContext};
316
317 use super::*;
318
319 #[test]
320 fn test_format_placeholder() {
321 assert_eq!("$1", format_placeholder(1));
322 assert_eq!("$3", format_placeholder(3));
323 }
324
325 #[test]
326 fn test_replace_placeholders() {
327 let create = "create table demo(host string, ts timestamp time index)";
328 let (sql, index) = replace_placeholders(create);
329 assert_eq!(create, sql);
330 assert_eq!(1, index);
331
332 let insert = "insert into demo values(?,?,?)";
333 let (sql, index) = replace_placeholders(insert);
334 assert_eq!("insert into demo values($1,$2,$3)", sql);
335 assert_eq!(4, index);
336
337 let query = "select from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?";
338 let (sql, index) = replace_placeholders(query);
339 assert_eq!("select from demo where host=$1 and idc in (select idc from idcs where name=$2) and cpu>$3", sql);
340 assert_eq!(4, index);
341 }
342
343 fn parse_sql(sql: &str) -> Statement {
344 let mut stmts =
345 ParserContext::create_with_dialect(sql, &MySqlDialect {}, ParseOptions::default())
346 .unwrap();
347 stmts.remove(0)
348 }
349
350 #[test]
351 fn test_transform_placeholders() {
352 let insert = parse_sql("insert into demo values(?,?,?)");
353 let Statement::Insert(insert) = transform_placeholders(insert) else {
354 unreachable!()
355 };
356 assert_eq!(
357 "INSERT INTO demo VALUES ($1, $2, $3)",
358 insert.inner.to_string()
359 );
360
361 let delete = parse_sql("delete from demo where host=? and idc=?");
362 let Statement::Delete(delete) = transform_placeholders(delete) else {
363 unreachable!()
364 };
365 assert_eq!(
366 "DELETE FROM demo WHERE host = $1 AND idc = $2",
367 delete.inner.to_string()
368 );
369
370 let select = parse_sql("select * from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?");
371 let Statement::Query(select) = transform_placeholders(select) else {
372 unreachable!()
373 };
374 assert_eq!("SELECT * FROM demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3", select.inner.to_string());
375 }
376
377 #[test]
378 fn test_convert_expr_to_scalar_value() {
379 let expr = Expr::Value(ValueExpr::Number("123".to_string(), false));
380 let t = ConcreteDataType::int32_datatype();
381 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
382 assert_eq!(ScalarValue::Int32(Some(123)), v);
383
384 let expr = Expr::Value(ValueExpr::Number("123.456789".to_string(), false));
385 let t = ConcreteDataType::float64_datatype();
386 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
387 assert_eq!(ScalarValue::Float64(Some(123.456789)), v);
388
389 let expr = Expr::Value(ValueExpr::SingleQuotedString("2001-01-02".to_string()));
390 let t = ConcreteDataType::date_datatype();
391 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
392 let scalar_v = ScalarValue::Utf8(Some("2001-01-02".to_string()))
393 .cast_to(&arrow_schema::DataType::Date32)
394 .unwrap();
395 assert_eq!(scalar_v, v);
396
397 let expr = Expr::Value(ValueExpr::SingleQuotedString(
398 "2001-01-02 03:04:05".to_string(),
399 ));
400 let t = ConcreteDataType::timestamp_microsecond_datatype();
401 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
402 let scalar_v = ScalarValue::Utf8(Some("2001-01-02 03:04:05".to_string()))
403 .cast_to(&arrow_schema::DataType::Timestamp(
404 arrow_schema::TimeUnit::Microsecond,
405 None,
406 ))
407 .unwrap();
408 assert_eq!(scalar_v, v);
409
410 let expr = Expr::Value(ValueExpr::SingleQuotedString("hello".to_string()));
411 let t = ConcreteDataType::string_datatype();
412 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
413 assert_eq!(ScalarValue::Utf8(Some("hello".to_string())), v);
414
415 let expr = Expr::Value(ValueExpr::Null);
416 let t = ConcreteDataType::time_microsecond_datatype();
417 let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
418 assert_eq!(ScalarValue::Time64Microsecond(None), v);
419 }
420
421 #[test]
422 fn test_convert_bytes_to_timestamp() {
423 let test_cases = vec![
424 (
426 "2024-12-26 12:00:00",
427 TimestampType::Nanosecond(TimestampNanosecondType),
428 ScalarValue::TimestampNanosecond(Some(1735214400000000000), None),
429 ),
430 (
432 "2024-12-26 12:00:00",
433 TimestampType::Microsecond(TimestampMicrosecondType),
434 ScalarValue::TimestampMicrosecond(Some(1735214400000000), None),
435 ),
436 (
438 "2024-12-26 12:00:00",
439 TimestampType::Millisecond(TimestampMillisecondType),
440 ScalarValue::TimestampMillisecond(Some(1735214400000), None),
441 ),
442 (
444 "2024-12-26 12:00:00",
445 TimestampType::Second(TimestampSecondType),
446 ScalarValue::TimestampSecond(Some(1735214400), None),
447 ),
448 (
450 "2024-12-26 12:00:00.123",
451 TimestampType::Nanosecond(TimestampNanosecondType),
452 ScalarValue::TimestampNanosecond(Some(1735214400123000000), None),
453 ),
454 (
456 "2024-12-26 12:00:00.123",
457 TimestampType::Microsecond(TimestampMicrosecondType),
458 ScalarValue::TimestampMicrosecond(Some(1735214400123000), None),
459 ),
460 (
462 "2024-12-26 12:00:00.123",
463 TimestampType::Millisecond(TimestampMillisecondType),
464 ScalarValue::TimestampMillisecond(Some(1735214400123), None),
465 ),
466 (
468 "2024-12-26 12:00:00.123",
469 TimestampType::Second(TimestampSecondType),
470 ScalarValue::TimestampSecond(Some(1735214400), None),
471 ),
472 (
474 "2024-12-26 12:00:00.123456",
475 TimestampType::Nanosecond(TimestampNanosecondType),
476 ScalarValue::TimestampNanosecond(Some(1735214400123456000), None),
477 ),
478 (
480 "2024-12-26 12:00:00.123456",
481 TimestampType::Microsecond(TimestampMicrosecondType),
482 ScalarValue::TimestampMicrosecond(Some(1735214400123456), None),
483 ),
484 (
486 "2024-12-26 12:00:00.123456",
487 TimestampType::Millisecond(TimestampMillisecondType),
488 ScalarValue::TimestampMillisecond(Some(1735214400123), None),
489 ),
490 (
492 "2024-12-26 12:00:00.123456",
493 TimestampType::Second(TimestampSecondType),
494 ScalarValue::TimestampSecond(Some(1735214400), None),
495 ),
496 ];
497
498 for (input, ts_type, expected) in test_cases {
499 let result = covert_bytes_to_timestamp(input.as_bytes(), &ts_type).unwrap();
500 assert_eq!(result, expected);
501 }
502 }
503}