servers/mysql/
handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::net::SocketAddr;
17use std::sync::Arc;
18use std::sync::atomic::{AtomicU32, Ordering};
19use std::time::Duration;
20
21use ::auth::{Identity, Password, UserProviderRef};
22use async_trait::async_trait;
23use chrono::{NaiveDate, NaiveDateTime};
24use common_catalog::parse_optional_catalog_and_schema_from_db_string;
25use common_error::ext::ErrorExt;
26use common_query::Output;
27use common_telemetry::{debug, error, tracing, warn};
28use datafusion_common::ParamValues;
29use datafusion_expr::LogicalPlan;
30use datatypes::prelude::ConcreteDataType;
31use itertools::Itertools;
32use opensrv_mysql::{
33    AsyncMysqlShim, Column, ErrorKind, InitWriter, ParamParser, ParamValue, QueryResultWriter,
34    StatementMetaWriter, ValueInner,
35};
36use parking_lot::RwLock;
37use query::query_engine::DescribeResult;
38use rand::RngCore;
39use session::context::{Channel, QueryContextRef};
40use session::{Session, SessionRef};
41use snafu::{ResultExt, ensure};
42use sql::dialect::MySqlDialect;
43use sql::parser::{ParseOptions, ParserContext};
44use sql::statements::statement::Statement;
45use tokio::io::AsyncWrite;
46
47use crate::SqlPlan;
48use crate::error::{self, DataFrameSnafu, InvalidPrepareStatementSnafu, Result};
49use crate::metrics::METRIC_AUTH_FAILURE;
50use crate::mysql::helper::{
51    self, fix_placeholder_types, format_placeholder, replace_placeholders, transform_placeholders,
52};
53use crate::mysql::writer;
54use crate::mysql::writer::{create_mysql_column, handle_err};
55use crate::query_handler::sql::ServerSqlQueryHandlerRef;
56
57const MYSQL_NATIVE_PASSWORD: &str = "mysql_native_password";
58const MYSQL_CLEAR_PASSWORD: &str = "mysql_clear_password";
59
60/// Parameters for the prepared statement
61enum Params<'a> {
62    /// Parameters passed through protocol
63    ProtocolParams(Vec<ParamValue<'a>>),
64    /// Parameters passed through cli
65    CliParams(Vec<sql::ast::Expr>),
66}
67
68impl Params<'_> {
69    fn len(&self) -> usize {
70        match self {
71            Params::ProtocolParams(params) => params.len(),
72            Params::CliParams(params) => params.len(),
73        }
74    }
75}
76
77// An intermediate shim for executing MySQL queries.
78pub struct MysqlInstanceShim {
79    query_handler: ServerSqlQueryHandlerRef,
80    salt: [u8; 20],
81    session: SessionRef,
82    user_provider: Option<UserProviderRef>,
83    prepared_stmts: Arc<RwLock<HashMap<String, SqlPlan>>>,
84    prepared_stmts_counter: AtomicU32,
85    process_id: u32,
86    prepared_stmt_cache_size: usize,
87}
88
89impl MysqlInstanceShim {
90    pub fn create(
91        query_handler: ServerSqlQueryHandlerRef,
92        user_provider: Option<UserProviderRef>,
93        client_addr: SocketAddr,
94        process_id: u32,
95        prepared_stmt_cache_size: usize,
96    ) -> MysqlInstanceShim {
97        // init a random salt
98        let mut bs = vec![0u8; 20];
99        let mut rng = rand::rng();
100        rng.fill_bytes(bs.as_mut());
101
102        let mut scramble: [u8; 20] = [0; 20];
103        for i in 0..20 {
104            scramble[i] = bs[i] & 0x7fu8;
105            if scramble[i] == b'\0' || scramble[i] == b'$' {
106                scramble[i] += 1;
107            }
108        }
109
110        MysqlInstanceShim {
111            query_handler,
112            salt: scramble,
113            session: Arc::new(Session::new(
114                Some(client_addr),
115                Channel::Mysql,
116                Default::default(),
117                process_id,
118            )),
119            user_provider,
120            prepared_stmts: Default::default(),
121            prepared_stmts_counter: AtomicU32::new(1),
122            process_id,
123            prepared_stmt_cache_size,
124        }
125    }
126
127    #[tracing::instrument(skip_all, name = "mysql::do_query")]
128    async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec<Result<Output>> {
129        if let Some(output) =
130            crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone())
131        {
132            vec![Ok(output)]
133        } else {
134            self.query_handler.do_query(query, query_ctx.clone()).await
135        }
136    }
137
138    /// Execute the logical plan and return the output
139    async fn do_exec_plan(
140        &self,
141        query: &str,
142        stmt: Option<Statement>,
143        plan: LogicalPlan,
144        query_ctx: QueryContextRef,
145    ) -> Result<Output> {
146        if let Some(output) =
147            crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone())
148        {
149            Ok(output)
150        } else {
151            self.query_handler.do_exec_plan(stmt, plan, query_ctx).await
152        }
153    }
154
155    /// Describe the statement
156    async fn do_describe(
157        &self,
158        statement: Statement,
159        query_ctx: QueryContextRef,
160    ) -> Result<Option<DescribeResult>> {
161        self.query_handler.do_describe(statement, query_ctx).await
162    }
163
164    /// Save query and logical plan with a given statement key
165    fn save_plan(&self, plan: SqlPlan, stmt_key: String) -> Result<()> {
166        let mut prepared_stmts = self.prepared_stmts.write();
167        let max_capacity = self.prepared_stmt_cache_size;
168
169        let is_update = prepared_stmts.contains_key(&stmt_key);
170
171        if !is_update && prepared_stmts.len() >= max_capacity {
172            return error::InternalSnafu {
173                err_msg: format!(
174                    "Prepared statement cache is full, max capacity: {}",
175                    max_capacity
176                ),
177            }
178            .fail();
179        }
180
181        let _ = prepared_stmts.insert(stmt_key, plan);
182        Ok(())
183    }
184
185    /// Retrieve the query and logical plan by a given statement key
186    fn plan(&self, stmt_key: &str) -> Option<SqlPlan> {
187        let guard = self.prepared_stmts.read();
188        guard.get(stmt_key).cloned()
189    }
190
191    /// Save the prepared statement and return the parameters and result columns
192    async fn do_prepare(
193        &mut self,
194        raw_query: &str,
195        query_ctx: QueryContextRef,
196        stmt_key: String,
197    ) -> Result<(Vec<Column>, Vec<Column>)> {
198        let (query, param_num) = replace_placeholders(raw_query);
199
200        let statement = validate_query(raw_query).await?;
201
202        // We have to transform the placeholder, because DataFusion only parses placeholders
203        // in the form of "$i", it can't process "?" right now.
204        let statement = transform_placeholders(statement);
205
206        let describe_result = self
207            .do_describe(statement.clone(), query_ctx.clone())
208            .await?;
209        let (mut plan, schema) = if let Some(DescribeResult {
210            logical_plan,
211            schema,
212        }) = describe_result
213        {
214            (Some(logical_plan), Some(schema))
215        } else {
216            (None, None)
217        };
218
219        let params = if let Some(plan) = &mut plan {
220            fix_placeholder_types(plan)?;
221            debug!("Plan after fix placeholder types: {:#?}", plan);
222            prepared_params(
223                &plan
224                    .get_parameter_types()
225                    .context(DataFrameSnafu)?
226                    .into_iter()
227                    .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
228                    .collect(),
229            )?
230        } else {
231            dummy_params(param_num)?
232        };
233
234        let columns = schema
235            .as_ref()
236            .map(|schema| {
237                schema
238                    .column_schemas()
239                    .iter()
240                    .map(|column_schema| {
241                        create_mysql_column(&column_schema.data_type, &column_schema.name)
242                    })
243                    .collect::<Result<Vec<_>>>()
244            })
245            .transpose()?
246            .unwrap_or_default();
247
248        // DataFusion may optimize the plan so that some parameters are not used.
249        if params.len() != param_num - 1 {
250            self.save_plan(
251                SqlPlan {
252                    query: query.clone(),
253                    statement: Some(statement),
254                    plan: None,
255                    schema: None,
256                },
257                stmt_key,
258            )
259            .map_err(|e| {
260                error!(e; "Failed to save prepared statement");
261                e
262            })?;
263        } else {
264            self.save_plan(
265                SqlPlan {
266                    query: query.clone(),
267                    statement: Some(statement),
268                    plan,
269                    schema,
270                },
271                stmt_key,
272            )
273            .map_err(|e| {
274                error!(e; "Failed to save prepared statement");
275                e
276            })?;
277        }
278
279        Ok((params, columns))
280    }
281
282    async fn do_execute(
283        &mut self,
284        query_ctx: QueryContextRef,
285        stmt_key: String,
286        params: Params<'_>,
287    ) -> Result<Vec<std::result::Result<Output, error::Error>>> {
288        let sql_plan = match self.plan(&stmt_key) {
289            None => {
290                return error::PrepareStatementNotFoundSnafu { name: stmt_key }.fail();
291            }
292            Some(sql_plan) => sql_plan,
293        };
294
295        let outputs = match sql_plan.plan {
296            Some(mut plan) => {
297                fix_placeholder_types(&mut plan)?;
298                let param_types = plan
299                    .get_parameter_types()
300                    .context(DataFrameSnafu)?
301                    .into_iter()
302                    .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
303                    .collect::<HashMap<_, _>>();
304
305                if params.len() != param_types.len() {
306                    return error::InternalSnafu {
307                        err_msg: "Prepare statement params number mismatch".to_string(),
308                    }
309                    .fail();
310                }
311
312                let plan = match params {
313                    Params::ProtocolParams(params) => {
314                        replace_params_with_values(&plan, param_types, &params)
315                    }
316                    Params::CliParams(params) => {
317                        replace_params_with_exprs(&plan, param_types, &params)
318                    }
319                }?;
320
321                debug!("Mysql execute prepared plan: {}", plan.display_indent());
322                vec![
323                    self.do_exec_plan(
324                        &sql_plan.query,
325                        sql_plan.statement.clone(),
326                        plan,
327                        query_ctx.clone(),
328                    )
329                    .await,
330                ]
331            }
332            None => {
333                let param_strs = match params {
334                    Params::ProtocolParams(params) => {
335                        params.iter().map(convert_param_value_to_string).collect()
336                    }
337                    Params::CliParams(params) => params.iter().map(|x| x.to_string()).collect(),
338                };
339                debug!(
340                    "do_execute Replacing with Params: {:?}, Original Query: {}",
341                    param_strs, sql_plan.query
342                );
343                let query = replace_params(param_strs, sql_plan.query);
344                debug!("Mysql execute replaced query: {}", query);
345                self.do_query(&query, query_ctx.clone()).await
346            }
347        };
348
349        Ok(outputs)
350    }
351
352    /// Remove the prepared statement by a given statement key
353    fn do_close(&mut self, stmt_key: String) {
354        let mut guard = self.prepared_stmts.write();
355        let _ = guard.remove(&stmt_key);
356    }
357
358    fn auth_plugin(&self) -> &'static str {
359        if self
360            .user_provider
361            .as_ref()
362            .map(|x| x.external())
363            .unwrap_or(false)
364        {
365            MYSQL_CLEAR_PASSWORD
366        } else {
367            MYSQL_NATIVE_PASSWORD
368        }
369    }
370}
371
372#[async_trait]
373impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShim {
374    type Error = error::Error;
375
376    fn version(&self) -> String {
377        std::env::var("GREPTIMEDB_MYSQL_SERVER_VERSION").unwrap_or_else(|_| "8.4.2".to_string())
378    }
379
380    fn connect_id(&self) -> u32 {
381        self.process_id
382    }
383
384    fn default_auth_plugin(&self) -> &str {
385        self.auth_plugin()
386    }
387
388    async fn auth_plugin_for_username(&self, _user: &[u8]) -> &'static str {
389        self.auth_plugin()
390    }
391
392    fn salt(&self) -> [u8; 20] {
393        self.salt
394    }
395
396    async fn authenticate(
397        &self,
398        auth_plugin: &str,
399        username: &[u8],
400        salt: &[u8],
401        auth_data: &[u8],
402    ) -> bool {
403        // if not specified then **greptime** will be used
404        let username = String::from_utf8_lossy(username);
405
406        let mut user_info = None;
407        let addr = self
408            .session
409            .conn_info()
410            .client_addr
411            .map(|addr| addr.to_string());
412        if let Some(user_provider) = &self.user_provider {
413            let user_id = Identity::UserId(&username, addr.as_deref());
414
415            let password = match auth_plugin {
416                MYSQL_NATIVE_PASSWORD => Password::MysqlNativePassword(auth_data, salt),
417                MYSQL_CLEAR_PASSWORD => {
418                    // The raw bytes received could be represented in C-like string, ended in '\0'.
419                    // We must "trim" it to get the real password string.
420                    let password = if let &[password @ .., 0] = &auth_data {
421                        password
422                    } else {
423                        auth_data
424                    };
425                    Password::PlainText(String::from_utf8_lossy(password).to_string().into())
426                }
427                other => {
428                    error!("Unsupported mysql auth plugin: {}", other);
429                    return false;
430                }
431            };
432            match user_provider.authenticate(user_id, password).await {
433                Ok(userinfo) => {
434                    user_info = Some(userinfo);
435                }
436                Err(e) => {
437                    METRIC_AUTH_FAILURE
438                        .with_label_values(&[e.status_code().as_ref()])
439                        .inc();
440                    warn!(e; "Failed to auth");
441                    return false;
442                }
443            };
444        }
445        let user_info =
446            user_info.unwrap_or_else(|| auth::userinfo_by_name(Some(username.to_string())));
447
448        self.session.set_user_info(user_info);
449
450        true
451    }
452
453    async fn on_prepare<'a>(
454        &'a mut self,
455        raw_query: &'a str,
456        w: StatementMetaWriter<'a, W>,
457    ) -> Result<()> {
458        let query_ctx = self.session.new_query_context();
459        let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::Relaxed);
460        let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
461        let (params, columns) = self
462            .do_prepare(raw_query, query_ctx.clone(), stmt_key)
463            .await?;
464        debug!("on_prepare: Params: {:?}, Columns: {:?}", params, columns);
465        w.reply(stmt_id, &params, &columns).await?;
466        crate::metrics::METRIC_MYSQL_PREPARED_COUNT
467            .with_label_values(&[query_ctx.get_db_string().as_str()])
468            .inc();
469        return Ok(());
470    }
471
472    async fn on_execute<'a>(
473        &'a mut self,
474        stmt_id: u32,
475        p: ParamParser<'a>,
476        w: QueryResultWriter<'a, W>,
477    ) -> Result<()> {
478        self.session.clear_warnings();
479
480        let query_ctx = self.session.new_query_context();
481        let db = query_ctx.get_db_string();
482        let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER
483            .with_label_values(&[crate::metrics::METRIC_MYSQL_BINQUERY, db.as_str()])
484            .start_timer();
485
486        let params: Vec<ParamValue> = p.into_iter().collect();
487        let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
488
489        let outputs = match self
490            .do_execute(query_ctx.clone(), stmt_key, Params::ProtocolParams(params))
491            .await
492        {
493            Ok(outputs) => outputs,
494            Err(e) => {
495                let (kind, err) = handle_err(e, query_ctx);
496                debug!(
497                    "Failed to execute prepared statement, kind: {:?}, err: {}",
498                    kind, err
499                );
500                w.error(kind, err.as_bytes()).await?;
501                return Ok(());
502            }
503        };
504
505        writer::write_output(w, query_ctx, self.session.clone(), outputs).await?;
506
507        Ok(())
508    }
509
510    async fn on_close<'a>(&'a mut self, stmt_id: u32)
511    where
512        W: 'async_trait,
513    {
514        let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
515        self.do_close(stmt_key);
516    }
517
518    #[tracing::instrument(skip_all, fields(protocol = "mysql"))]
519    async fn on_query<'a>(
520        &'a mut self,
521        query: &'a str,
522        writer: QueryResultWriter<'a, W>,
523    ) -> Result<()> {
524        let query_ctx = self.session.new_query_context();
525        let db = query_ctx.get_db_string();
526        let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER
527            .with_label_values(&[crate::metrics::METRIC_MYSQL_TEXTQUERY, db.as_str()])
528            .start_timer();
529
530        // Clear warnings for non SHOW WARNINGS queries
531        let query_upcase = query.to_uppercase();
532        if !query_upcase.starts_with("SHOW WARNINGS") {
533            self.session.clear_warnings();
534        }
535
536        if query_upcase.starts_with("PREPARE ") {
537            match ParserContext::parse_mysql_prepare_stmt(query, query_ctx.sql_dialect()) {
538                Ok((stmt_name, stmt)) => {
539                    let prepare_results =
540                        self.do_prepare(&stmt, query_ctx.clone(), stmt_name).await;
541                    match prepare_results {
542                        Ok(_) => {
543                            let outputs = vec![Ok(Output::new_with_affected_rows(0))];
544                            writer::write_output(writer, query_ctx, self.session.clone(), outputs)
545                                .await?;
546                            return Ok(());
547                        }
548                        Err(e) => {
549                            writer
550                                .error(ErrorKind::ER_SP_BADSTATEMENT, e.output_msg().as_bytes())
551                                .await?;
552                            return Ok(());
553                        }
554                    }
555                }
556                Err(e) => {
557                    writer
558                        .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
559                        .await?;
560                    return Ok(());
561                }
562            }
563        } else if query_upcase.starts_with("EXECUTE ") {
564            match ParserContext::parse_mysql_execute_stmt(query, query_ctx.sql_dialect()) {
565                Ok((stmt_name, params)) => {
566                    let outputs = match self
567                        .do_execute(query_ctx.clone(), stmt_name, Params::CliParams(params))
568                        .await
569                    {
570                        Ok(outputs) => outputs,
571                        Err(e) => {
572                            let (kind, err) = handle_err(e, query_ctx);
573                            debug!(
574                                "Failed to execute prepared statement, kind: {:?}, err: {}",
575                                kind, err
576                            );
577                            writer.error(kind, err.as_bytes()).await?;
578                            return Ok(());
579                        }
580                    };
581                    writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
582
583                    return Ok(());
584                }
585                Err(e) => {
586                    writer
587                        .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
588                        .await?;
589                    return Ok(());
590                }
591            }
592        } else if query_upcase.starts_with("DEALLOCATE ") {
593            match ParserContext::parse_mysql_deallocate_stmt(query, query_ctx.sql_dialect()) {
594                Ok(stmt_name) => {
595                    self.do_close(stmt_name);
596                    let outputs = vec![Ok(Output::new_with_affected_rows(0))];
597                    writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
598                    return Ok(());
599                }
600                Err(e) => {
601                    writer
602                        .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
603                        .await?;
604                    return Ok(());
605                }
606            }
607        }
608
609        let outputs = self.do_query(query, query_ctx.clone()).await;
610        writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
611
612        Ok(())
613    }
614
615    async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> {
616        let (catalog_from_db, schema) = parse_optional_catalog_and_schema_from_db_string(database);
617        let catalog = if let Some(catalog) = &catalog_from_db {
618            catalog.clone()
619        } else {
620            self.session.catalog()
621        };
622
623        if !self
624            .query_handler
625            .is_valid_schema(&catalog, &schema)
626            .await?
627        {
628            return w
629                .error(
630                    ErrorKind::ER_WRONG_DB_NAME,
631                    format!("Unknown database '{}'", database).as_bytes(),
632                )
633                .await
634                .map_err(|e| e.into());
635        }
636
637        let user_info = &self.session.user_info();
638
639        if let Some(schema_validator) = &self.user_provider
640            && let Err(e) = schema_validator
641                .authorize(&catalog, &schema, user_info)
642                .await
643        {
644            METRIC_AUTH_FAILURE
645                .with_label_values(&[e.status_code().as_ref()])
646                .inc();
647            return w
648                .error(
649                    ErrorKind::ER_DBACCESS_DENIED_ERROR,
650                    e.output_msg().as_bytes(),
651                )
652                .await
653                .map_err(|e| e.into());
654        }
655
656        if catalog_from_db.is_some() {
657            self.session.set_catalog(catalog)
658        }
659        self.session.set_schema(schema);
660
661        w.ok().await.map_err(|e| e.into())
662    }
663}
664
665fn convert_param_value_to_string(param: &ParamValue) -> String {
666    match param.value.into_inner() {
667        ValueInner::Int(u) => u.to_string(),
668        ValueInner::UInt(u) => u.to_string(),
669        ValueInner::Double(u) => u.to_string(),
670        ValueInner::NULL => "NULL".to_string(),
671        ValueInner::Bytes(b) => format!("'{}'", &String::from_utf8_lossy(b)),
672        ValueInner::Date(_) => format!("'{}'", NaiveDate::from(param.value)),
673        ValueInner::Datetime(_) => format!("'{}'", NaiveDateTime::from(param.value)),
674        ValueInner::Time(_) => format_duration(Duration::from(param.value)),
675    }
676}
677
678fn replace_params(params: Vec<String>, query: String) -> String {
679    let mut query = query;
680    let mut index = 1;
681    for param in params {
682        query = query.replace(&format_placeholder(index), &param);
683        index += 1;
684    }
685    query
686}
687
688fn format_duration(duration: Duration) -> String {
689    let seconds = duration.as_secs() % 60;
690    let minutes = (duration.as_secs() / 60) % 60;
691    let hours = (duration.as_secs() / 60) / 60;
692    format!("'{}:{}:{}'", hours, minutes, seconds)
693}
694
695fn replace_params_with_values(
696    plan: &LogicalPlan,
697    param_types: HashMap<String, Option<ConcreteDataType>>,
698    params: &[ParamValue],
699) -> Result<LogicalPlan> {
700    debug_assert_eq!(param_types.len(), params.len());
701
702    debug!(
703        "replace_params_with_values(param_types: {:#?}, params: {:#?}, plan: {:#?})",
704        param_types,
705        params
706            .iter()
707            .map(|x| format!("({:?}, {:?})", x.value, x.coltype))
708            .join(", "),
709        plan
710    );
711
712    let mut values = Vec::with_capacity(params.len());
713
714    for (i, param) in params.iter().enumerate() {
715        if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
716            let value = helper::convert_value(param, t)?;
717
718            values.push(value);
719        }
720    }
721
722    plan.clone()
723        .replace_params_with_values(&ParamValues::List(values.clone()))
724        .context(DataFrameSnafu)
725}
726
727fn replace_params_with_exprs(
728    plan: &LogicalPlan,
729    param_types: HashMap<String, Option<ConcreteDataType>>,
730    params: &[sql::ast::Expr],
731) -> Result<LogicalPlan> {
732    debug_assert_eq!(param_types.len(), params.len());
733
734    debug!(
735        "replace_params_with_exprs(param_types: {:#?}, params: {:#?}, plan: {:#?})",
736        param_types,
737        params.iter().map(|x| format!("({:?})", x)).join(", "),
738        plan
739    );
740
741    let mut values = Vec::with_capacity(params.len());
742
743    for (i, param) in params.iter().enumerate() {
744        if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
745            let value = helper::convert_expr_to_scalar_value(param, t)?;
746
747            values.push(value);
748        }
749    }
750
751    plan.clone()
752        .replace_params_with_values(&ParamValues::List(values.clone()))
753        .context(DataFrameSnafu)
754}
755
756async fn validate_query(query: &str) -> Result<Statement> {
757    let statement =
758        ParserContext::create_with_dialect(query, &MySqlDialect {}, ParseOptions::default());
759    let mut statement = statement.map_err(|e| {
760        InvalidPrepareStatementSnafu {
761            err_msg: e.output_msg(),
762        }
763        .build()
764    })?;
765
766    ensure!(
767        statement.len() == 1,
768        InvalidPrepareStatementSnafu {
769            err_msg: "prepare statement only support single statement".to_string(),
770        }
771    );
772
773    let statement = statement.remove(0);
774
775    Ok(statement)
776}
777
778fn dummy_params(index: usize) -> Result<Vec<Column>> {
779    let mut params = Vec::with_capacity(index - 1);
780
781    for _ in 1..index {
782        params.push(create_mysql_column(&ConcreteDataType::null_datatype(), "")?);
783    }
784
785    Ok(params)
786}
787
788/// Parameters that the client must provide when executing the prepared statement.
789fn prepared_params(param_types: &HashMap<String, Option<ConcreteDataType>>) -> Result<Vec<Column>> {
790    let mut params = Vec::with_capacity(param_types.len());
791
792    // Placeholder index starts from 1
793    for index in 1..=param_types.len() {
794        if let Some(Some(t)) = param_types.get(&format_placeholder(index)) {
795            let column = create_mysql_column(t, "")?;
796            params.push(column);
797        }
798    }
799
800    Ok(params)
801}