servers/mysql/
helper.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::ControlFlow;
16use std::time::Duration;
17
18use chrono::NaiveDate;
19use common_query::prelude::ScalarValue;
20use common_sql::convert::sql_value_to_value;
21use common_time::{Date, Timestamp};
22use datafusion_common::tree_node::{Transformed, TreeNode};
23use datafusion_expr::LogicalPlan;
24use datatypes::prelude::ConcreteDataType;
25use datatypes::types::TimestampType;
26use datatypes::value::{self, Value};
27use itertools::Itertools;
28use opensrv_mysql::{ParamValue, ValueInner, to_naive_datetime};
29use snafu::ResultExt;
30use sql::ast::{Expr, Value as ValueExpr, ValueWithSpan, VisitMut, visit_expressions_mut};
31use sql::statements::statement::Statement;
32
33use crate::error::{self, DataFusionSnafu, Result};
34
35/// Returns the placeholder string "$i".
36pub fn format_placeholder(i: usize) -> String {
37    format!("${}", i)
38}
39
40/// Replace all the "?" placeholder into "$i" in SQL,
41/// returns the new SQL and the last placeholder index.
42pub fn replace_placeholders(query: &str) -> (String, usize) {
43    let query_parts = query.split('?').collect::<Vec<_>>();
44    let parts_len = query_parts.len();
45    let mut index = 0;
46    let query = query_parts
47        .into_iter()
48        .enumerate()
49        .map(|(i, part)| {
50            if i == parts_len - 1 {
51                return part.to_string();
52            }
53
54            index += 1;
55            format!("{part}{}", format_placeholder(index))
56        })
57        .join("");
58
59    (query, index + 1)
60}
61
62/// Transform all the "?" placeholder into "$i".
63/// Only works for Insert,Query and Delete statements.
64pub fn transform_placeholders(stmt: Statement) -> Statement {
65    match stmt {
66        Statement::Query(mut query) => {
67            visit_placeholders(&mut query.inner);
68            Statement::Query(query)
69        }
70        Statement::Insert(mut insert) => {
71            visit_placeholders(&mut insert.inner);
72            Statement::Insert(insert)
73        }
74        Statement::Delete(mut delete) => {
75            visit_placeholders(&mut delete.inner);
76            Statement::Delete(delete)
77        }
78        stmt => stmt,
79    }
80}
81
82/// Give placeholder that cast to certain type `data_type` the same data type as is cast to
83///
84/// because it seems datafusion will not give data type to placeholder if it need to be cast to certain type, still unknown if this is a feature or a bug. And if a placeholder expr have no data type, datafusion will fail to extract it using `LogicalPlan::get_parameter_types`
85pub fn fix_placeholder_types(plan: &mut LogicalPlan) -> Result<()> {
86    let give_placeholder_types = |mut e: datafusion_expr::Expr| {
87        if let datafusion_expr::Expr::Cast(cast) = &mut e {
88            if let datafusion_expr::Expr::Placeholder(ph) = &mut *cast.expr {
89                if ph.data_type.is_none() {
90                    ph.data_type = Some(cast.data_type.clone());
91                    common_telemetry::debug!(
92                        "give placeholder type {:?} to {:?}",
93                        cast.data_type,
94                        ph
95                    );
96                    Ok(Transformed::yes(e))
97                } else {
98                    Ok(Transformed::no(e))
99                }
100            } else {
101                Ok(Transformed::no(e))
102            }
103        } else {
104            Ok(Transformed::no(e))
105        }
106    };
107    let give_placeholder_types_recursively =
108        |e: datafusion_expr::Expr| e.transform(give_placeholder_types);
109    *plan = std::mem::take(plan)
110        .transform(|p| p.map_expressions(give_placeholder_types_recursively))
111        .context(DataFusionSnafu)?
112        .data;
113    Ok(())
114}
115
116fn visit_placeholders<V>(v: &mut V)
117where
118    V: VisitMut,
119{
120    let mut index = 1;
121    let _ = visit_expressions_mut(v, |expr| {
122        if let Expr::Value(ValueWithSpan {
123            value: ValueExpr::Placeholder(s),
124            ..
125        }) = expr
126        {
127            *s = format_placeholder(index);
128            index += 1;
129        }
130        ControlFlow::<()>::Continue(())
131    });
132}
133
134/// Convert [`ParamValue`] into [`Value`] according to param type.
135/// It will try it's best to do type conversions if possible
136pub fn convert_value(param: &ParamValue, t: &ConcreteDataType) -> Result<ScalarValue> {
137    match param.value.into_inner() {
138        ValueInner::Int(i) => match t {
139            ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(i as i8))),
140            ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(i as i16))),
141            ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(i as i32))),
142            ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(i))),
143            ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(i as u8))),
144            ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(i as u16))),
145            ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(i as u32))),
146            ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(i as u64))),
147            ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(i as f32))),
148            ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(i as f64))),
149            ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(i != 0))),
150            ConcreteDataType::Timestamp(ts_type) => Value::Timestamp(ts_type.create_timestamp(i))
151                .try_to_scalar_value(t)
152                .context(error::ConvertScalarValueSnafu),
153
154            _ => error::PreparedStmtTypeMismatchSnafu {
155                expected: t,
156                actual: param.coltype,
157            }
158            .fail(),
159        },
160        ValueInner::UInt(u) => match t {
161            ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(u as i8))),
162            ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(u as i16))),
163            ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(u as i32))),
164            ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(u as i64))),
165            ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(u as u8))),
166            ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(u as u16))),
167            ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(u as u32))),
168            ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(u))),
169            ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(u as f32))),
170            ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(u as f64))),
171            ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(u != 0))),
172            ConcreteDataType::Timestamp(ts_type) => {
173                Value::Timestamp(ts_type.create_timestamp(u as i64))
174                    .try_to_scalar_value(t)
175                    .context(error::ConvertScalarValueSnafu)
176            }
177
178            _ => error::PreparedStmtTypeMismatchSnafu {
179                expected: t,
180                actual: param.coltype,
181            }
182            .fail(),
183        },
184        ValueInner::Double(f) => match t {
185            ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(f as i8))),
186            ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(f as i16))),
187            ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(f as i32))),
188            ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(f as i64))),
189            ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(f as u8))),
190            ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(f as u16))),
191            ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(f as u32))),
192            ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(f as u64))),
193            ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(f as f32))),
194            ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(f))),
195
196            _ => error::PreparedStmtTypeMismatchSnafu {
197                expected: t,
198                actual: param.coltype,
199            }
200            .fail(),
201        },
202        ValueInner::NULL => value::to_null_scalar_value(t).context(error::ConvertScalarValueSnafu),
203        ValueInner::Bytes(b) => match t {
204            ConcreteDataType::String(t) => {
205                let s = String::from_utf8_lossy(b).to_string();
206                if t.is_large() {
207                    Ok(ScalarValue::LargeUtf8(Some(s)))
208                } else {
209                    Ok(ScalarValue::Utf8(Some(s)))
210                }
211            }
212            ConcreteDataType::Binary(_) => Ok(ScalarValue::Binary(Some(b.to_vec()))),
213            ConcreteDataType::Timestamp(ts_type) => convert_bytes_to_timestamp(b, ts_type),
214            ConcreteDataType::Date(_) => convert_bytes_to_date(b),
215            _ => error::PreparedStmtTypeMismatchSnafu {
216                expected: t,
217                actual: param.coltype,
218            }
219            .fail(),
220        },
221        ValueInner::Date(_) => {
222            let date: common_time::Date = NaiveDate::from(param.value).into();
223            Ok(ScalarValue::Date32(Some(date.val())))
224        }
225        ValueInner::Datetime(_) => {
226            let timestamp_millis = to_naive_datetime(param.value)
227                .map_err(|e| {
228                    error::MysqlValueConversionSnafu {
229                        err_msg: e.to_string(),
230                    }
231                    .build()
232                })?
233                .and_utc()
234                .timestamp_millis();
235
236            match t {
237                ConcreteDataType::Timestamp(_) => Ok(ScalarValue::TimestampMillisecond(
238                    Some(timestamp_millis),
239                    None,
240                )),
241                _ => error::PreparedStmtTypeMismatchSnafu {
242                    expected: t,
243                    actual: param.coltype,
244                }
245                .fail(),
246            }
247        }
248        ValueInner::Time(_) => Ok(ScalarValue::Time64Nanosecond(Some(
249            Duration::from(param.value).as_millis() as i64,
250        ))),
251    }
252}
253
254/// Convert an MySQL expression to a scalar value.
255/// It automatically handles the conversion of strings to numeric values.
256pub fn convert_expr_to_scalar_value(param: &Expr, t: &ConcreteDataType) -> Result<ScalarValue> {
257    match param {
258        Expr::Value(v) => {
259            let v = sql_value_to_value("", t, &v.value, None, None, true);
260            match v {
261                Ok(v) => v
262                    .try_to_scalar_value(t)
263                    .context(error::ConvertScalarValueSnafu),
264                Err(e) => error::InvalidParameterSnafu {
265                    reason: e.to_string(),
266                }
267                .fail(),
268            }
269        }
270        Expr::UnaryOp { op, expr } if let Expr::Value(v) = &**expr => {
271            let v = sql_value_to_value("", t, &v.value, None, Some(*op), true);
272            match v {
273                Ok(v) => v
274                    .try_to_scalar_value(t)
275                    .context(error::ConvertScalarValueSnafu),
276                Err(e) => error::InvalidParameterSnafu {
277                    reason: e.to_string(),
278                }
279                .fail(),
280            }
281        }
282        _ => error::InvalidParameterSnafu {
283            reason: format!("cannot convert {:?} to scalar value of type {}", param, t),
284        }
285        .fail(),
286    }
287}
288
289fn convert_bytes_to_timestamp(bytes: &[u8], ts_type: &TimestampType) -> Result<ScalarValue> {
290    let ts = Timestamp::from_str_utc(&String::from_utf8_lossy(bytes))
291        .map_err(|e| {
292            error::MysqlValueConversionSnafu {
293                err_msg: e.to_string(),
294            }
295            .build()
296        })?
297        .convert_to(ts_type.unit())
298        .ok_or_else(|| {
299            error::MysqlValueConversionSnafu {
300                err_msg: "Overflow when converting timestamp to target unit".to_string(),
301            }
302            .build()
303        })?;
304    match ts_type {
305        TimestampType::Nanosecond(_) => {
306            Ok(ScalarValue::TimestampNanosecond(Some(ts.value()), None))
307        }
308        TimestampType::Microsecond(_) => {
309            Ok(ScalarValue::TimestampMicrosecond(Some(ts.value()), None))
310        }
311        TimestampType::Millisecond(_) => {
312            Ok(ScalarValue::TimestampMillisecond(Some(ts.value()), None))
313        }
314        TimestampType::Second(_) => Ok(ScalarValue::TimestampSecond(Some(ts.value()), None)),
315    }
316}
317
318fn convert_bytes_to_date(bytes: &[u8]) -> Result<ScalarValue> {
319    let date = Date::from_str_utc(&String::from_utf8_lossy(bytes)).map_err(|e| {
320        error::MysqlValueConversionSnafu {
321            err_msg: e.to_string(),
322        }
323        .build()
324    })?;
325
326    Ok(ScalarValue::Date32(Some(date.val())))
327}
328
329#[cfg(test)]
330mod tests {
331    use datatypes::types::{
332        TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
333        TimestampSecondType,
334    };
335    use sql::dialect::MySqlDialect;
336    use sql::parser::{ParseOptions, ParserContext};
337
338    use super::*;
339
340    #[test]
341    fn test_format_placeholder() {
342        assert_eq!("$1", format_placeholder(1));
343        assert_eq!("$3", format_placeholder(3));
344    }
345
346    #[test]
347    fn test_replace_placeholders() {
348        let create = "create table demo(host string, ts timestamp time index)";
349        let (sql, index) = replace_placeholders(create);
350        assert_eq!(create, sql);
351        assert_eq!(1, index);
352
353        let insert = "insert into demo values(?,?,?)";
354        let (sql, index) = replace_placeholders(insert);
355        assert_eq!("insert into demo values($1,$2,$3)", sql);
356        assert_eq!(4, index);
357
358        let query = "select from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?";
359        let (sql, index) = replace_placeholders(query);
360        assert_eq!(
361            "select from demo where host=$1 and idc in (select idc from idcs where name=$2) and cpu>$3",
362            sql
363        );
364        assert_eq!(4, index);
365    }
366
367    fn parse_sql(sql: &str) -> Statement {
368        let mut stmts =
369            ParserContext::create_with_dialect(sql, &MySqlDialect {}, ParseOptions::default())
370                .unwrap();
371        stmts.remove(0)
372    }
373
374    #[test]
375    fn test_transform_placeholders() {
376        let insert = parse_sql("insert into demo values(?,?,?)");
377        let Statement::Insert(insert) = transform_placeholders(insert) else {
378            unreachable!()
379        };
380        assert_eq!(
381            "INSERT INTO demo VALUES ($1, $2, $3)",
382            insert.inner.to_string()
383        );
384
385        let delete = parse_sql("delete from demo where host=? and idc=?");
386        let Statement::Delete(delete) = transform_placeholders(delete) else {
387            unreachable!()
388        };
389        assert_eq!(
390            "DELETE FROM demo WHERE host = $1 AND idc = $2",
391            delete.inner.to_string()
392        );
393
394        let select = parse_sql(
395            "select * from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?",
396        );
397        let Statement::Query(select) = transform_placeholders(select) else {
398            unreachable!()
399        };
400        assert_eq!(
401            "SELECT * FROM demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3",
402            select.inner.to_string()
403        );
404    }
405
406    #[test]
407    fn test_convert_expr_to_scalar_value() {
408        let expr = Expr::Value(ValueExpr::Number("123".to_string(), false).into());
409        let t = ConcreteDataType::int32_datatype();
410        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
411        assert_eq!(ScalarValue::Int32(Some(123)), v);
412
413        let expr = Expr::Value(ValueExpr::Number("123.456789".to_string(), false).into());
414        let t = ConcreteDataType::float64_datatype();
415        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
416        assert_eq!(ScalarValue::Float64(Some(123.456789)), v);
417
418        let expr = Expr::Value(ValueExpr::SingleQuotedString("2001-01-02".to_string()).into());
419        let t = ConcreteDataType::date_datatype();
420        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
421        let scalar_v = ScalarValue::Utf8(Some("2001-01-02".to_string()))
422            .cast_to(&arrow_schema::DataType::Date32)
423            .unwrap();
424        assert_eq!(scalar_v, v);
425
426        let expr =
427            Expr::Value(ValueExpr::SingleQuotedString("2001-01-02 03:04:05".to_string()).into());
428        let t = ConcreteDataType::timestamp_microsecond_datatype();
429        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
430        let scalar_v = ScalarValue::Utf8(Some("2001-01-02 03:04:05".to_string()))
431            .cast_to(&arrow_schema::DataType::Timestamp(
432                arrow_schema::TimeUnit::Microsecond,
433                None,
434            ))
435            .unwrap();
436        assert_eq!(scalar_v, v);
437
438        let expr = Expr::Value(ValueExpr::SingleQuotedString("hello".to_string()).into());
439        let t = ConcreteDataType::string_datatype();
440        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
441        assert_eq!(ScalarValue::Utf8(Some("hello".to_string())), v);
442
443        let expr = Expr::Value(ValueExpr::Null.into());
444        let t = ConcreteDataType::time_microsecond_datatype();
445        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
446        assert_eq!(ScalarValue::Time64Microsecond(None), v);
447    }
448
449    #[test]
450    fn test_convert_bytes_to_timestamp() {
451        let test_cases = vec![
452            // input unix timestamp in seconds -> nanosecond.
453            (
454                "2024-12-26 12:00:00",
455                TimestampType::Nanosecond(TimestampNanosecondType),
456                ScalarValue::TimestampNanosecond(Some(1735214400000000000), None),
457            ),
458            // input unix timestamp in seconds -> microsecond.
459            (
460                "2024-12-26 12:00:00",
461                TimestampType::Microsecond(TimestampMicrosecondType),
462                ScalarValue::TimestampMicrosecond(Some(1735214400000000), None),
463            ),
464            // input unix timestamp in seconds -> millisecond.
465            (
466                "2024-12-26 12:00:00",
467                TimestampType::Millisecond(TimestampMillisecondType),
468                ScalarValue::TimestampMillisecond(Some(1735214400000), None),
469            ),
470            // input unix timestamp in seconds -> second.
471            (
472                "2024-12-26 12:00:00",
473                TimestampType::Second(TimestampSecondType),
474                ScalarValue::TimestampSecond(Some(1735214400), None),
475            ),
476            // input unix timestamp in milliseconds -> nanosecond.
477            (
478                "2024-12-26 12:00:00.123",
479                TimestampType::Nanosecond(TimestampNanosecondType),
480                ScalarValue::TimestampNanosecond(Some(1735214400123000000), None),
481            ),
482            // input unix timestamp in milliseconds -> microsecond.
483            (
484                "2024-12-26 12:00:00.123",
485                TimestampType::Microsecond(TimestampMicrosecondType),
486                ScalarValue::TimestampMicrosecond(Some(1735214400123000), None),
487            ),
488            // input unix timestamp in milliseconds -> millisecond.
489            (
490                "2024-12-26 12:00:00.123",
491                TimestampType::Millisecond(TimestampMillisecondType),
492                ScalarValue::TimestampMillisecond(Some(1735214400123), None),
493            ),
494            // input unix timestamp in milliseconds -> second.
495            (
496                "2024-12-26 12:00:00.123",
497                TimestampType::Second(TimestampSecondType),
498                ScalarValue::TimestampSecond(Some(1735214400), None),
499            ),
500            // input unix timestamp in microseconds -> nanosecond.
501            (
502                "2024-12-26 12:00:00.123456",
503                TimestampType::Nanosecond(TimestampNanosecondType),
504                ScalarValue::TimestampNanosecond(Some(1735214400123456000), None),
505            ),
506            // input unix timestamp in microseconds -> microsecond.
507            (
508                "2024-12-26 12:00:00.123456",
509                TimestampType::Microsecond(TimestampMicrosecondType),
510                ScalarValue::TimestampMicrosecond(Some(1735214400123456), None),
511            ),
512            // input unix timestamp in microseconds -> millisecond.
513            (
514                "2024-12-26 12:00:00.123456",
515                TimestampType::Millisecond(TimestampMillisecondType),
516                ScalarValue::TimestampMillisecond(Some(1735214400123), None),
517            ),
518            // input unix timestamp in milliseconds -> second.
519            (
520                "2024-12-26 12:00:00.123456",
521                TimestampType::Second(TimestampSecondType),
522                ScalarValue::TimestampSecond(Some(1735214400), None),
523            ),
524        ];
525
526        for (input, ts_type, expected) in test_cases {
527            let result = convert_bytes_to_timestamp(input.as_bytes(), &ts_type).unwrap();
528            assert_eq!(result, expected);
529        }
530    }
531
532    #[test]
533    fn test_convert_bytes_to_date() {
534        let test_cases = vec![
535            // Standard date format: YYYY-MM-DD
536            ("1970-01-01", ScalarValue::Date32(Some(0))),
537            ("1969-12-31", ScalarValue::Date32(Some(-1))),
538            ("2024-02-29", ScalarValue::Date32(Some(19782))),
539            ("2024-01-01", ScalarValue::Date32(Some(19723))),
540            ("2024-12-31", ScalarValue::Date32(Some(20088))),
541            ("2001-01-02", ScalarValue::Date32(Some(11324))),
542            ("2050-06-14", ScalarValue::Date32(Some(29384))),
543            ("2020-03-15", ScalarValue::Date32(Some(18336))),
544        ];
545
546        for (input, expected) in test_cases {
547            let result = convert_bytes_to_date(input.as_bytes()).unwrap();
548            assert_eq!(result, expected, "Failed for input: {}", input);
549        }
550    }
551}