servers/mysql/
helper.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::ControlFlow;
16use std::time::Duration;
17
18use chrono::NaiveDate;
19use common_query::prelude::ScalarValue;
20use common_sql::convert::sql_value_to_value;
21use common_time::{Date, Timestamp};
22use datatypes::prelude::ConcreteDataType;
23use datatypes::schema::ColumnSchema;
24use datatypes::types::TimestampType;
25use datatypes::value::{self, Value};
26use itertools::Itertools;
27use opensrv_mysql::{ParamValue, ValueInner, to_naive_datetime};
28use snafu::ResultExt;
29use sql::ast::{Expr, Value as ValueExpr, ValueWithSpan, VisitMut, visit_expressions_mut};
30use sql::statements::statement::Statement;
31
32use crate::error::{self, Result};
33
34/// Returns the placeholder string "$i".
35pub fn format_placeholder(i: usize) -> String {
36    format!("${}", i)
37}
38
39/// Replace all the "?" placeholder into "$i" in SQL,
40/// returns the new SQL and the last placeholder index.
41pub fn replace_placeholders(query: &str) -> (String, usize) {
42    let query_parts = query.split('?').collect::<Vec<_>>();
43    let parts_len = query_parts.len();
44    let mut index = 0;
45    let query = query_parts
46        .into_iter()
47        .enumerate()
48        .map(|(i, part)| {
49            if i == parts_len - 1 {
50                return part.to_string();
51            }
52
53            index += 1;
54            format!("{part}{}", format_placeholder(index))
55        })
56        .join("");
57
58    (query, index + 1)
59}
60
61/// Transform all the "?" placeholder into "$i".
62/// Only works for Insert,Query and Delete statements.
63pub fn transform_placeholders(stmt: Statement) -> Statement {
64    match stmt {
65        Statement::Query(mut query) => {
66            visit_placeholders(&mut query.inner);
67            Statement::Query(query)
68        }
69        Statement::Insert(mut insert) => {
70            visit_placeholders(&mut insert.inner);
71            Statement::Insert(insert)
72        }
73        Statement::Delete(mut delete) => {
74            visit_placeholders(&mut delete.inner);
75            Statement::Delete(delete)
76        }
77        stmt => stmt,
78    }
79}
80
81fn visit_placeholders<V>(v: &mut V)
82where
83    V: VisitMut,
84{
85    let mut index = 1;
86    let _ = visit_expressions_mut(v, |expr| {
87        if let Expr::Value(ValueWithSpan {
88            value: ValueExpr::Placeholder(s),
89            ..
90        }) = expr
91        {
92            *s = format_placeholder(index);
93            index += 1;
94        }
95        ControlFlow::<()>::Continue(())
96    });
97}
98
99/// Convert [`ParamValue`] into [`Value`] according to param type.
100/// It will try it's best to do type conversions if possible
101pub fn convert_value(param: &ParamValue, t: &ConcreteDataType) -> Result<ScalarValue> {
102    match param.value.into_inner() {
103        ValueInner::Int(i) => match t {
104            ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(i as i8))),
105            ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(i as i16))),
106            ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(i as i32))),
107            ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(i))),
108            ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(i as u8))),
109            ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(i as u16))),
110            ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(i as u32))),
111            ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(i as u64))),
112            ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(i as f32))),
113            ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(i as f64))),
114            ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(i != 0))),
115            ConcreteDataType::Timestamp(ts_type) => Value::Timestamp(ts_type.create_timestamp(i))
116                .try_to_scalar_value(t)
117                .context(error::ConvertScalarValueSnafu),
118
119            _ => error::PreparedStmtTypeMismatchSnafu {
120                expected: t,
121                actual: param.coltype,
122            }
123            .fail(),
124        },
125        ValueInner::UInt(u) => match t {
126            ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(u as i8))),
127            ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(u as i16))),
128            ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(u as i32))),
129            ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(u as i64))),
130            ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(u as u8))),
131            ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(u as u16))),
132            ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(u as u32))),
133            ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(u))),
134            ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(u as f32))),
135            ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(u as f64))),
136            ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(u != 0))),
137            ConcreteDataType::Timestamp(ts_type) => {
138                Value::Timestamp(ts_type.create_timestamp(u as i64))
139                    .try_to_scalar_value(t)
140                    .context(error::ConvertScalarValueSnafu)
141            }
142
143            _ => error::PreparedStmtTypeMismatchSnafu {
144                expected: t,
145                actual: param.coltype,
146            }
147            .fail(),
148        },
149        ValueInner::Double(f) => match t {
150            ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(f as i8))),
151            ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(f as i16))),
152            ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(f as i32))),
153            ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(f as i64))),
154            ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(f as u8))),
155            ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(f as u16))),
156            ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(f as u32))),
157            ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(f as u64))),
158            ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(f as f32))),
159            ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(f))),
160
161            _ => error::PreparedStmtTypeMismatchSnafu {
162                expected: t,
163                actual: param.coltype,
164            }
165            .fail(),
166        },
167        ValueInner::NULL => value::to_null_scalar_value(t).context(error::ConvertScalarValueSnafu),
168        ValueInner::Bytes(b) => match t {
169            ConcreteDataType::String(t) => {
170                let s = String::from_utf8_lossy(b).to_string();
171                if t.is_large() {
172                    Ok(ScalarValue::LargeUtf8(Some(s)))
173                } else {
174                    Ok(ScalarValue::Utf8(Some(s)))
175                }
176            }
177            ConcreteDataType::Binary(_) => Ok(ScalarValue::Binary(Some(b.to_vec()))),
178            ConcreteDataType::Timestamp(ts_type) => convert_bytes_to_timestamp(b, ts_type),
179            ConcreteDataType::Date(_) => convert_bytes_to_date(b),
180            _ => error::PreparedStmtTypeMismatchSnafu {
181                expected: t,
182                actual: param.coltype,
183            }
184            .fail(),
185        },
186        ValueInner::Date(_) => {
187            let date: common_time::Date = NaiveDate::from(param.value).into();
188            Ok(ScalarValue::Date32(Some(date.val())))
189        }
190        ValueInner::Datetime(_) => {
191            let timestamp_millis = to_naive_datetime(param.value)
192                .map_err(|e| {
193                    error::MysqlValueConversionSnafu {
194                        err_msg: e.to_string(),
195                    }
196                    .build()
197                })?
198                .and_utc()
199                .timestamp_millis();
200
201            match t {
202                ConcreteDataType::Timestamp(_) => Ok(ScalarValue::TimestampMillisecond(
203                    Some(timestamp_millis),
204                    None,
205                )),
206                _ => error::PreparedStmtTypeMismatchSnafu {
207                    expected: t,
208                    actual: param.coltype,
209                }
210                .fail(),
211            }
212        }
213        ValueInner::Time(_) => Ok(ScalarValue::Time64Nanosecond(Some(
214            Duration::from(param.value).as_millis() as i64,
215        ))),
216    }
217}
218
219/// Convert an MySQL expression to a scalar value.
220/// It automatically handles the conversion of strings to numeric values.
221pub fn convert_expr_to_scalar_value(param: &Expr, t: &ConcreteDataType) -> Result<ScalarValue> {
222    let column_schema = ColumnSchema::new("", t.clone(), true);
223    match param {
224        Expr::Value(v) => {
225            let v = sql_value_to_value(&column_schema, &v.value, None, None, true);
226            match v {
227                Ok(v) => v
228                    .try_to_scalar_value(t)
229                    .context(error::ConvertScalarValueSnafu),
230                Err(e) => error::InvalidParameterSnafu {
231                    reason: e.to_string(),
232                }
233                .fail(),
234            }
235        }
236        Expr::UnaryOp { op, expr } if let Expr::Value(v) = &**expr => {
237            let v = sql_value_to_value(&column_schema, &v.value, None, Some(*op), true);
238            match v {
239                Ok(v) => v
240                    .try_to_scalar_value(t)
241                    .context(error::ConvertScalarValueSnafu),
242                Err(e) => error::InvalidParameterSnafu {
243                    reason: e.to_string(),
244                }
245                .fail(),
246            }
247        }
248        _ => error::InvalidParameterSnafu {
249            reason: format!("cannot convert {:?} to scalar value of type {}", param, t),
250        }
251        .fail(),
252    }
253}
254
255fn convert_bytes_to_timestamp(bytes: &[u8], ts_type: &TimestampType) -> Result<ScalarValue> {
256    let ts = Timestamp::from_str_utc(&String::from_utf8_lossy(bytes))
257        .map_err(|e| {
258            error::MysqlValueConversionSnafu {
259                err_msg: e.to_string(),
260            }
261            .build()
262        })?
263        .convert_to(ts_type.unit())
264        .ok_or_else(|| {
265            error::MysqlValueConversionSnafu {
266                err_msg: "Overflow when converting timestamp to target unit".to_string(),
267            }
268            .build()
269        })?;
270    match ts_type {
271        TimestampType::Nanosecond(_) => {
272            Ok(ScalarValue::TimestampNanosecond(Some(ts.value()), None))
273        }
274        TimestampType::Microsecond(_) => {
275            Ok(ScalarValue::TimestampMicrosecond(Some(ts.value()), None))
276        }
277        TimestampType::Millisecond(_) => {
278            Ok(ScalarValue::TimestampMillisecond(Some(ts.value()), None))
279        }
280        TimestampType::Second(_) => Ok(ScalarValue::TimestampSecond(Some(ts.value()), None)),
281    }
282}
283
284fn convert_bytes_to_date(bytes: &[u8]) -> Result<ScalarValue> {
285    let date = Date::from_str_utc(&String::from_utf8_lossy(bytes)).map_err(|e| {
286        error::MysqlValueConversionSnafu {
287            err_msg: e.to_string(),
288        }
289        .build()
290    })?;
291
292    Ok(ScalarValue::Date32(Some(date.val())))
293}
294
295#[cfg(test)]
296mod tests {
297    use datatypes::types::{
298        TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
299        TimestampSecondType,
300    };
301    use sql::dialect::MySqlDialect;
302    use sql::parser::{ParseOptions, ParserContext};
303
304    use super::*;
305
306    #[test]
307    fn test_format_placeholder() {
308        assert_eq!("$1", format_placeholder(1));
309        assert_eq!("$3", format_placeholder(3));
310    }
311
312    #[test]
313    fn test_replace_placeholders() {
314        let create = "create table demo(host string, ts timestamp time index)";
315        let (sql, index) = replace_placeholders(create);
316        assert_eq!(create, sql);
317        assert_eq!(1, index);
318
319        let insert = "insert into demo values(?,?,?)";
320        let (sql, index) = replace_placeholders(insert);
321        assert_eq!("insert into demo values($1,$2,$3)", sql);
322        assert_eq!(4, index);
323
324        let query = "select from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?";
325        let (sql, index) = replace_placeholders(query);
326        assert_eq!(
327            "select from demo where host=$1 and idc in (select idc from idcs where name=$2) and cpu>$3",
328            sql
329        );
330        assert_eq!(4, index);
331    }
332
333    fn parse_sql(sql: &str) -> Statement {
334        let mut stmts =
335            ParserContext::create_with_dialect(sql, &MySqlDialect {}, ParseOptions::default())
336                .unwrap();
337        stmts.remove(0)
338    }
339
340    #[test]
341    fn test_transform_placeholders() {
342        let insert = parse_sql("insert into demo values(?,?,?)");
343        let Statement::Insert(insert) = transform_placeholders(insert) else {
344            unreachable!()
345        };
346        assert_eq!(
347            "INSERT INTO demo VALUES ($1, $2, $3)",
348            insert.inner.to_string()
349        );
350
351        let delete = parse_sql("delete from demo where host=? and idc=?");
352        let Statement::Delete(delete) = transform_placeholders(delete) else {
353            unreachable!()
354        };
355        assert_eq!(
356            "DELETE FROM demo WHERE host = $1 AND idc = $2",
357            delete.inner.to_string()
358        );
359
360        let select = parse_sql(
361            "select * from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?",
362        );
363        let Statement::Query(select) = transform_placeholders(select) else {
364            unreachable!()
365        };
366        assert_eq!(
367            "SELECT * FROM demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3",
368            select.inner.to_string()
369        );
370    }
371
372    #[test]
373    fn test_convert_expr_to_scalar_value() {
374        let expr = Expr::Value(ValueExpr::Number("123".to_string(), false).into());
375        let t = ConcreteDataType::int32_datatype();
376        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
377        assert_eq!(ScalarValue::Int32(Some(123)), v);
378
379        let expr = Expr::Value(ValueExpr::Number("123.456789".to_string(), false).into());
380        let t = ConcreteDataType::float64_datatype();
381        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
382        assert_eq!(ScalarValue::Float64(Some(123.456789)), v);
383
384        let expr = Expr::Value(ValueExpr::SingleQuotedString("2001-01-02".to_string()).into());
385        let t = ConcreteDataType::date_datatype();
386        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
387        let scalar_v = ScalarValue::Utf8(Some("2001-01-02".to_string()))
388            .cast_to(&arrow_schema::DataType::Date32)
389            .unwrap();
390        assert_eq!(scalar_v, v);
391
392        let expr =
393            Expr::Value(ValueExpr::SingleQuotedString("2001-01-02 03:04:05".to_string()).into());
394        let t = ConcreteDataType::timestamp_microsecond_datatype();
395        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
396        let scalar_v = ScalarValue::Utf8(Some("2001-01-02 03:04:05".to_string()))
397            .cast_to(&arrow_schema::DataType::Timestamp(
398                arrow_schema::TimeUnit::Microsecond,
399                None,
400            ))
401            .unwrap();
402        assert_eq!(scalar_v, v);
403
404        let expr = Expr::Value(ValueExpr::SingleQuotedString("hello".to_string()).into());
405        let t = ConcreteDataType::string_datatype();
406        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
407        assert_eq!(ScalarValue::Utf8(Some("hello".to_string())), v);
408
409        let expr = Expr::Value(ValueExpr::Null.into());
410        let t = ConcreteDataType::time_microsecond_datatype();
411        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
412        assert_eq!(ScalarValue::Time64Microsecond(None), v);
413    }
414
415    #[test]
416    fn test_convert_bytes_to_timestamp() {
417        let test_cases = vec![
418            // input unix timestamp in seconds -> nanosecond.
419            (
420                "2024-12-26 12:00:00",
421                TimestampType::Nanosecond(TimestampNanosecondType),
422                ScalarValue::TimestampNanosecond(Some(1735214400000000000), None),
423            ),
424            // input unix timestamp in seconds -> microsecond.
425            (
426                "2024-12-26 12:00:00",
427                TimestampType::Microsecond(TimestampMicrosecondType),
428                ScalarValue::TimestampMicrosecond(Some(1735214400000000), None),
429            ),
430            // input unix timestamp in seconds -> millisecond.
431            (
432                "2024-12-26 12:00:00",
433                TimestampType::Millisecond(TimestampMillisecondType),
434                ScalarValue::TimestampMillisecond(Some(1735214400000), None),
435            ),
436            // input unix timestamp in seconds -> second.
437            (
438                "2024-12-26 12:00:00",
439                TimestampType::Second(TimestampSecondType),
440                ScalarValue::TimestampSecond(Some(1735214400), None),
441            ),
442            // input unix timestamp in milliseconds -> nanosecond.
443            (
444                "2024-12-26 12:00:00.123",
445                TimestampType::Nanosecond(TimestampNanosecondType),
446                ScalarValue::TimestampNanosecond(Some(1735214400123000000), None),
447            ),
448            // input unix timestamp in milliseconds -> microsecond.
449            (
450                "2024-12-26 12:00:00.123",
451                TimestampType::Microsecond(TimestampMicrosecondType),
452                ScalarValue::TimestampMicrosecond(Some(1735214400123000), None),
453            ),
454            // input unix timestamp in milliseconds -> millisecond.
455            (
456                "2024-12-26 12:00:00.123",
457                TimestampType::Millisecond(TimestampMillisecondType),
458                ScalarValue::TimestampMillisecond(Some(1735214400123), None),
459            ),
460            // input unix timestamp in milliseconds -> second.
461            (
462                "2024-12-26 12:00:00.123",
463                TimestampType::Second(TimestampSecondType),
464                ScalarValue::TimestampSecond(Some(1735214400), None),
465            ),
466            // input unix timestamp in microseconds -> nanosecond.
467            (
468                "2024-12-26 12:00:00.123456",
469                TimestampType::Nanosecond(TimestampNanosecondType),
470                ScalarValue::TimestampNanosecond(Some(1735214400123456000), None),
471            ),
472            // input unix timestamp in microseconds -> microsecond.
473            (
474                "2024-12-26 12:00:00.123456",
475                TimestampType::Microsecond(TimestampMicrosecondType),
476                ScalarValue::TimestampMicrosecond(Some(1735214400123456), None),
477            ),
478            // input unix timestamp in microseconds -> millisecond.
479            (
480                "2024-12-26 12:00:00.123456",
481                TimestampType::Millisecond(TimestampMillisecondType),
482                ScalarValue::TimestampMillisecond(Some(1735214400123), None),
483            ),
484            // input unix timestamp in milliseconds -> second.
485            (
486                "2024-12-26 12:00:00.123456",
487                TimestampType::Second(TimestampSecondType),
488                ScalarValue::TimestampSecond(Some(1735214400), None),
489            ),
490        ];
491
492        for (input, ts_type, expected) in test_cases {
493            let result = convert_bytes_to_timestamp(input.as_bytes(), &ts_type).unwrap();
494            assert_eq!(result, expected);
495        }
496    }
497
498    #[test]
499    fn test_convert_bytes_to_date() {
500        let test_cases = vec![
501            // Standard date format: YYYY-MM-DD
502            ("1970-01-01", ScalarValue::Date32(Some(0))),
503            ("1969-12-31", ScalarValue::Date32(Some(-1))),
504            ("2024-02-29", ScalarValue::Date32(Some(19782))),
505            ("2024-01-01", ScalarValue::Date32(Some(19723))),
506            ("2024-12-31", ScalarValue::Date32(Some(20088))),
507            ("2001-01-02", ScalarValue::Date32(Some(11324))),
508            ("2050-06-14", ScalarValue::Date32(Some(29384))),
509            ("2020-03-15", ScalarValue::Date32(Some(18336))),
510        ];
511
512        for (input, expected) in test_cases {
513            let result = convert_bytes_to_date(input.as_bytes()).unwrap();
514            assert_eq!(result, expected, "Failed for input: {}", input);
515        }
516    }
517}