Skip to main content

servers/mysql/
helper.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::ControlFlow;
16use std::time::Duration;
17
18use chrono::NaiveDate;
19use common_query::prelude::ScalarValue;
20use common_sql::convert::sql_value_to_value;
21use common_time::{Date, Timestamp};
22use datatypes::prelude::ConcreteDataType;
23use datatypes::schema::ColumnSchema;
24use datatypes::types::TimestampType;
25use datatypes::value::{self, Value};
26#[cfg(test)]
27use itertools::Itertools;
28use opensrv_mysql::{ParamValue, ValueInner, to_naive_datetime};
29use snafu::ResultExt;
30use sql::ast::{Expr, Value as ValueExpr, ValueWithSpan, VisitMut, visit_expressions_mut};
31use sql::statements::statement::Statement;
32
33use crate::error::{self, Result};
34
35/// Location of a prepared-statement placeholder in the original SQL text.
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub(crate) struct PlaceholderSpan {
38    /// 1-based placeholder index.
39    pub(crate) index: usize,
40    pub(crate) start_line: u64,
41    pub(crate) start_column: u64,
42    pub(crate) end_line: u64,
43    pub(crate) end_column: u64,
44}
45
46/// Returns the placeholder string "$i".
47pub fn format_placeholder(i: usize) -> String {
48    format!("${}", i)
49}
50
51/// Replace all the "?" placeholder into "$i" in SQL,
52/// returns the new SQL and the last placeholder index.
53#[cfg(test)]
54pub fn replace_placeholders(query: &str) -> (String, usize) {
55    let query_parts = query.split('?').collect::<Vec<_>>();
56    let parts_len = query_parts.len();
57    let mut index = 0;
58    let query = query_parts
59        .into_iter()
60        .enumerate()
61        .map(|(i, part)| {
62            if i == parts_len - 1 {
63                return part.to_string();
64            }
65
66            index += 1;
67            format!("{part}{}", format_placeholder(index))
68        })
69        .join("");
70
71    (query, index + 1)
72}
73
74/// Transform all the "?" placeholders into "$i" and return the number of
75/// transformed placeholders.
76pub fn transform_placeholders_with_count(mut stmt: Statement) -> (Statement, usize) {
77    let count = visit_placeholders(&mut stmt);
78    (stmt, count)
79}
80
81/// Collect spans of "$i" placeholders in a statement.
82pub(crate) fn placeholder_spans(mut stmt: Statement) -> Vec<PlaceholderSpan> {
83    let mut spans = Vec::new();
84    collect_placeholder_spans(&mut stmt, &mut spans);
85    spans
86}
87
88fn collect_placeholder_spans<V>(v: &mut V, spans: &mut Vec<PlaceholderSpan>)
89where
90    V: VisitMut,
91{
92    let _ = visit_expressions_mut(v, |expr| {
93        if let Expr::Value(ValueWithSpan {
94            value: ValueExpr::Placeholder(s),
95            span,
96        }) = expr
97            && let Some(index) = placeholder_index(s)
98        {
99            spans.push(PlaceholderSpan {
100                index,
101                start_line: span.start.line,
102                start_column: span.start.column,
103                end_line: span.end.line,
104                end_column: span.end.column,
105            });
106        }
107        ControlFlow::<()>::Continue(())
108    });
109}
110
111fn placeholder_index(s: &str) -> Option<usize> {
112    s.strip_prefix('$')?
113        .parse::<usize>()
114        .ok()
115        .filter(|i| *i > 0)
116}
117
118fn visit_placeholders<V>(v: &mut V) -> usize
119where
120    V: VisitMut,
121{
122    let mut index = 1;
123    let _ = visit_expressions_mut(v, |expr| {
124        if let Expr::Value(ValueWithSpan {
125            value: ValueExpr::Placeholder(s),
126            ..
127        }) = expr
128            && s == "?"
129        {
130            *s = format_placeholder(index);
131            index += 1;
132        }
133        ControlFlow::<()>::Continue(())
134    });
135    index - 1
136}
137
138/// Convert [`ParamValue`] into [`Value`] according to param type.
139/// It will try it's best to do type conversions if possible
140pub fn convert_value(param: &ParamValue, t: &ConcreteDataType) -> Result<ScalarValue> {
141    match param.value.into_inner() {
142        ValueInner::Int(i) => match t {
143            ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(i as i8))),
144            ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(i as i16))),
145            ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(i as i32))),
146            ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(i))),
147            ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(i as u8))),
148            ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(i as u16))),
149            ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(i as u32))),
150            ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(i as u64))),
151            ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(i as f32))),
152            ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(i as f64))),
153            ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(i != 0))),
154            ConcreteDataType::Timestamp(ts_type) => Value::Timestamp(ts_type.create_timestamp(i))
155                .try_to_scalar_value(t)
156                .context(error::ConvertScalarValueSnafu),
157
158            _ => error::PreparedStmtTypeMismatchSnafu {
159                expected: t,
160                actual: param.coltype,
161            }
162            .fail(),
163        },
164        ValueInner::UInt(u) => match t {
165            ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(u as i8))),
166            ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(u as i16))),
167            ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(u as i32))),
168            ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(u as i64))),
169            ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(u as u8))),
170            ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(u as u16))),
171            ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(u as u32))),
172            ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(u))),
173            ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(u as f32))),
174            ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(u as f64))),
175            ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(u != 0))),
176            ConcreteDataType::Timestamp(ts_type) => {
177                Value::Timestamp(ts_type.create_timestamp(u as i64))
178                    .try_to_scalar_value(t)
179                    .context(error::ConvertScalarValueSnafu)
180            }
181
182            _ => error::PreparedStmtTypeMismatchSnafu {
183                expected: t,
184                actual: param.coltype,
185            }
186            .fail(),
187        },
188        ValueInner::Double(f) => match t {
189            ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(f as i8))),
190            ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(f as i16))),
191            ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(f as i32))),
192            ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(f as i64))),
193            ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(f as u8))),
194            ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(f as u16))),
195            ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(f as u32))),
196            ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(f as u64))),
197            ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(f as f32))),
198            ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(f))),
199
200            _ => error::PreparedStmtTypeMismatchSnafu {
201                expected: t,
202                actual: param.coltype,
203            }
204            .fail(),
205        },
206        ValueInner::NULL => value::to_null_scalar_value(t).context(error::ConvertScalarValueSnafu),
207        ValueInner::Bytes(b) => match t {
208            ConcreteDataType::String(t) => {
209                let s = String::from_utf8_lossy(b).to_string();
210                if t.is_large() {
211                    Ok(ScalarValue::LargeUtf8(Some(s)))
212                } else {
213                    Ok(ScalarValue::Utf8(Some(s)))
214                }
215            }
216            ConcreteDataType::Binary(_) => Ok(ScalarValue::Binary(Some(b.to_vec()))),
217            ConcreteDataType::Timestamp(ts_type) => convert_bytes_to_timestamp(b, ts_type),
218            ConcreteDataType::Date(_) => convert_bytes_to_date(b),
219            _ => error::PreparedStmtTypeMismatchSnafu {
220                expected: t,
221                actual: param.coltype,
222            }
223            .fail(),
224        },
225        ValueInner::Date(_) => {
226            let date: common_time::Date = NaiveDate::from(param.value).into();
227            Ok(ScalarValue::Date32(Some(date.val())))
228        }
229        ValueInner::Datetime(_) => {
230            let timestamp_millis = to_naive_datetime(param.value)
231                .map_err(|e| {
232                    error::MysqlValueConversionSnafu {
233                        err_msg: e.to_string(),
234                    }
235                    .build()
236                })?
237                .and_utc()
238                .timestamp_millis();
239
240            match t {
241                ConcreteDataType::Timestamp(_) => Ok(ScalarValue::TimestampMillisecond(
242                    Some(timestamp_millis),
243                    None,
244                )),
245                _ => error::PreparedStmtTypeMismatchSnafu {
246                    expected: t,
247                    actual: param.coltype,
248                }
249                .fail(),
250            }
251        }
252        ValueInner::Time(_) => Ok(ScalarValue::Time64Nanosecond(Some(
253            Duration::from(param.value).as_millis() as i64,
254        ))),
255    }
256}
257
258/// Convert an MySQL expression to a scalar value.
259/// It automatically handles the conversion of strings to numeric values.
260pub fn convert_expr_to_scalar_value(param: &Expr, t: &ConcreteDataType) -> Result<ScalarValue> {
261    let column_schema = ColumnSchema::new("", t.clone(), true);
262    match param {
263        Expr::Value(v) => {
264            let v = sql_value_to_value(&column_schema, &v.value, None, None, true);
265            match v {
266                Ok(v) => v
267                    .try_to_scalar_value(t)
268                    .context(error::ConvertScalarValueSnafu),
269                Err(e) => error::InvalidParameterSnafu {
270                    reason: e.to_string(),
271                }
272                .fail(),
273            }
274        }
275        Expr::UnaryOp { op, expr } if let Expr::Value(v) = &**expr => {
276            let v = sql_value_to_value(&column_schema, &v.value, None, Some(*op), true);
277            match v {
278                Ok(v) => v
279                    .try_to_scalar_value(t)
280                    .context(error::ConvertScalarValueSnafu),
281                Err(e) => error::InvalidParameterSnafu {
282                    reason: e.to_string(),
283                }
284                .fail(),
285            }
286        }
287        _ => error::InvalidParameterSnafu {
288            reason: format!("cannot convert {:?} to scalar value of type {}", param, t),
289        }
290        .fail(),
291    }
292}
293
294fn convert_bytes_to_timestamp(bytes: &[u8], ts_type: &TimestampType) -> Result<ScalarValue> {
295    let ts = Timestamp::from_str_utc(&String::from_utf8_lossy(bytes))
296        .map_err(|e| {
297            error::MysqlValueConversionSnafu {
298                err_msg: e.to_string(),
299            }
300            .build()
301        })?
302        .convert_to(ts_type.unit())
303        .ok_or_else(|| {
304            error::MysqlValueConversionSnafu {
305                err_msg: "Overflow when converting timestamp to target unit".to_string(),
306            }
307            .build()
308        })?;
309    match ts_type {
310        TimestampType::Nanosecond(_) => {
311            Ok(ScalarValue::TimestampNanosecond(Some(ts.value()), None))
312        }
313        TimestampType::Microsecond(_) => {
314            Ok(ScalarValue::TimestampMicrosecond(Some(ts.value()), None))
315        }
316        TimestampType::Millisecond(_) => {
317            Ok(ScalarValue::TimestampMillisecond(Some(ts.value()), None))
318        }
319        TimestampType::Second(_) => Ok(ScalarValue::TimestampSecond(Some(ts.value()), None)),
320    }
321}
322
323fn convert_bytes_to_date(bytes: &[u8]) -> Result<ScalarValue> {
324    let date = Date::from_str_utc(&String::from_utf8_lossy(bytes)).map_err(|e| {
325        error::MysqlValueConversionSnafu {
326            err_msg: e.to_string(),
327        }
328        .build()
329    })?;
330
331    Ok(ScalarValue::Date32(Some(date.val())))
332}
333
334#[cfg(test)]
335mod tests {
336    use datatypes::types::{
337        TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
338        TimestampSecondType,
339    };
340    use sql::dialect::MySqlDialect;
341    use sql::parser::{ParseOptions, ParserContext};
342
343    use super::*;
344
345    #[test]
346    fn test_format_placeholder() {
347        assert_eq!("$1", format_placeholder(1));
348        assert_eq!("$3", format_placeholder(3));
349    }
350
351    #[test]
352    fn test_replace_placeholders() {
353        let create = "create table demo(host string, ts timestamp time index)";
354        let (sql, index) = replace_placeholders(create);
355        assert_eq!(create, sql);
356        assert_eq!(1, index);
357
358        let insert = "insert into demo values(?,?,?)";
359        let (sql, index) = replace_placeholders(insert);
360        assert_eq!("insert into demo values($1,$2,$3)", sql);
361        assert_eq!(4, index);
362
363        let query = "select from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?";
364        let (sql, index) = replace_placeholders(query);
365        assert_eq!(
366            "select from demo where host=$1 and idc in (select idc from idcs where name=$2) and cpu>$3",
367            sql
368        );
369        assert_eq!(4, index);
370    }
371
372    fn parse_sql(sql: &str) -> Statement {
373        let mut stmts =
374            ParserContext::create_with_dialect(sql, &MySqlDialect {}, ParseOptions::default())
375                .unwrap();
376        stmts.remove(0)
377    }
378
379    #[test]
380    fn test_transform_placeholders() {
381        let insert = parse_sql("insert into demo values(?,?,?)");
382        let (stmt, count) = transform_placeholders_with_count(insert);
383        let Statement::Insert(insert) = stmt else {
384            unreachable!()
385        };
386        assert_eq!(
387            "INSERT INTO demo VALUES ($1, $2, $3)",
388            insert.inner.to_string()
389        );
390        assert_eq!(3, count);
391
392        let delete = parse_sql("delete from demo where host=? and idc=?");
393        let (stmt, count) = transform_placeholders_with_count(delete);
394        let Statement::Delete(delete) = stmt else {
395            unreachable!()
396        };
397        assert_eq!(
398            "DELETE FROM demo WHERE host = $1 AND idc = $2",
399            delete.inner.to_string()
400        );
401        assert_eq!(2, count);
402
403        let select = parse_sql(
404            "select * from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?",
405        );
406        let (stmt, count) = transform_placeholders_with_count(select);
407        let Statement::Query(select) = stmt else {
408            unreachable!()
409        };
410        assert_eq!(
411            "SELECT * FROM demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3",
412            select.inner.to_string()
413        );
414        assert_eq!(3, count);
415
416        let select = parse_sql("select '?', ?");
417        let (stmt, count) = transform_placeholders_with_count(select);
418        let Statement::Query(select) = stmt else {
419            unreachable!()
420        };
421        assert_eq!("SELECT '?', $1", select.inner.to_string());
422        assert_eq!(1, count);
423
424        let set = parse_sql("set time_zone = ?");
425        let (stmt, count) = transform_placeholders_with_count(set);
426        assert_eq!("SET time_zone = $1", stmt.to_string());
427        assert_eq!(1, count);
428    }
429
430    #[test]
431    fn test_convert_expr_to_scalar_value() {
432        let expr = Expr::Value(ValueExpr::Number("123".to_string(), false).into());
433        let t = ConcreteDataType::int32_datatype();
434        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
435        assert_eq!(ScalarValue::Int32(Some(123)), v);
436
437        let expr = Expr::Value(ValueExpr::Number("123.456789".to_string(), false).into());
438        let t = ConcreteDataType::float64_datatype();
439        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
440        assert_eq!(ScalarValue::Float64(Some(123.456789)), v);
441
442        let expr = Expr::Value(ValueExpr::SingleQuotedString("2001-01-02".to_string()).into());
443        let t = ConcreteDataType::date_datatype();
444        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
445        let scalar_v = ScalarValue::Utf8(Some("2001-01-02".to_string()))
446            .cast_to(&arrow_schema::DataType::Date32)
447            .unwrap();
448        assert_eq!(scalar_v, v);
449
450        let expr =
451            Expr::Value(ValueExpr::SingleQuotedString("2001-01-02 03:04:05".to_string()).into());
452        let t = ConcreteDataType::timestamp_microsecond_datatype();
453        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
454        let scalar_v = ScalarValue::Utf8(Some("2001-01-02 03:04:05".to_string()))
455            .cast_to(&arrow_schema::DataType::Timestamp(
456                arrow_schema::TimeUnit::Microsecond,
457                None,
458            ))
459            .unwrap();
460        assert_eq!(scalar_v, v);
461
462        let expr = Expr::Value(ValueExpr::SingleQuotedString("hello".to_string()).into());
463        let t = ConcreteDataType::string_datatype();
464        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
465        assert_eq!(ScalarValue::Utf8(Some("hello".to_string())), v);
466
467        let expr = Expr::Value(ValueExpr::Null.into());
468        let t = ConcreteDataType::time_microsecond_datatype();
469        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
470        assert_eq!(ScalarValue::Time64Microsecond(None), v);
471    }
472
473    #[test]
474    fn test_convert_bytes_to_timestamp() {
475        let test_cases = vec![
476            // input unix timestamp in seconds -> nanosecond.
477            (
478                "2024-12-26 12:00:00",
479                TimestampType::Nanosecond(TimestampNanosecondType),
480                ScalarValue::TimestampNanosecond(Some(1735214400000000000), None),
481            ),
482            // input unix timestamp in seconds -> microsecond.
483            (
484                "2024-12-26 12:00:00",
485                TimestampType::Microsecond(TimestampMicrosecondType),
486                ScalarValue::TimestampMicrosecond(Some(1735214400000000), None),
487            ),
488            // input unix timestamp in seconds -> millisecond.
489            (
490                "2024-12-26 12:00:00",
491                TimestampType::Millisecond(TimestampMillisecondType),
492                ScalarValue::TimestampMillisecond(Some(1735214400000), None),
493            ),
494            // input unix timestamp in seconds -> second.
495            (
496                "2024-12-26 12:00:00",
497                TimestampType::Second(TimestampSecondType),
498                ScalarValue::TimestampSecond(Some(1735214400), None),
499            ),
500            // input unix timestamp in milliseconds -> nanosecond.
501            (
502                "2024-12-26 12:00:00.123",
503                TimestampType::Nanosecond(TimestampNanosecondType),
504                ScalarValue::TimestampNanosecond(Some(1735214400123000000), None),
505            ),
506            // input unix timestamp in milliseconds -> microsecond.
507            (
508                "2024-12-26 12:00:00.123",
509                TimestampType::Microsecond(TimestampMicrosecondType),
510                ScalarValue::TimestampMicrosecond(Some(1735214400123000), None),
511            ),
512            // input unix timestamp in milliseconds -> millisecond.
513            (
514                "2024-12-26 12:00:00.123",
515                TimestampType::Millisecond(TimestampMillisecondType),
516                ScalarValue::TimestampMillisecond(Some(1735214400123), None),
517            ),
518            // input unix timestamp in milliseconds -> second.
519            (
520                "2024-12-26 12:00:00.123",
521                TimestampType::Second(TimestampSecondType),
522                ScalarValue::TimestampSecond(Some(1735214400), None),
523            ),
524            // input unix timestamp in microseconds -> nanosecond.
525            (
526                "2024-12-26 12:00:00.123456",
527                TimestampType::Nanosecond(TimestampNanosecondType),
528                ScalarValue::TimestampNanosecond(Some(1735214400123456000), None),
529            ),
530            // input unix timestamp in microseconds -> microsecond.
531            (
532                "2024-12-26 12:00:00.123456",
533                TimestampType::Microsecond(TimestampMicrosecondType),
534                ScalarValue::TimestampMicrosecond(Some(1735214400123456), None),
535            ),
536            // input unix timestamp in microseconds -> millisecond.
537            (
538                "2024-12-26 12:00:00.123456",
539                TimestampType::Millisecond(TimestampMillisecondType),
540                ScalarValue::TimestampMillisecond(Some(1735214400123), None),
541            ),
542            // input unix timestamp in milliseconds -> second.
543            (
544                "2024-12-26 12:00:00.123456",
545                TimestampType::Second(TimestampSecondType),
546                ScalarValue::TimestampSecond(Some(1735214400), None),
547            ),
548        ];
549
550        for (input, ts_type, expected) in test_cases {
551            let result = convert_bytes_to_timestamp(input.as_bytes(), &ts_type).unwrap();
552            assert_eq!(result, expected);
553        }
554    }
555
556    #[test]
557    fn test_convert_bytes_to_date() {
558        let test_cases = vec![
559            // Standard date format: YYYY-MM-DD
560            ("1970-01-01", ScalarValue::Date32(Some(0))),
561            ("1969-12-31", ScalarValue::Date32(Some(-1))),
562            ("2024-02-29", ScalarValue::Date32(Some(19782))),
563            ("2024-01-01", ScalarValue::Date32(Some(19723))),
564            ("2024-12-31", ScalarValue::Date32(Some(20088))),
565            ("2001-01-02", ScalarValue::Date32(Some(11324))),
566            ("2050-06-14", ScalarValue::Date32(Some(29384))),
567            ("2020-03-15", ScalarValue::Date32(Some(18336))),
568        ];
569
570        for (input, expected) in test_cases {
571            let result = convert_bytes_to_date(input.as_bytes()).unwrap();
572            assert_eq!(result, expected, "Failed for input: {}", input);
573        }
574    }
575}