servers/mysql/
helper.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::ControlFlow;
16use std::sync::Arc;
17use std::time::Duration;
18
19use arrow_schema::Field;
20use chrono::NaiveDate;
21use common_query::prelude::ScalarValue;
22use common_sql::convert::sql_value_to_value;
23use common_time::{Date, Timestamp};
24use datafusion_common::tree_node::{Transformed, TreeNode};
25use datafusion_expr::LogicalPlan;
26use datatypes::prelude::ConcreteDataType;
27use datatypes::schema::ColumnSchema;
28use datatypes::types::TimestampType;
29use datatypes::value::{self, Value};
30use itertools::Itertools;
31use opensrv_mysql::{ParamValue, ValueInner, to_naive_datetime};
32use snafu::ResultExt;
33use sql::ast::{Expr, Value as ValueExpr, ValueWithSpan, VisitMut, visit_expressions_mut};
34use sql::statements::statement::Statement;
35
36use crate::error::{self, DataFusionSnafu, Result};
37
38/// Returns the placeholder string "$i".
39pub fn format_placeholder(i: usize) -> String {
40    format!("${}", i)
41}
42
43/// Replace all the "?" placeholder into "$i" in SQL,
44/// returns the new SQL and the last placeholder index.
45pub fn replace_placeholders(query: &str) -> (String, usize) {
46    let query_parts = query.split('?').collect::<Vec<_>>();
47    let parts_len = query_parts.len();
48    let mut index = 0;
49    let query = query_parts
50        .into_iter()
51        .enumerate()
52        .map(|(i, part)| {
53            if i == parts_len - 1 {
54                return part.to_string();
55            }
56
57            index += 1;
58            format!("{part}{}", format_placeholder(index))
59        })
60        .join("");
61
62    (query, index + 1)
63}
64
65/// Transform all the "?" placeholder into "$i".
66/// Only works for Insert,Query and Delete statements.
67pub fn transform_placeholders(stmt: Statement) -> Statement {
68    match stmt {
69        Statement::Query(mut query) => {
70            visit_placeholders(&mut query.inner);
71            Statement::Query(query)
72        }
73        Statement::Insert(mut insert) => {
74            visit_placeholders(&mut insert.inner);
75            Statement::Insert(insert)
76        }
77        Statement::Delete(mut delete) => {
78            visit_placeholders(&mut delete.inner);
79            Statement::Delete(delete)
80        }
81        stmt => stmt,
82    }
83}
84
85/// Give placeholder that cast to certain type `data_type` the same data type as is cast to
86///
87/// because it seems datafusion will not give data type to placeholder if it need to be cast to certain type, still unknown if this is a feature or a bug. And if a placeholder expr have no data type, datafusion will fail to extract it using `LogicalPlan::get_parameter_types`
88pub fn fix_placeholder_types(plan: &mut LogicalPlan) -> Result<()> {
89    let give_placeholder_types = |mut e: datafusion_expr::Expr| {
90        if let datafusion_expr::Expr::Cast(cast) = &mut e {
91            if let datafusion_expr::Expr::Placeholder(ph) = &mut *cast.expr {
92                if ph.field.is_none() {
93                    ph.field = Some(Arc::new(Field::new("", cast.data_type.clone(), true)));
94                    common_telemetry::debug!(
95                        "give placeholder type {:?} to {:?}",
96                        cast.data_type,
97                        ph
98                    );
99                    Ok(Transformed::yes(e))
100                } else {
101                    Ok(Transformed::no(e))
102                }
103            } else {
104                Ok(Transformed::no(e))
105            }
106        } else {
107            Ok(Transformed::no(e))
108        }
109    };
110    let give_placeholder_types_recursively =
111        |e: datafusion_expr::Expr| e.transform(give_placeholder_types);
112    *plan = std::mem::take(plan)
113        .transform(|p| p.map_expressions(give_placeholder_types_recursively))
114        .context(DataFusionSnafu)?
115        .data;
116    Ok(())
117}
118
119fn visit_placeholders<V>(v: &mut V)
120where
121    V: VisitMut,
122{
123    let mut index = 1;
124    let _ = visit_expressions_mut(v, |expr| {
125        if let Expr::Value(ValueWithSpan {
126            value: ValueExpr::Placeholder(s),
127            ..
128        }) = expr
129        {
130            *s = format_placeholder(index);
131            index += 1;
132        }
133        ControlFlow::<()>::Continue(())
134    });
135}
136
137/// Convert [`ParamValue`] into [`Value`] according to param type.
138/// It will try it's best to do type conversions if possible
139pub fn convert_value(param: &ParamValue, t: &ConcreteDataType) -> Result<ScalarValue> {
140    match param.value.into_inner() {
141        ValueInner::Int(i) => match t {
142            ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(i as i8))),
143            ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(i as i16))),
144            ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(i as i32))),
145            ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(i))),
146            ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(i as u8))),
147            ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(i as u16))),
148            ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(i as u32))),
149            ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(i as u64))),
150            ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(i as f32))),
151            ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(i as f64))),
152            ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(i != 0))),
153            ConcreteDataType::Timestamp(ts_type) => Value::Timestamp(ts_type.create_timestamp(i))
154                .try_to_scalar_value(t)
155                .context(error::ConvertScalarValueSnafu),
156
157            _ => error::PreparedStmtTypeMismatchSnafu {
158                expected: t,
159                actual: param.coltype,
160            }
161            .fail(),
162        },
163        ValueInner::UInt(u) => match t {
164            ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(u as i8))),
165            ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(u as i16))),
166            ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(u as i32))),
167            ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(u as i64))),
168            ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(u as u8))),
169            ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(u as u16))),
170            ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(u as u32))),
171            ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(u))),
172            ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(u as f32))),
173            ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(u as f64))),
174            ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(u != 0))),
175            ConcreteDataType::Timestamp(ts_type) => {
176                Value::Timestamp(ts_type.create_timestamp(u as i64))
177                    .try_to_scalar_value(t)
178                    .context(error::ConvertScalarValueSnafu)
179            }
180
181            _ => error::PreparedStmtTypeMismatchSnafu {
182                expected: t,
183                actual: param.coltype,
184            }
185            .fail(),
186        },
187        ValueInner::Double(f) => match t {
188            ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(f as i8))),
189            ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(f as i16))),
190            ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(f as i32))),
191            ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(f as i64))),
192            ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(f as u8))),
193            ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(f as u16))),
194            ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(f as u32))),
195            ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(f as u64))),
196            ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(f as f32))),
197            ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(f))),
198
199            _ => error::PreparedStmtTypeMismatchSnafu {
200                expected: t,
201                actual: param.coltype,
202            }
203            .fail(),
204        },
205        ValueInner::NULL => value::to_null_scalar_value(t).context(error::ConvertScalarValueSnafu),
206        ValueInner::Bytes(b) => match t {
207            ConcreteDataType::String(t) => {
208                let s = String::from_utf8_lossy(b).to_string();
209                if t.is_large() {
210                    Ok(ScalarValue::LargeUtf8(Some(s)))
211                } else {
212                    Ok(ScalarValue::Utf8(Some(s)))
213                }
214            }
215            ConcreteDataType::Binary(_) => Ok(ScalarValue::Binary(Some(b.to_vec()))),
216            ConcreteDataType::Timestamp(ts_type) => convert_bytes_to_timestamp(b, ts_type),
217            ConcreteDataType::Date(_) => convert_bytes_to_date(b),
218            _ => error::PreparedStmtTypeMismatchSnafu {
219                expected: t,
220                actual: param.coltype,
221            }
222            .fail(),
223        },
224        ValueInner::Date(_) => {
225            let date: common_time::Date = NaiveDate::from(param.value).into();
226            Ok(ScalarValue::Date32(Some(date.val())))
227        }
228        ValueInner::Datetime(_) => {
229            let timestamp_millis = to_naive_datetime(param.value)
230                .map_err(|e| {
231                    error::MysqlValueConversionSnafu {
232                        err_msg: e.to_string(),
233                    }
234                    .build()
235                })?
236                .and_utc()
237                .timestamp_millis();
238
239            match t {
240                ConcreteDataType::Timestamp(_) => Ok(ScalarValue::TimestampMillisecond(
241                    Some(timestamp_millis),
242                    None,
243                )),
244                _ => error::PreparedStmtTypeMismatchSnafu {
245                    expected: t,
246                    actual: param.coltype,
247                }
248                .fail(),
249            }
250        }
251        ValueInner::Time(_) => Ok(ScalarValue::Time64Nanosecond(Some(
252            Duration::from(param.value).as_millis() as i64,
253        ))),
254    }
255}
256
257/// Convert an MySQL expression to a scalar value.
258/// It automatically handles the conversion of strings to numeric values.
259pub fn convert_expr_to_scalar_value(param: &Expr, t: &ConcreteDataType) -> Result<ScalarValue> {
260    let column_schema = ColumnSchema::new("", t.clone(), true);
261    match param {
262        Expr::Value(v) => {
263            let v = sql_value_to_value(&column_schema, &v.value, None, None, true);
264            match v {
265                Ok(v) => v
266                    .try_to_scalar_value(t)
267                    .context(error::ConvertScalarValueSnafu),
268                Err(e) => error::InvalidParameterSnafu {
269                    reason: e.to_string(),
270                }
271                .fail(),
272            }
273        }
274        Expr::UnaryOp { op, expr } if let Expr::Value(v) = &**expr => {
275            let v = sql_value_to_value(&column_schema, &v.value, None, Some(*op), true);
276            match v {
277                Ok(v) => v
278                    .try_to_scalar_value(t)
279                    .context(error::ConvertScalarValueSnafu),
280                Err(e) => error::InvalidParameterSnafu {
281                    reason: e.to_string(),
282                }
283                .fail(),
284            }
285        }
286        _ => error::InvalidParameterSnafu {
287            reason: format!("cannot convert {:?} to scalar value of type {}", param, t),
288        }
289        .fail(),
290    }
291}
292
293fn convert_bytes_to_timestamp(bytes: &[u8], ts_type: &TimestampType) -> Result<ScalarValue> {
294    let ts = Timestamp::from_str_utc(&String::from_utf8_lossy(bytes))
295        .map_err(|e| {
296            error::MysqlValueConversionSnafu {
297                err_msg: e.to_string(),
298            }
299            .build()
300        })?
301        .convert_to(ts_type.unit())
302        .ok_or_else(|| {
303            error::MysqlValueConversionSnafu {
304                err_msg: "Overflow when converting timestamp to target unit".to_string(),
305            }
306            .build()
307        })?;
308    match ts_type {
309        TimestampType::Nanosecond(_) => {
310            Ok(ScalarValue::TimestampNanosecond(Some(ts.value()), None))
311        }
312        TimestampType::Microsecond(_) => {
313            Ok(ScalarValue::TimestampMicrosecond(Some(ts.value()), None))
314        }
315        TimestampType::Millisecond(_) => {
316            Ok(ScalarValue::TimestampMillisecond(Some(ts.value()), None))
317        }
318        TimestampType::Second(_) => Ok(ScalarValue::TimestampSecond(Some(ts.value()), None)),
319    }
320}
321
322fn convert_bytes_to_date(bytes: &[u8]) -> Result<ScalarValue> {
323    let date = Date::from_str_utc(&String::from_utf8_lossy(bytes)).map_err(|e| {
324        error::MysqlValueConversionSnafu {
325            err_msg: e.to_string(),
326        }
327        .build()
328    })?;
329
330    Ok(ScalarValue::Date32(Some(date.val())))
331}
332
333#[cfg(test)]
334mod tests {
335    use datatypes::types::{
336        TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
337        TimestampSecondType,
338    };
339    use sql::dialect::MySqlDialect;
340    use sql::parser::{ParseOptions, ParserContext};
341
342    use super::*;
343
344    #[test]
345    fn test_format_placeholder() {
346        assert_eq!("$1", format_placeholder(1));
347        assert_eq!("$3", format_placeholder(3));
348    }
349
350    #[test]
351    fn test_replace_placeholders() {
352        let create = "create table demo(host string, ts timestamp time index)";
353        let (sql, index) = replace_placeholders(create);
354        assert_eq!(create, sql);
355        assert_eq!(1, index);
356
357        let insert = "insert into demo values(?,?,?)";
358        let (sql, index) = replace_placeholders(insert);
359        assert_eq!("insert into demo values($1,$2,$3)", sql);
360        assert_eq!(4, index);
361
362        let query = "select from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?";
363        let (sql, index) = replace_placeholders(query);
364        assert_eq!(
365            "select from demo where host=$1 and idc in (select idc from idcs where name=$2) and cpu>$3",
366            sql
367        );
368        assert_eq!(4, index);
369    }
370
371    fn parse_sql(sql: &str) -> Statement {
372        let mut stmts =
373            ParserContext::create_with_dialect(sql, &MySqlDialect {}, ParseOptions::default())
374                .unwrap();
375        stmts.remove(0)
376    }
377
378    #[test]
379    fn test_transform_placeholders() {
380        let insert = parse_sql("insert into demo values(?,?,?)");
381        let Statement::Insert(insert) = transform_placeholders(insert) else {
382            unreachable!()
383        };
384        assert_eq!(
385            "INSERT INTO demo VALUES ($1, $2, $3)",
386            insert.inner.to_string()
387        );
388
389        let delete = parse_sql("delete from demo where host=? and idc=?");
390        let Statement::Delete(delete) = transform_placeholders(delete) else {
391            unreachable!()
392        };
393        assert_eq!(
394            "DELETE FROM demo WHERE host = $1 AND idc = $2",
395            delete.inner.to_string()
396        );
397
398        let select = parse_sql(
399            "select * from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?",
400        );
401        let Statement::Query(select) = transform_placeholders(select) else {
402            unreachable!()
403        };
404        assert_eq!(
405            "SELECT * FROM demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3",
406            select.inner.to_string()
407        );
408    }
409
410    #[test]
411    fn test_convert_expr_to_scalar_value() {
412        let expr = Expr::Value(ValueExpr::Number("123".to_string(), false).into());
413        let t = ConcreteDataType::int32_datatype();
414        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
415        assert_eq!(ScalarValue::Int32(Some(123)), v);
416
417        let expr = Expr::Value(ValueExpr::Number("123.456789".to_string(), false).into());
418        let t = ConcreteDataType::float64_datatype();
419        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
420        assert_eq!(ScalarValue::Float64(Some(123.456789)), v);
421
422        let expr = Expr::Value(ValueExpr::SingleQuotedString("2001-01-02".to_string()).into());
423        let t = ConcreteDataType::date_datatype();
424        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
425        let scalar_v = ScalarValue::Utf8(Some("2001-01-02".to_string()))
426            .cast_to(&arrow_schema::DataType::Date32)
427            .unwrap();
428        assert_eq!(scalar_v, v);
429
430        let expr =
431            Expr::Value(ValueExpr::SingleQuotedString("2001-01-02 03:04:05".to_string()).into());
432        let t = ConcreteDataType::timestamp_microsecond_datatype();
433        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
434        let scalar_v = ScalarValue::Utf8(Some("2001-01-02 03:04:05".to_string()))
435            .cast_to(&arrow_schema::DataType::Timestamp(
436                arrow_schema::TimeUnit::Microsecond,
437                None,
438            ))
439            .unwrap();
440        assert_eq!(scalar_v, v);
441
442        let expr = Expr::Value(ValueExpr::SingleQuotedString("hello".to_string()).into());
443        let t = ConcreteDataType::string_datatype();
444        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
445        assert_eq!(ScalarValue::Utf8(Some("hello".to_string())), v);
446
447        let expr = Expr::Value(ValueExpr::Null.into());
448        let t = ConcreteDataType::time_microsecond_datatype();
449        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
450        assert_eq!(ScalarValue::Time64Microsecond(None), v);
451    }
452
453    #[test]
454    fn test_convert_bytes_to_timestamp() {
455        let test_cases = vec![
456            // input unix timestamp in seconds -> nanosecond.
457            (
458                "2024-12-26 12:00:00",
459                TimestampType::Nanosecond(TimestampNanosecondType),
460                ScalarValue::TimestampNanosecond(Some(1735214400000000000), None),
461            ),
462            // input unix timestamp in seconds -> microsecond.
463            (
464                "2024-12-26 12:00:00",
465                TimestampType::Microsecond(TimestampMicrosecondType),
466                ScalarValue::TimestampMicrosecond(Some(1735214400000000), None),
467            ),
468            // input unix timestamp in seconds -> millisecond.
469            (
470                "2024-12-26 12:00:00",
471                TimestampType::Millisecond(TimestampMillisecondType),
472                ScalarValue::TimestampMillisecond(Some(1735214400000), None),
473            ),
474            // input unix timestamp in seconds -> second.
475            (
476                "2024-12-26 12:00:00",
477                TimestampType::Second(TimestampSecondType),
478                ScalarValue::TimestampSecond(Some(1735214400), None),
479            ),
480            // input unix timestamp in milliseconds -> nanosecond.
481            (
482                "2024-12-26 12:00:00.123",
483                TimestampType::Nanosecond(TimestampNanosecondType),
484                ScalarValue::TimestampNanosecond(Some(1735214400123000000), None),
485            ),
486            // input unix timestamp in milliseconds -> microsecond.
487            (
488                "2024-12-26 12:00:00.123",
489                TimestampType::Microsecond(TimestampMicrosecondType),
490                ScalarValue::TimestampMicrosecond(Some(1735214400123000), None),
491            ),
492            // input unix timestamp in milliseconds -> millisecond.
493            (
494                "2024-12-26 12:00:00.123",
495                TimestampType::Millisecond(TimestampMillisecondType),
496                ScalarValue::TimestampMillisecond(Some(1735214400123), None),
497            ),
498            // input unix timestamp in milliseconds -> second.
499            (
500                "2024-12-26 12:00:00.123",
501                TimestampType::Second(TimestampSecondType),
502                ScalarValue::TimestampSecond(Some(1735214400), None),
503            ),
504            // input unix timestamp in microseconds -> nanosecond.
505            (
506                "2024-12-26 12:00:00.123456",
507                TimestampType::Nanosecond(TimestampNanosecondType),
508                ScalarValue::TimestampNanosecond(Some(1735214400123456000), None),
509            ),
510            // input unix timestamp in microseconds -> microsecond.
511            (
512                "2024-12-26 12:00:00.123456",
513                TimestampType::Microsecond(TimestampMicrosecondType),
514                ScalarValue::TimestampMicrosecond(Some(1735214400123456), None),
515            ),
516            // input unix timestamp in microseconds -> millisecond.
517            (
518                "2024-12-26 12:00:00.123456",
519                TimestampType::Millisecond(TimestampMillisecondType),
520                ScalarValue::TimestampMillisecond(Some(1735214400123), None),
521            ),
522            // input unix timestamp in milliseconds -> second.
523            (
524                "2024-12-26 12:00:00.123456",
525                TimestampType::Second(TimestampSecondType),
526                ScalarValue::TimestampSecond(Some(1735214400), None),
527            ),
528        ];
529
530        for (input, ts_type, expected) in test_cases {
531            let result = convert_bytes_to_timestamp(input.as_bytes(), &ts_type).unwrap();
532            assert_eq!(result, expected);
533        }
534    }
535
536    #[test]
537    fn test_convert_bytes_to_date() {
538        let test_cases = vec![
539            // Standard date format: YYYY-MM-DD
540            ("1970-01-01", ScalarValue::Date32(Some(0))),
541            ("1969-12-31", ScalarValue::Date32(Some(-1))),
542            ("2024-02-29", ScalarValue::Date32(Some(19782))),
543            ("2024-01-01", ScalarValue::Date32(Some(19723))),
544            ("2024-12-31", ScalarValue::Date32(Some(20088))),
545            ("2001-01-02", ScalarValue::Date32(Some(11324))),
546            ("2050-06-14", ScalarValue::Date32(Some(29384))),
547            ("2020-03-15", ScalarValue::Date32(Some(18336))),
548        ];
549
550        for (input, expected) in test_cases {
551            let result = convert_bytes_to_date(input.as_bytes()).unwrap();
552            assert_eq!(result, expected, "Failed for input: {}", input);
553        }
554    }
555}