Skip to main content

servers/mysql/
handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::net::SocketAddr;
17use std::sync::Arc;
18use std::sync::atomic::{AtomicU32, Ordering};
19use std::time::Duration;
20
21use ::auth::{Identity, Password, UserProviderRef};
22use async_trait::async_trait;
23use chrono::{NaiveDate, NaiveDateTime};
24use common_catalog::parse_optional_catalog_and_schema_from_db_string;
25use common_error::ext::ErrorExt;
26use common_query::Output;
27use common_telemetry::{debug, error, tracing, warn};
28use datafusion_common::ParamValues;
29use datafusion_expr::LogicalPlan;
30use datatypes::prelude::ConcreteDataType;
31use datatypes::schema::Schema;
32use itertools::Itertools;
33use opensrv_mysql::{
34    AsyncMysqlShim, Column, ErrorKind, InitWriter, ParamParser, ParamValue, QueryResultWriter,
35    StatementMetaWriter, ValueInner,
36};
37use parking_lot::RwLock;
38use query::planner::DfLogicalPlanner;
39use query::query_engine::DescribeResult;
40use rand::RngCore;
41use session::context::{Channel, QueryContextRef};
42use session::{Session, SessionRef};
43use snafu::{ResultExt, ensure};
44use sql::dialect::MySqlDialect;
45use sql::parser::{ParseOptions, ParserContext};
46use sql::statements::statement::Statement;
47use tokio::io::AsyncWrite;
48
49use crate::SqlPlan;
50use crate::error::{
51    self, DataFrameSnafu, InferParameterTypesSnafu, InvalidPrepareStatementSnafu, Result,
52};
53use crate::metrics::METRIC_AUTH_FAILURE;
54use crate::mysql::helper::{
55    self, format_placeholder, replace_placeholders, transform_placeholders,
56};
57use crate::mysql::writer;
58use crate::mysql::writer::{create_mysql_column, handle_err};
59use crate::query_handler::sql::ServerSqlQueryHandlerRef;
60
61const MYSQL_NATIVE_PASSWORD: &str = "mysql_native_password";
62const MYSQL_CLEAR_PASSWORD: &str = "mysql_clear_password";
63
64/// Parameters for the prepared statement
65enum Params<'a> {
66    /// Parameters passed through protocol
67    ProtocolParams(Vec<ParamValue<'a>>),
68    /// Parameters passed through cli
69    CliParams(Vec<sql::ast::Expr>),
70}
71
72impl Params<'_> {
73    fn len(&self) -> usize {
74        match self {
75            Params::ProtocolParams(params) => params.len(),
76            Params::CliParams(params) => params.len(),
77        }
78    }
79}
80
81// An intermediate shim for executing MySQL queries.
82pub struct MysqlInstanceShim {
83    query_handler: ServerSqlQueryHandlerRef,
84    salt: [u8; 20],
85    session: SessionRef,
86    user_provider: Option<UserProviderRef>,
87    prepared_stmts: Arc<RwLock<HashMap<String, SqlPlan>>>,
88    prepared_stmts_counter: AtomicU32,
89    process_id: u32,
90    prepared_stmt_cache_size: usize,
91}
92
93impl MysqlInstanceShim {
94    pub fn create(
95        query_handler: ServerSqlQueryHandlerRef,
96        user_provider: Option<UserProviderRef>,
97        client_addr: SocketAddr,
98        process_id: u32,
99        prepared_stmt_cache_size: usize,
100    ) -> MysqlInstanceShim {
101        // init a random salt
102        let mut bs = vec![0u8; 20];
103        let mut rng = rand::rng();
104        rng.fill_bytes(bs.as_mut());
105
106        let mut scramble: [u8; 20] = [0; 20];
107        for i in 0..20 {
108            scramble[i] = bs[i] & 0x7fu8;
109            if scramble[i] == b'\0' || scramble[i] == b'$' {
110                scramble[i] += 1;
111            }
112        }
113
114        MysqlInstanceShim {
115            query_handler,
116            salt: scramble,
117            session: Arc::new(Session::new(
118                Some(client_addr),
119                Channel::Mysql,
120                Default::default(),
121                process_id,
122            )),
123            user_provider,
124            prepared_stmts: Default::default(),
125            prepared_stmts_counter: AtomicU32::new(1),
126            process_id,
127            prepared_stmt_cache_size,
128        }
129    }
130
131    #[tracing::instrument(skip_all, name = "mysql::do_query")]
132    async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec<Result<Output>> {
133        if let Some(output) =
134            crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone())
135        {
136            vec![Ok(output)]
137        } else {
138            self.query_handler.do_query(query, query_ctx.clone()).await
139        }
140    }
141
142    /// Describe the statement
143    async fn do_describe(
144        &self,
145        statement: Statement,
146        query_ctx: QueryContextRef,
147    ) -> Result<Option<DescribeResult>> {
148        self.query_handler.do_describe(statement, query_ctx).await
149    }
150
151    /// Save query and logical plan with a given statement key
152    fn save_plan(&self, plan: SqlPlan, stmt_key: String) -> Result<()> {
153        let mut prepared_stmts = self.prepared_stmts.write();
154        let max_capacity = self.prepared_stmt_cache_size;
155
156        let is_update = prepared_stmts.contains_key(&stmt_key);
157
158        if !is_update && prepared_stmts.len() >= max_capacity {
159            return error::InternalSnafu {
160                err_msg: format!(
161                    "Prepared statement cache is full, max capacity: {}",
162                    max_capacity
163                ),
164            }
165            .fail();
166        }
167
168        let _ = prepared_stmts.insert(stmt_key, plan);
169        Ok(())
170    }
171
172    /// Retrieve the query and logical plan by a given statement key
173    fn plan(&self, stmt_key: &str) -> Option<SqlPlan> {
174        let guard = self.prepared_stmts.read();
175        guard.get(stmt_key).cloned()
176    }
177
178    /// Save the prepared statement and return the parameters and result columns
179    async fn do_prepare(
180        &mut self,
181        raw_query: &str,
182        query_ctx: QueryContextRef,
183        stmt_key: String,
184    ) -> Result<(Vec<Column>, Vec<Column>)> {
185        if crate::mysql::federated::check(raw_query, query_ctx.clone(), self.session.clone())
186            .is_some()
187        {
188            self.save_plan(SqlPlan::Shortcut(raw_query.to_string()), stmt_key)
189                .inspect_err(|e| {
190                    error!(e; "Failed to save prepared statement");
191                })?;
192            return Ok((vec![], vec![]));
193        }
194
195        let (query, param_num) = replace_placeholders(raw_query);
196
197        let statement = validate_query(raw_query).await?;
198
199        // We have to transform the placeholder, because DataFusion only parses placeholders
200        // in the form of "$i", it can't process "?" right now.
201        let statement = transform_placeholders(statement);
202
203        let describe_result = self
204            .do_describe(statement.clone(), query_ctx.clone())
205            .await?;
206        let plan = describe_result.map(|DescribeResult { logical_plan }| logical_plan);
207
208        let params = if let Some(plan) = &plan {
209            let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
210                .context(InferParameterTypesSnafu)?
211                .into_iter()
212                .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
213                .collect();
214            prepared_params(&param_types)?
215        } else {
216            dummy_params(param_num)?
217        };
218
219        let columns =
220            plan.as_ref()
221                .map(|plan| {
222                    let schema: Schema = plan.schema().clone().try_into().map_err(
223                        |e: datatypes::error::Error| {
224                            error::InternalSnafu {
225                                err_msg: e.to_string(),
226                            }
227                            .build()
228                        },
229                    )?;
230                    schema
231                        .column_schemas()
232                        .iter()
233                        .map(|column_schema| {
234                            create_mysql_column(&column_schema.data_type, &column_schema.name)
235                        })
236                        .collect::<Result<Vec<_>>>()
237                })
238                .transpose()?
239                .unwrap_or_default();
240
241        match plan {
242            Some(plan) if params.len() == param_num - 1 => {
243                self.save_plan(SqlPlan::Plan(plan, query.clone()), stmt_key)
244                    .inspect_err(|e| {
245                        error!(e; "Failed to save prepared statement");
246                    })?;
247            }
248            _ => {
249                self.save_plan(SqlPlan::Statement(statement, query), stmt_key)
250                    .inspect_err(|e| {
251                        error!(e; "Failed to save prepared statement");
252                    })?;
253            }
254        }
255
256        Ok((params, columns))
257    }
258
259    async fn do_execute(
260        &mut self,
261        query_ctx: QueryContextRef,
262        stmt_key: String,
263        params: Params<'_>,
264    ) -> Result<Vec<std::result::Result<Output, error::Error>>> {
265        let sql_plan = match self.plan(&stmt_key) {
266            None => {
267                return error::PrepareStatementNotFoundSnafu { name: stmt_key }.fail();
268            }
269            Some(sql_plan) => sql_plan,
270        };
271
272        let outputs = match sql_plan {
273            SqlPlan::Plan(plan, query) => {
274                let param_types = DfLogicalPlanner::get_inferred_parameter_types(&plan)
275                    .context(InferParameterTypesSnafu)?
276                    .into_iter()
277                    .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
278                    .collect::<HashMap<_, _>>();
279
280                if params.len() != param_types.len() {
281                    return error::InternalSnafu {
282                        err_msg: "Prepare statement params number mismatch".to_string(),
283                    }
284                    .fail();
285                }
286
287                let replaced_plan = match params {
288                    Params::ProtocolParams(params) => {
289                        replace_params_with_values(&plan, param_types, &params)
290                    }
291                    Params::CliParams(params) => {
292                        replace_params_with_exprs(&plan, param_types, &params)
293                    }
294                }?;
295
296                debug!(
297                    "Mysql execute prepared plan: {}",
298                    replaced_plan.display_indent()
299                );
300                vec![
301                    self.query_handler
302                        .do_exec_plan(replaced_plan, query, query_ctx.clone())
303                        .await,
304                ]
305            }
306            SqlPlan::Shortcut(query) => {
307                if let Some(output) =
308                    crate::mysql::federated::check(&query, query_ctx.clone(), self.session.clone())
309                {
310                    vec![Ok(output)]
311                } else {
312                    self.do_query(&query, query_ctx.clone()).await
313                }
314            }
315            SqlPlan::Statement(_stmt, query) => {
316                let param_strs = match params {
317                    Params::ProtocolParams(params) => {
318                        params.iter().map(convert_param_value_to_string).collect()
319                    }
320                    Params::CliParams(params) => params.iter().map(|x| x.to_string()).collect(),
321                };
322                debug!(
323                    "do_execute Replacing with Params: {:?}, Original Query: {}",
324                    param_strs, query
325                );
326                let query = replace_params(param_strs, query);
327                debug!("Mysql execute replaced query: {}", query);
328                self.do_query(&query, query_ctx.clone()).await
329            }
330            _ => {
331                return error::PrepareStatementNotFoundSnafu { name: stmt_key }.fail();
332            }
333        };
334
335        Ok(outputs)
336    }
337
338    /// Remove the prepared statement by a given statement key
339    fn do_close(&mut self, stmt_key: String) {
340        let mut guard = self.prepared_stmts.write();
341        let _ = guard.remove(&stmt_key);
342    }
343
344    fn auth_plugin(&self) -> &'static str {
345        if self
346            .user_provider
347            .as_ref()
348            .map(|x| x.external())
349            .unwrap_or(false)
350        {
351            MYSQL_CLEAR_PASSWORD
352        } else {
353            MYSQL_NATIVE_PASSWORD
354        }
355    }
356}
357
358#[async_trait]
359impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShim {
360    type Error = error::Error;
361
362    fn version(&self) -> String {
363        std::env::var("GREPTIMEDB_MYSQL_SERVER_VERSION").unwrap_or_else(|_| "8.4.2".to_string())
364    }
365
366    fn connect_id(&self) -> u32 {
367        self.process_id
368    }
369
370    fn default_auth_plugin(&self) -> &str {
371        self.auth_plugin()
372    }
373
374    async fn auth_plugin_for_username(&self, _user: &[u8]) -> &'static str {
375        self.auth_plugin()
376    }
377
378    fn salt(&self) -> [u8; 20] {
379        self.salt
380    }
381
382    async fn authenticate(
383        &self,
384        auth_plugin: &str,
385        username: &[u8],
386        salt: &[u8],
387        auth_data: &[u8],
388    ) -> bool {
389        // if not specified then **greptime** will be used
390        let username = String::from_utf8_lossy(username);
391
392        let mut user_info = None;
393        let addr = self
394            .session
395            .conn_info()
396            .client_addr
397            .map(|addr| addr.to_string());
398        if let Some(user_provider) = &self.user_provider {
399            let user_id = Identity::UserId(&username, addr.as_deref());
400
401            let password = match auth_plugin {
402                MYSQL_NATIVE_PASSWORD => Password::MysqlNativePassword(auth_data, salt),
403                MYSQL_CLEAR_PASSWORD => {
404                    // The raw bytes received could be represented in C-like string, ended in '\0'.
405                    // We must "trim" it to get the real password string.
406                    let password = if let &[password @ .., 0] = &auth_data {
407                        password
408                    } else {
409                        auth_data
410                    };
411                    Password::PlainText(String::from_utf8_lossy(password).to_string().into())
412                }
413                other => {
414                    error!("Unsupported mysql auth plugin: {}", other);
415                    return false;
416                }
417            };
418            match user_provider.authenticate(user_id, password).await {
419                Ok(userinfo) => {
420                    user_info = Some(userinfo);
421                }
422                Err(e) => {
423                    METRIC_AUTH_FAILURE
424                        .with_label_values(&[e.status_code().as_ref()])
425                        .inc();
426                    warn!(e; "Failed to auth");
427                    return false;
428                }
429            };
430        }
431        let user_info =
432            user_info.unwrap_or_else(|| auth::userinfo_by_name(Some(username.to_string())));
433
434        self.session.set_user_info(user_info);
435
436        true
437    }
438
439    async fn on_prepare<'a>(
440        &'a mut self,
441        raw_query: &'a str,
442        w: StatementMetaWriter<'a, W>,
443    ) -> Result<()> {
444        let query_ctx = self.session.new_query_context();
445        let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::Relaxed);
446        let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
447        let (params, columns) = match self
448            .do_prepare(raw_query, query_ctx.clone(), stmt_key)
449            .await
450        {
451            Ok(x) => x,
452            Err(e) => {
453                let (kind, msg) = handle_err(e, query_ctx.clone());
454                w.error(kind, msg.as_bytes()).await?;
455                return Ok(());
456            }
457        };
458        debug!("on_prepare: Params: {:?}, Columns: {:?}", params, columns);
459        w.reply(stmt_id, &params, &columns).await?;
460        crate::metrics::METRIC_MYSQL_PREPARED_COUNT
461            .with_label_values(&[query_ctx.get_db_string().as_str()])
462            .inc();
463        return Ok(());
464    }
465
466    async fn on_execute<'a>(
467        &'a mut self,
468        stmt_id: u32,
469        p: ParamParser<'a>,
470        w: QueryResultWriter<'a, W>,
471    ) -> Result<()> {
472        self.session.clear_warnings();
473
474        let query_ctx = self.session.new_query_context();
475        let db = query_ctx.get_db_string();
476        let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER
477            .with_label_values(&[crate::metrics::METRIC_MYSQL_BINQUERY, db.as_str()])
478            .start_timer();
479
480        let params: Vec<ParamValue> = p.into_iter().collect();
481        let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
482
483        let outputs = match self
484            .do_execute(query_ctx.clone(), stmt_key, Params::ProtocolParams(params))
485            .await
486        {
487            Ok(outputs) => outputs,
488            Err(e) => {
489                let (kind, err) = handle_err(e, query_ctx);
490                debug!(
491                    "Failed to execute prepared statement, kind: {:?}, err: {}",
492                    kind, err
493                );
494                w.error(kind, err.as_bytes()).await?;
495                return Ok(());
496            }
497        };
498
499        writer::write_output(w, query_ctx, self.session.clone(), outputs).await?;
500
501        Ok(())
502    }
503
504    async fn on_close<'a>(&'a mut self, stmt_id: u32)
505    where
506        W: 'async_trait,
507    {
508        let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
509        self.do_close(stmt_key);
510    }
511
512    #[tracing::instrument(skip_all, fields(protocol = "mysql"))]
513    async fn on_query<'a>(
514        &'a mut self,
515        query: &'a str,
516        writer: QueryResultWriter<'a, W>,
517    ) -> Result<()> {
518        let query_ctx = self.session.new_query_context();
519        let db = query_ctx.get_db_string();
520        let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER
521            .with_label_values(&[crate::metrics::METRIC_MYSQL_TEXTQUERY, db.as_str()])
522            .start_timer();
523
524        // Clear warnings for non SHOW WARNINGS queries
525        let query_upcase = query.to_uppercase();
526        if !query_upcase.starts_with("SHOW WARNINGS") {
527            self.session.clear_warnings();
528        }
529
530        if query_upcase.starts_with("PREPARE ") {
531            match ParserContext::parse_mysql_prepare_stmt(query, query_ctx.sql_dialect()) {
532                Ok((stmt_name, stmt)) => {
533                    let prepare_results =
534                        self.do_prepare(&stmt, query_ctx.clone(), stmt_name).await;
535                    match prepare_results {
536                        Ok(_) => {
537                            let outputs = vec![Ok(Output::new_with_affected_rows(0))];
538                            writer::write_output(writer, query_ctx, self.session.clone(), outputs)
539                                .await?;
540                            return Ok(());
541                        }
542                        Err(e) => {
543                            writer
544                                .error(ErrorKind::ER_SP_BADSTATEMENT, e.output_msg().as_bytes())
545                                .await?;
546                            return Ok(());
547                        }
548                    }
549                }
550                Err(e) => {
551                    writer
552                        .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
553                        .await?;
554                    return Ok(());
555                }
556            }
557        } else if query_upcase.starts_with("EXECUTE ") {
558            match ParserContext::parse_mysql_execute_stmt(query, query_ctx.sql_dialect()) {
559                Ok((stmt_name, params)) => {
560                    let outputs = match self
561                        .do_execute(query_ctx.clone(), stmt_name, Params::CliParams(params))
562                        .await
563                    {
564                        Ok(outputs) => outputs,
565                        Err(e) => {
566                            let (kind, err) = handle_err(e, query_ctx);
567                            debug!(
568                                "Failed to execute prepared statement, kind: {:?}, err: {}",
569                                kind, err
570                            );
571                            writer.error(kind, err.as_bytes()).await?;
572                            return Ok(());
573                        }
574                    };
575                    writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
576
577                    return Ok(());
578                }
579                Err(e) => {
580                    writer
581                        .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
582                        .await?;
583                    return Ok(());
584                }
585            }
586        } else if query_upcase.starts_with("DEALLOCATE ") {
587            match ParserContext::parse_mysql_deallocate_stmt(query, query_ctx.sql_dialect()) {
588                Ok(stmt_name) => {
589                    self.do_close(stmt_name);
590                    let outputs = vec![Ok(Output::new_with_affected_rows(0))];
591                    writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
592                    return Ok(());
593                }
594                Err(e) => {
595                    writer
596                        .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
597                        .await?;
598                    return Ok(());
599                }
600            }
601        }
602
603        let outputs = self.do_query(query, query_ctx.clone()).await;
604        writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
605
606        Ok(())
607    }
608
609    async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> {
610        let (catalog_from_db, schema) = parse_optional_catalog_and_schema_from_db_string(database);
611        let catalog = if let Some(catalog) = &catalog_from_db {
612            catalog.clone()
613        } else {
614            self.session.catalog()
615        };
616
617        if !self
618            .query_handler
619            .is_valid_schema(&catalog, &schema)
620            .await?
621        {
622            return w
623                .error(
624                    ErrorKind::ER_WRONG_DB_NAME,
625                    format!("Unknown database '{}'", database).as_bytes(),
626                )
627                .await
628                .map_err(|e| e.into());
629        }
630
631        let user_info = &self.session.user_info();
632
633        if let Some(schema_validator) = &self.user_provider
634            && let Err(e) = schema_validator
635                .authorize(&catalog, &schema, user_info)
636                .await
637        {
638            METRIC_AUTH_FAILURE
639                .with_label_values(&[e.status_code().as_ref()])
640                .inc();
641            return w
642                .error(
643                    ErrorKind::ER_DBACCESS_DENIED_ERROR,
644                    e.output_msg().as_bytes(),
645                )
646                .await
647                .map_err(|e| e.into());
648        }
649
650        if catalog_from_db.is_some() {
651            self.session.set_catalog(catalog)
652        }
653        self.session.set_schema(schema);
654
655        w.ok().await.map_err(|e| e.into())
656    }
657}
658
659fn convert_param_value_to_string(param: &ParamValue) -> String {
660    match param.value.into_inner() {
661        ValueInner::Int(u) => u.to_string(),
662        ValueInner::UInt(u) => u.to_string(),
663        ValueInner::Double(u) => u.to_string(),
664        ValueInner::NULL => "NULL".to_string(),
665        ValueInner::Bytes(b) => format!("'{}'", &String::from_utf8_lossy(b)),
666        ValueInner::Date(_) => format!("'{}'", NaiveDate::from(param.value)),
667        ValueInner::Datetime(_) => format!("'{}'", NaiveDateTime::from(param.value)),
668        ValueInner::Time(_) => format_duration(Duration::from(param.value)),
669    }
670}
671
672fn replace_params(params: Vec<String>, query: String) -> String {
673    let mut query = query;
674    for (index, param) in (1..).zip(params) {
675        query = query.replace(&format_placeholder(index), &param);
676    }
677    query
678}
679
680fn format_duration(duration: Duration) -> String {
681    let seconds = duration.as_secs() % 60;
682    let minutes = (duration.as_secs() / 60) % 60;
683    let hours = (duration.as_secs() / 60) / 60;
684    format!("'{}:{}:{}'", hours, minutes, seconds)
685}
686
687fn replace_params_with_values(
688    plan: &LogicalPlan,
689    param_types: HashMap<String, Option<ConcreteDataType>>,
690    params: &[ParamValue],
691) -> Result<LogicalPlan> {
692    debug_assert_eq!(param_types.len(), params.len());
693
694    debug!(
695        "replace_params_with_values(param_types: {:#?}, params: {:#?}, plan: {:#?})",
696        param_types,
697        params
698            .iter()
699            .map(|x| format!("({:?}, {:?})", x.value, x.coltype))
700            .join(", "),
701        plan
702    );
703
704    let mut values = Vec::with_capacity(params.len());
705
706    for (i, param) in params.iter().enumerate() {
707        if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
708            let value = helper::convert_value(param, t)?;
709
710            values.push(value.into());
711        }
712    }
713
714    plan.clone()
715        .replace_params_with_values(&ParamValues::List(values.clone()))
716        .context(DataFrameSnafu)
717}
718
719fn replace_params_with_exprs(
720    plan: &LogicalPlan,
721    param_types: HashMap<String, Option<ConcreteDataType>>,
722    params: &[sql::ast::Expr],
723) -> Result<LogicalPlan> {
724    debug_assert_eq!(param_types.len(), params.len());
725
726    debug!(
727        "replace_params_with_exprs(param_types: {:#?}, params: {:#?}, plan: {:#?})",
728        param_types,
729        params.iter().map(|x| format!("({:?})", x)).join(", "),
730        plan
731    );
732
733    let mut values = Vec::with_capacity(params.len());
734
735    for (i, param) in params.iter().enumerate() {
736        if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
737            let value = helper::convert_expr_to_scalar_value(param, t)?;
738
739            values.push(value.into());
740        }
741    }
742
743    plan.clone()
744        .replace_params_with_values(&ParamValues::List(values.clone()))
745        .context(DataFrameSnafu)
746}
747
748async fn validate_query(query: &str) -> Result<Statement> {
749    let statement =
750        ParserContext::create_with_dialect(query, &MySqlDialect {}, ParseOptions::default());
751    let mut statement = statement.map_err(|e| {
752        InvalidPrepareStatementSnafu {
753            err_msg: e.output_msg(),
754        }
755        .build()
756    })?;
757
758    ensure!(
759        statement.len() == 1,
760        InvalidPrepareStatementSnafu {
761            err_msg: "prepare statement only support single statement".to_string(),
762        }
763    );
764
765    let statement = statement.remove(0);
766
767    Ok(statement)
768}
769
770fn dummy_params(index: usize) -> Result<Vec<Column>> {
771    let mut params = Vec::with_capacity(index - 1);
772
773    for _ in 1..index {
774        params.push(create_mysql_column(&ConcreteDataType::null_datatype(), "")?);
775    }
776
777    Ok(params)
778}
779
780/// Parameters that the client must provide when executing the prepared statement.
781fn prepared_params(param_types: &HashMap<String, Option<ConcreteDataType>>) -> Result<Vec<Column>> {
782    let mut params = Vec::with_capacity(param_types.len());
783
784    // Placeholder index starts from 1
785    for index in 1..=param_types.len() {
786        if let Some(Some(t)) = param_types.get(&format_placeholder(index)) {
787            let column = create_mysql_column(t, "")?;
788            params.push(column);
789        }
790    }
791
792    Ok(params)
793}
794
795#[cfg(test)]
796mod tests {
797    use std::sync::Arc;
798
799    use async_trait::async_trait;
800    use common_query::Output;
801    use datafusion_expr::LogicalPlan;
802    use query::parser::PromQuery;
803    use query::query_engine::DescribeResult;
804    use session::context::QueryContext;
805    use sql::statements::statement::Statement;
806
807    use super::*;
808    use crate::error::Result;
809    use crate::query_handler::sql::SqlQueryHandler;
810
811    struct DummyQueryHandler;
812
813    #[async_trait]
814    impl SqlQueryHandler for DummyQueryHandler {
815        async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
816            unimplemented!()
817        }
818
819        async fn do_promql_query(&self, _: &PromQuery, _: QueryContextRef) -> Vec<Result<Output>> {
820            unimplemented!()
821        }
822
823        async fn do_exec_plan(
824            &self,
825            _: LogicalPlan,
826            _: String,
827            _: QueryContextRef,
828        ) -> Result<Output> {
829            unimplemented!()
830        }
831
832        async fn do_describe(
833            &self,
834            _: Statement,
835            _: QueryContextRef,
836        ) -> Result<Option<DescribeResult>> {
837            unimplemented!()
838        }
839
840        async fn is_valid_schema(&self, _: &str, _: &str) -> Result<bool> {
841            Ok(true)
842        }
843    }
844
845    fn create_shim() -> MysqlInstanceShim {
846        MysqlInstanceShim::create(
847            Arc::new(DummyQueryHandler),
848            None,
849            "127.0.0.1:3306".parse().unwrap(),
850            1,
851            1024,
852        )
853    }
854
855    #[tokio::test]
856    async fn test_prepare_federated_query() {
857        let mut shim = create_shim();
858        let query_ctx = QueryContext::arc();
859        let stmt_key = "test_federated".to_string();
860
861        let (params, columns) = shim
862            .do_prepare(
863                "SELECT @@version_comment",
864                query_ctx.clone(),
865                stmt_key.clone(),
866            )
867            .await
868            .unwrap();
869
870        assert!(params.is_empty());
871        assert!(columns.is_empty());
872
873        let plan = shim.plan(&stmt_key).unwrap();
874        assert!(matches!(plan, SqlPlan::Shortcut(q) if q == "SELECT @@version_comment"));
875    }
876
877    #[tokio::test]
878    async fn test_execute_federated_shortcut() {
879        let mut shim = create_shim();
880        let query_ctx = QueryContext::arc();
881        let stmt_key = "test_federated_exec".to_string();
882
883        shim.do_prepare(
884            "SELECT @@version_comment",
885            query_ctx.clone(),
886            stmt_key.clone(),
887        )
888        .await
889        .unwrap();
890
891        let outputs = shim
892            .do_execute(query_ctx.clone(), stmt_key, Params::CliParams(vec![]))
893            .await
894            .unwrap();
895
896        assert_eq!(outputs.len(), 1);
897        let output = outputs.into_iter().next().unwrap().unwrap();
898        let pretty = output.data.pretty_print().await;
899        assert!(pretty.contains("GreptimeDB"));
900    }
901
902    #[tokio::test]
903    async fn test_prepare_non_federated_query_not_shortcut() {
904        let mut shim = create_shim();
905        let query_ctx = QueryContext::arc();
906        let stmt_key = "test_non_federated".to_string();
907
908        let result = shim
909            .do_prepare("SET NAMES utf8", query_ctx.clone(), stmt_key.clone())
910            .await;
911
912        assert!(result.is_ok());
913        let plan = shim.plan(&stmt_key).unwrap();
914        assert!(matches!(plan, SqlPlan::Shortcut(_)));
915    }
916
917    #[tokio::test]
918    async fn test_execute_set_shortcut() {
919        let mut shim = create_shim();
920        let query_ctx = QueryContext::arc();
921        let stmt_key = "test_set_shortcut".to_string();
922
923        shim.do_prepare("SET NAMES utf8", query_ctx.clone(), stmt_key.clone())
924            .await
925            .unwrap();
926
927        let outputs = shim
928            .do_execute(query_ctx.clone(), stmt_key, Params::CliParams(vec![]))
929            .await
930            .unwrap();
931
932        assert_eq!(outputs.len(), 1);
933        let output = outputs.into_iter().next().unwrap().unwrap();
934        match output.data {
935            common_query::OutputData::RecordBatches(batches) => {
936                let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
937                assert_eq!(total_rows, 0);
938            }
939            other => panic!("Expected RecordBatches, got {:?}", other),
940        }
941    }
942}