servers/mysql/
helper.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::ControlFlow;
16use std::time::Duration;
17
18use chrono::NaiveDate;
19use common_query::prelude::ScalarValue;
20use common_time::Timestamp;
21use datafusion_common::tree_node::{Transformed, TreeNode};
22use datafusion_expr::LogicalPlan;
23use datatypes::prelude::ConcreteDataType;
24use datatypes::types::TimestampType;
25use datatypes::value::{self, Value};
26use itertools::Itertools;
27use opensrv_mysql::{to_naive_datetime, ParamValue, ValueInner};
28use snafu::ResultExt;
29use sql::ast::{visit_expressions_mut, Expr, Value as ValueExpr, VisitMut};
30use sql::statements::sql_value_to_value;
31use sql::statements::statement::Statement;
32
33use crate::error::{self, DataFusionSnafu, Result};
34
35/// Returns the placeholder string "$i".
36pub fn format_placeholder(i: usize) -> String {
37    format!("${}", i)
38}
39
40/// Replace all the "?" placeholder into "$i" in SQL,
41/// returns the new SQL and the last placeholder index.
42pub fn replace_placeholders(query: &str) -> (String, usize) {
43    let query_parts = query.split('?').collect::<Vec<_>>();
44    let parts_len = query_parts.len();
45    let mut index = 0;
46    let query = query_parts
47        .into_iter()
48        .enumerate()
49        .map(|(i, part)| {
50            if i == parts_len - 1 {
51                return part.to_string();
52            }
53
54            index += 1;
55            format!("{part}{}", format_placeholder(index))
56        })
57        .join("");
58
59    (query, index + 1)
60}
61
62/// Transform all the "?" placeholder into "$i".
63/// Only works for Insert,Query and Delete statements.
64pub fn transform_placeholders(stmt: Statement) -> Statement {
65    match stmt {
66        Statement::Query(mut query) => {
67            visit_placeholders(&mut query.inner);
68            Statement::Query(query)
69        }
70        Statement::Insert(mut insert) => {
71            visit_placeholders(&mut insert.inner);
72            Statement::Insert(insert)
73        }
74        Statement::Delete(mut delete) => {
75            visit_placeholders(&mut delete.inner);
76            Statement::Delete(delete)
77        }
78        stmt => stmt,
79    }
80}
81
82/// Give placeholder that cast to certain type `data_type` the same data type as is cast to
83///
84/// because it seems datafusion will not give data type to placeholder if it need to be cast to certain type, still unknown if this is a feature or a bug. And if a placeholder expr have no data type, datafusion will fail to extract it using `LogicalPlan::get_parameter_types`
85pub fn fix_placeholder_types(plan: &mut LogicalPlan) -> Result<()> {
86    let give_placeholder_types = |mut e: datafusion_expr::Expr| {
87        if let datafusion_expr::Expr::Cast(cast) = &mut e {
88            if let datafusion_expr::Expr::Placeholder(ph) = &mut *cast.expr {
89                if ph.data_type.is_none() {
90                    ph.data_type = Some(cast.data_type.clone());
91                    common_telemetry::debug!(
92                        "give placeholder type {:?} to {:?}",
93                        cast.data_type,
94                        ph
95                    );
96                    Ok(Transformed::yes(e))
97                } else {
98                    Ok(Transformed::no(e))
99                }
100            } else {
101                Ok(Transformed::no(e))
102            }
103        } else {
104            Ok(Transformed::no(e))
105        }
106    };
107    let give_placeholder_types_recursively =
108        |e: datafusion_expr::Expr| e.transform(give_placeholder_types);
109    *plan = std::mem::take(plan)
110        .transform(|p| p.map_expressions(give_placeholder_types_recursively))
111        .context(DataFusionSnafu)?
112        .data;
113    Ok(())
114}
115
116fn visit_placeholders<V>(v: &mut V)
117where
118    V: VisitMut,
119{
120    let mut index = 1;
121    let _ = visit_expressions_mut(v, |expr| {
122        if let Expr::Value(ValueExpr::Placeholder(s)) = expr {
123            *s = format_placeholder(index);
124            index += 1;
125        }
126        ControlFlow::<()>::Continue(())
127    });
128}
129
130/// Convert [`ParamValue`] into [`Value`] according to param type.
131/// It will try it's best to do type conversions if possible
132pub fn convert_value(param: &ParamValue, t: &ConcreteDataType) -> Result<ScalarValue> {
133    match param.value.into_inner() {
134        ValueInner::Int(i) => match t {
135            ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(i as i8))),
136            ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(i as i16))),
137            ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(i as i32))),
138            ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(i))),
139            ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(i as u8))),
140            ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(i as u16))),
141            ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(i as u32))),
142            ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(i as u64))),
143            ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(i as f32))),
144            ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(i as f64))),
145            ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(i != 0))),
146            ConcreteDataType::Timestamp(ts_type) => Value::Timestamp(ts_type.create_timestamp(i))
147                .try_to_scalar_value(t)
148                .context(error::ConvertScalarValueSnafu),
149
150            _ => error::PreparedStmtTypeMismatchSnafu {
151                expected: t,
152                actual: param.coltype,
153            }
154            .fail(),
155        },
156        ValueInner::UInt(u) => match t {
157            ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(u as i8))),
158            ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(u as i16))),
159            ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(u as i32))),
160            ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(u as i64))),
161            ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(u as u8))),
162            ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(u as u16))),
163            ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(u as u32))),
164            ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(u))),
165            ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(u as f32))),
166            ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(u as f64))),
167            ConcreteDataType::Boolean(_) => Ok(ScalarValue::Boolean(Some(u != 0))),
168            ConcreteDataType::Timestamp(ts_type) => {
169                Value::Timestamp(ts_type.create_timestamp(u as i64))
170                    .try_to_scalar_value(t)
171                    .context(error::ConvertScalarValueSnafu)
172            }
173
174            _ => error::PreparedStmtTypeMismatchSnafu {
175                expected: t,
176                actual: param.coltype,
177            }
178            .fail(),
179        },
180        ValueInner::Double(f) => match t {
181            ConcreteDataType::Int8(_) => Ok(ScalarValue::Int8(Some(f as i8))),
182            ConcreteDataType::Int16(_) => Ok(ScalarValue::Int16(Some(f as i16))),
183            ConcreteDataType::Int32(_) => Ok(ScalarValue::Int32(Some(f as i32))),
184            ConcreteDataType::Int64(_) => Ok(ScalarValue::Int64(Some(f as i64))),
185            ConcreteDataType::UInt8(_) => Ok(ScalarValue::UInt8(Some(f as u8))),
186            ConcreteDataType::UInt16(_) => Ok(ScalarValue::UInt16(Some(f as u16))),
187            ConcreteDataType::UInt32(_) => Ok(ScalarValue::UInt32(Some(f as u32))),
188            ConcreteDataType::UInt64(_) => Ok(ScalarValue::UInt64(Some(f as u64))),
189            ConcreteDataType::Float32(_) => Ok(ScalarValue::Float32(Some(f as f32))),
190            ConcreteDataType::Float64(_) => Ok(ScalarValue::Float64(Some(f))),
191
192            _ => error::PreparedStmtTypeMismatchSnafu {
193                expected: t,
194                actual: param.coltype,
195            }
196            .fail(),
197        },
198        ValueInner::NULL => value::to_null_scalar_value(t).context(error::ConvertScalarValueSnafu),
199        ValueInner::Bytes(b) => match t {
200            ConcreteDataType::String(_) => Ok(ScalarValue::Utf8(Some(
201                String::from_utf8_lossy(b).to_string(),
202            ))),
203            ConcreteDataType::Binary(_) => Ok(ScalarValue::Binary(Some(b.to_vec()))),
204            ConcreteDataType::Timestamp(ts_type) => covert_bytes_to_timestamp(b, ts_type),
205            _ => error::PreparedStmtTypeMismatchSnafu {
206                expected: t,
207                actual: param.coltype,
208            }
209            .fail(),
210        },
211        ValueInner::Date(_) => {
212            let date: common_time::Date = NaiveDate::from(param.value).into();
213            Ok(ScalarValue::Date32(Some(date.val())))
214        }
215        ValueInner::Datetime(_) => {
216            let timestamp_millis = to_naive_datetime(param.value)
217                .map_err(|e| {
218                    error::MysqlValueConversionSnafu {
219                        err_msg: e.to_string(),
220                    }
221                    .build()
222                })?
223                .and_utc()
224                .timestamp_millis();
225
226            match t {
227                ConcreteDataType::Timestamp(_) => Ok(ScalarValue::TimestampMillisecond(
228                    Some(timestamp_millis),
229                    None,
230                )),
231                _ => error::PreparedStmtTypeMismatchSnafu {
232                    expected: t,
233                    actual: param.coltype,
234                }
235                .fail(),
236            }
237        }
238        ValueInner::Time(_) => Ok(ScalarValue::Time64Nanosecond(Some(
239            Duration::from(param.value).as_millis() as i64,
240        ))),
241    }
242}
243
244/// Convert an MySQL expression to a scalar value.
245/// It automatically handles the conversion of strings to numeric values.
246pub fn convert_expr_to_scalar_value(param: &Expr, t: &ConcreteDataType) -> Result<ScalarValue> {
247    match param {
248        Expr::Value(v) => {
249            let v = sql_value_to_value("", t, v, None, None, true);
250            match v {
251                Ok(v) => v
252                    .try_to_scalar_value(t)
253                    .context(error::ConvertScalarValueSnafu),
254                Err(e) => error::InvalidParameterSnafu {
255                    reason: e.to_string(),
256                }
257                .fail(),
258            }
259        }
260        Expr::UnaryOp { op, expr } if let Expr::Value(v) = &**expr => {
261            let v = sql_value_to_value("", t, v, None, Some(*op), true);
262            match v {
263                Ok(v) => v
264                    .try_to_scalar_value(t)
265                    .context(error::ConvertScalarValueSnafu),
266                Err(e) => error::InvalidParameterSnafu {
267                    reason: e.to_string(),
268                }
269                .fail(),
270            }
271        }
272        _ => error::InvalidParameterSnafu {
273            reason: format!("cannot convert {:?} to scalar value of type {}", param, t),
274        }
275        .fail(),
276    }
277}
278
279fn covert_bytes_to_timestamp(bytes: &[u8], ts_type: &TimestampType) -> Result<ScalarValue> {
280    let ts = Timestamp::from_str_utc(&String::from_utf8_lossy(bytes))
281        .map_err(|e| {
282            error::MysqlValueConversionSnafu {
283                err_msg: e.to_string(),
284            }
285            .build()
286        })?
287        .convert_to(ts_type.unit())
288        .ok_or_else(|| {
289            error::MysqlValueConversionSnafu {
290                err_msg: "Overflow when converting timestamp to target unit".to_string(),
291            }
292            .build()
293        })?;
294    match ts_type {
295        TimestampType::Nanosecond(_) => {
296            Ok(ScalarValue::TimestampNanosecond(Some(ts.value()), None))
297        }
298        TimestampType::Microsecond(_) => {
299            Ok(ScalarValue::TimestampMicrosecond(Some(ts.value()), None))
300        }
301        TimestampType::Millisecond(_) => {
302            Ok(ScalarValue::TimestampMillisecond(Some(ts.value()), None))
303        }
304        TimestampType::Second(_) => Ok(ScalarValue::TimestampSecond(Some(ts.value()), None)),
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use datatypes::types::{
311        TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
312        TimestampSecondType,
313    };
314    use sql::dialect::MySqlDialect;
315    use sql::parser::{ParseOptions, ParserContext};
316
317    use super::*;
318
319    #[test]
320    fn test_format_placeholder() {
321        assert_eq!("$1", format_placeholder(1));
322        assert_eq!("$3", format_placeholder(3));
323    }
324
325    #[test]
326    fn test_replace_placeholders() {
327        let create = "create table demo(host string, ts timestamp time index)";
328        let (sql, index) = replace_placeholders(create);
329        assert_eq!(create, sql);
330        assert_eq!(1, index);
331
332        let insert = "insert into demo values(?,?,?)";
333        let (sql, index) = replace_placeholders(insert);
334        assert_eq!("insert into demo values($1,$2,$3)", sql);
335        assert_eq!(4, index);
336
337        let query = "select from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?";
338        let (sql, index) = replace_placeholders(query);
339        assert_eq!("select from demo where host=$1 and idc in (select idc from idcs where name=$2) and cpu>$3", sql);
340        assert_eq!(4, index);
341    }
342
343    fn parse_sql(sql: &str) -> Statement {
344        let mut stmts =
345            ParserContext::create_with_dialect(sql, &MySqlDialect {}, ParseOptions::default())
346                .unwrap();
347        stmts.remove(0)
348    }
349
350    #[test]
351    fn test_transform_placeholders() {
352        let insert = parse_sql("insert into demo values(?,?,?)");
353        let Statement::Insert(insert) = transform_placeholders(insert) else {
354            unreachable!()
355        };
356        assert_eq!(
357            "INSERT INTO demo VALUES ($1, $2, $3)",
358            insert.inner.to_string()
359        );
360
361        let delete = parse_sql("delete from demo where host=? and idc=?");
362        let Statement::Delete(delete) = transform_placeholders(delete) else {
363            unreachable!()
364        };
365        assert_eq!(
366            "DELETE FROM demo WHERE host = $1 AND idc = $2",
367            delete.inner.to_string()
368        );
369
370        let select = parse_sql("select * from demo where host=? and idc in (select idc from idcs where name=?) and cpu>?");
371        let Statement::Query(select) = transform_placeholders(select) else {
372            unreachable!()
373        };
374        assert_eq!("SELECT * FROM demo WHERE host = $1 AND idc IN (SELECT idc FROM idcs WHERE name = $2) AND cpu > $3", select.inner.to_string());
375    }
376
377    #[test]
378    fn test_convert_expr_to_scalar_value() {
379        let expr = Expr::Value(ValueExpr::Number("123".to_string(), false));
380        let t = ConcreteDataType::int32_datatype();
381        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
382        assert_eq!(ScalarValue::Int32(Some(123)), v);
383
384        let expr = Expr::Value(ValueExpr::Number("123.456789".to_string(), false));
385        let t = ConcreteDataType::float64_datatype();
386        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
387        assert_eq!(ScalarValue::Float64(Some(123.456789)), v);
388
389        let expr = Expr::Value(ValueExpr::SingleQuotedString("2001-01-02".to_string()));
390        let t = ConcreteDataType::date_datatype();
391        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
392        let scalar_v = ScalarValue::Utf8(Some("2001-01-02".to_string()))
393            .cast_to(&arrow_schema::DataType::Date32)
394            .unwrap();
395        assert_eq!(scalar_v, v);
396
397        let expr = Expr::Value(ValueExpr::SingleQuotedString(
398            "2001-01-02 03:04:05".to_string(),
399        ));
400        let t = ConcreteDataType::timestamp_microsecond_datatype();
401        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
402        let scalar_v = ScalarValue::Utf8(Some("2001-01-02 03:04:05".to_string()))
403            .cast_to(&arrow_schema::DataType::Timestamp(
404                arrow_schema::TimeUnit::Microsecond,
405                None,
406            ))
407            .unwrap();
408        assert_eq!(scalar_v, v);
409
410        let expr = Expr::Value(ValueExpr::SingleQuotedString("hello".to_string()));
411        let t = ConcreteDataType::string_datatype();
412        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
413        assert_eq!(ScalarValue::Utf8(Some("hello".to_string())), v);
414
415        let expr = Expr::Value(ValueExpr::Null);
416        let t = ConcreteDataType::time_microsecond_datatype();
417        let v = convert_expr_to_scalar_value(&expr, &t).unwrap();
418        assert_eq!(ScalarValue::Time64Microsecond(None), v);
419    }
420
421    #[test]
422    fn test_convert_bytes_to_timestamp() {
423        let test_cases = vec![
424            // input unix timestamp in seconds -> nanosecond.
425            (
426                "2024-12-26 12:00:00",
427                TimestampType::Nanosecond(TimestampNanosecondType),
428                ScalarValue::TimestampNanosecond(Some(1735214400000000000), None),
429            ),
430            // input unix timestamp in seconds -> microsecond.
431            (
432                "2024-12-26 12:00:00",
433                TimestampType::Microsecond(TimestampMicrosecondType),
434                ScalarValue::TimestampMicrosecond(Some(1735214400000000), None),
435            ),
436            // input unix timestamp in seconds -> millisecond.
437            (
438                "2024-12-26 12:00:00",
439                TimestampType::Millisecond(TimestampMillisecondType),
440                ScalarValue::TimestampMillisecond(Some(1735214400000), None),
441            ),
442            // input unix timestamp in seconds -> second.
443            (
444                "2024-12-26 12:00:00",
445                TimestampType::Second(TimestampSecondType),
446                ScalarValue::TimestampSecond(Some(1735214400), None),
447            ),
448            // input unix timestamp in milliseconds -> nanosecond.
449            (
450                "2024-12-26 12:00:00.123",
451                TimestampType::Nanosecond(TimestampNanosecondType),
452                ScalarValue::TimestampNanosecond(Some(1735214400123000000), None),
453            ),
454            // input unix timestamp in milliseconds -> microsecond.
455            (
456                "2024-12-26 12:00:00.123",
457                TimestampType::Microsecond(TimestampMicrosecondType),
458                ScalarValue::TimestampMicrosecond(Some(1735214400123000), None),
459            ),
460            // input unix timestamp in milliseconds -> millisecond.
461            (
462                "2024-12-26 12:00:00.123",
463                TimestampType::Millisecond(TimestampMillisecondType),
464                ScalarValue::TimestampMillisecond(Some(1735214400123), None),
465            ),
466            // input unix timestamp in milliseconds -> second.
467            (
468                "2024-12-26 12:00:00.123",
469                TimestampType::Second(TimestampSecondType),
470                ScalarValue::TimestampSecond(Some(1735214400), None),
471            ),
472            // input unix timestamp in microseconds -> nanosecond.
473            (
474                "2024-12-26 12:00:00.123456",
475                TimestampType::Nanosecond(TimestampNanosecondType),
476                ScalarValue::TimestampNanosecond(Some(1735214400123456000), None),
477            ),
478            // input unix timestamp in microseconds -> microsecond.
479            (
480                "2024-12-26 12:00:00.123456",
481                TimestampType::Microsecond(TimestampMicrosecondType),
482                ScalarValue::TimestampMicrosecond(Some(1735214400123456), None),
483            ),
484            // input unix timestamp in microseconds -> millisecond.
485            (
486                "2024-12-26 12:00:00.123456",
487                TimestampType::Millisecond(TimestampMillisecondType),
488                ScalarValue::TimestampMillisecond(Some(1735214400123), None),
489            ),
490            // input unix timestamp in milliseconds -> second.
491            (
492                "2024-12-26 12:00:00.123456",
493                TimestampType::Second(TimestampSecondType),
494                ScalarValue::TimestampSecond(Some(1735214400), None),
495            ),
496        ];
497
498        for (input, ts_type, expected) in test_cases {
499            let result = covert_bytes_to_timestamp(input.as_bytes(), &ts_type).unwrap();
500            assert_eq!(result, expected);
501        }
502    }
503}