1use std::collections::HashMap;
16use std::net::SocketAddr;
17use std::sync::Arc;
18use std::sync::atomic::{AtomicU32, Ordering};
19use std::time::Duration;
20
21use ::auth::{Identity, Password, UserProviderRef};
22use async_trait::async_trait;
23use chrono::{NaiveDate, NaiveDateTime};
24use common_catalog::parse_optional_catalog_and_schema_from_db_string;
25use common_error::ext::ErrorExt;
26use common_query::Output;
27use common_telemetry::{debug, error, tracing, warn};
28use datafusion_common::ParamValues;
29use datafusion_expr::LogicalPlan;
30use datatypes::prelude::ConcreteDataType;
31use datatypes::schema::Schema;
32use itertools::Itertools;
33use opensrv_mysql::{
34 AsyncMysqlShim, Column, ErrorKind, InitWriter, ParamParser, ParamValue, QueryResultWriter,
35 StatementMetaWriter, ValueInner,
36};
37use parking_lot::RwLock;
38use query::planner::DfLogicalPlanner;
39use query::query_engine::DescribeResult;
40use rand::RngCore;
41use session::context::{Channel, QueryContextRef};
42use session::{Session, SessionRef};
43use snafu::{ResultExt, ensure};
44use sql::dialect::MySqlDialect;
45use sql::parser::{ParseOptions, ParserContext};
46use sql::statements::statement::Statement;
47use tokio::io::AsyncWrite;
48
49use crate::SqlPlan;
50use crate::error::{
51 self, DataFrameSnafu, InferParameterTypesSnafu, InvalidPrepareStatementSnafu, Result,
52};
53use crate::metrics::METRIC_AUTH_FAILURE;
54use crate::mysql::helper::{
55 self, format_placeholder, replace_placeholders, transform_placeholders,
56};
57use crate::mysql::writer;
58use crate::mysql::writer::{create_mysql_column, handle_err};
59use crate::query_handler::sql::ServerSqlQueryHandlerRef;
60
61const MYSQL_NATIVE_PASSWORD: &str = "mysql_native_password";
62const MYSQL_CLEAR_PASSWORD: &str = "mysql_clear_password";
63
64enum Params<'a> {
66 ProtocolParams(Vec<ParamValue<'a>>),
68 CliParams(Vec<sql::ast::Expr>),
70}
71
72impl Params<'_> {
73 fn len(&self) -> usize {
74 match self {
75 Params::ProtocolParams(params) => params.len(),
76 Params::CliParams(params) => params.len(),
77 }
78 }
79}
80
81pub struct MysqlInstanceShim {
83 query_handler: ServerSqlQueryHandlerRef,
84 salt: [u8; 20],
85 session: SessionRef,
86 user_provider: Option<UserProviderRef>,
87 prepared_stmts: Arc<RwLock<HashMap<String, SqlPlan>>>,
88 prepared_stmts_counter: AtomicU32,
89 process_id: u32,
90 prepared_stmt_cache_size: usize,
91}
92
93impl MysqlInstanceShim {
94 pub fn create(
95 query_handler: ServerSqlQueryHandlerRef,
96 user_provider: Option<UserProviderRef>,
97 client_addr: SocketAddr,
98 process_id: u32,
99 prepared_stmt_cache_size: usize,
100 ) -> MysqlInstanceShim {
101 let mut bs = vec![0u8; 20];
103 let mut rng = rand::rng();
104 rng.fill_bytes(bs.as_mut());
105
106 let mut scramble: [u8; 20] = [0; 20];
107 for i in 0..20 {
108 scramble[i] = bs[i] & 0x7fu8;
109 if scramble[i] == b'\0' || scramble[i] == b'$' {
110 scramble[i] += 1;
111 }
112 }
113
114 MysqlInstanceShim {
115 query_handler,
116 salt: scramble,
117 session: Arc::new(Session::new(
118 Some(client_addr),
119 Channel::Mysql,
120 Default::default(),
121 process_id,
122 )),
123 user_provider,
124 prepared_stmts: Default::default(),
125 prepared_stmts_counter: AtomicU32::new(1),
126 process_id,
127 prepared_stmt_cache_size,
128 }
129 }
130
131 #[tracing::instrument(skip_all, name = "mysql::do_query")]
132 async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec<Result<Output>> {
133 if let Some(output) =
134 crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone())
135 {
136 vec![Ok(output)]
137 } else {
138 self.query_handler.do_query(query, query_ctx.clone()).await
139 }
140 }
141
142 async fn do_describe(
144 &self,
145 statement: Statement,
146 query_ctx: QueryContextRef,
147 ) -> Result<Option<DescribeResult>> {
148 self.query_handler.do_describe(statement, query_ctx).await
149 }
150
151 fn save_plan(&self, plan: SqlPlan, stmt_key: String) -> Result<()> {
153 let mut prepared_stmts = self.prepared_stmts.write();
154 let max_capacity = self.prepared_stmt_cache_size;
155
156 let is_update = prepared_stmts.contains_key(&stmt_key);
157
158 if !is_update && prepared_stmts.len() >= max_capacity {
159 return error::InternalSnafu {
160 err_msg: format!(
161 "Prepared statement cache is full, max capacity: {}",
162 max_capacity
163 ),
164 }
165 .fail();
166 }
167
168 let _ = prepared_stmts.insert(stmt_key, plan);
169 Ok(())
170 }
171
172 fn plan(&self, stmt_key: &str) -> Option<SqlPlan> {
174 let guard = self.prepared_stmts.read();
175 guard.get(stmt_key).cloned()
176 }
177
178 async fn do_prepare(
180 &mut self,
181 raw_query: &str,
182 query_ctx: QueryContextRef,
183 stmt_key: String,
184 ) -> Result<(Vec<Column>, Vec<Column>)> {
185 if crate::mysql::federated::check(raw_query, query_ctx.clone(), self.session.clone())
186 .is_some()
187 {
188 self.save_plan(SqlPlan::Shortcut(raw_query.to_string()), stmt_key)
189 .inspect_err(|e| {
190 error!(e; "Failed to save prepared statement");
191 })?;
192 return Ok((vec![], vec![]));
193 }
194
195 let (query, param_num) = replace_placeholders(raw_query);
196
197 let statement = validate_query(raw_query).await?;
198
199 let statement = transform_placeholders(statement);
202
203 let describe_result = self
204 .do_describe(statement.clone(), query_ctx.clone())
205 .await?;
206 let plan = describe_result.map(|DescribeResult { logical_plan }| logical_plan);
207
208 let params = if let Some(plan) = &plan {
209 let param_types = DfLogicalPlanner::get_inferred_parameter_types(plan)
210 .context(InferParameterTypesSnafu)?
211 .into_iter()
212 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
213 .collect();
214 prepared_params(¶m_types)?
215 } else {
216 dummy_params(param_num)?
217 };
218
219 let columns =
220 plan.as_ref()
221 .map(|plan| {
222 let schema: Schema = plan.schema().clone().try_into().map_err(
223 |e: datatypes::error::Error| {
224 error::InternalSnafu {
225 err_msg: e.to_string(),
226 }
227 .build()
228 },
229 )?;
230 schema
231 .column_schemas()
232 .iter()
233 .map(|column_schema| {
234 create_mysql_column(&column_schema.data_type, &column_schema.name)
235 })
236 .collect::<Result<Vec<_>>>()
237 })
238 .transpose()?
239 .unwrap_or_default();
240
241 match plan {
242 Some(plan) if params.len() == param_num - 1 => {
243 self.save_plan(SqlPlan::Plan(plan, query.clone()), stmt_key)
244 .inspect_err(|e| {
245 error!(e; "Failed to save prepared statement");
246 })?;
247 }
248 _ => {
249 self.save_plan(SqlPlan::Statement(statement, query), stmt_key)
250 .inspect_err(|e| {
251 error!(e; "Failed to save prepared statement");
252 })?;
253 }
254 }
255
256 Ok((params, columns))
257 }
258
259 async fn do_execute(
260 &mut self,
261 query_ctx: QueryContextRef,
262 stmt_key: String,
263 params: Params<'_>,
264 ) -> Result<Vec<std::result::Result<Output, error::Error>>> {
265 let sql_plan = match self.plan(&stmt_key) {
266 None => {
267 return error::PrepareStatementNotFoundSnafu { name: stmt_key }.fail();
268 }
269 Some(sql_plan) => sql_plan,
270 };
271
272 let outputs = match sql_plan {
273 SqlPlan::Plan(plan, query) => {
274 let param_types = DfLogicalPlanner::get_inferred_parameter_types(&plan)
275 .context(InferParameterTypesSnafu)?
276 .into_iter()
277 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
278 .collect::<HashMap<_, _>>();
279
280 if params.len() != param_types.len() {
281 return error::InternalSnafu {
282 err_msg: "Prepare statement params number mismatch".to_string(),
283 }
284 .fail();
285 }
286
287 let replaced_plan = match params {
288 Params::ProtocolParams(params) => {
289 replace_params_with_values(&plan, param_types, ¶ms)
290 }
291 Params::CliParams(params) => {
292 replace_params_with_exprs(&plan, param_types, ¶ms)
293 }
294 }?;
295
296 debug!(
297 "Mysql execute prepared plan: {}",
298 replaced_plan.display_indent()
299 );
300 vec![
301 self.query_handler
302 .do_exec_plan(replaced_plan, query, query_ctx.clone())
303 .await,
304 ]
305 }
306 SqlPlan::Shortcut(query) => {
307 if let Some(output) =
308 crate::mysql::federated::check(&query, query_ctx.clone(), self.session.clone())
309 {
310 vec![Ok(output)]
311 } else {
312 self.do_query(&query, query_ctx.clone()).await
313 }
314 }
315 SqlPlan::Statement(_stmt, query) => {
316 let param_strs = match params {
317 Params::ProtocolParams(params) => {
318 params.iter().map(convert_param_value_to_string).collect()
319 }
320 Params::CliParams(params) => params.iter().map(|x| x.to_string()).collect(),
321 };
322 debug!(
323 "do_execute Replacing with Params: {:?}, Original Query: {}",
324 param_strs, query
325 );
326 let query = replace_params(param_strs, query);
327 debug!("Mysql execute replaced query: {}", query);
328 self.do_query(&query, query_ctx.clone()).await
329 }
330 _ => {
331 return error::PrepareStatementNotFoundSnafu { name: stmt_key }.fail();
332 }
333 };
334
335 Ok(outputs)
336 }
337
338 fn do_close(&mut self, stmt_key: String) {
340 let mut guard = self.prepared_stmts.write();
341 let _ = guard.remove(&stmt_key);
342 }
343
344 fn auth_plugin(&self) -> &'static str {
345 if self
346 .user_provider
347 .as_ref()
348 .map(|x| x.external())
349 .unwrap_or(false)
350 {
351 MYSQL_CLEAR_PASSWORD
352 } else {
353 MYSQL_NATIVE_PASSWORD
354 }
355 }
356}
357
358#[async_trait]
359impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShim {
360 type Error = error::Error;
361
362 fn version(&self) -> String {
363 std::env::var("GREPTIMEDB_MYSQL_SERVER_VERSION").unwrap_or_else(|_| "8.4.2".to_string())
364 }
365
366 fn connect_id(&self) -> u32 {
367 self.process_id
368 }
369
370 fn default_auth_plugin(&self) -> &str {
371 self.auth_plugin()
372 }
373
374 async fn auth_plugin_for_username(&self, _user: &[u8]) -> &'static str {
375 self.auth_plugin()
376 }
377
378 fn salt(&self) -> [u8; 20] {
379 self.salt
380 }
381
382 async fn authenticate(
383 &self,
384 auth_plugin: &str,
385 username: &[u8],
386 salt: &[u8],
387 auth_data: &[u8],
388 ) -> bool {
389 let username = String::from_utf8_lossy(username);
391
392 let mut user_info = None;
393 let addr = self
394 .session
395 .conn_info()
396 .client_addr
397 .map(|addr| addr.to_string());
398 if let Some(user_provider) = &self.user_provider {
399 let user_id = Identity::UserId(&username, addr.as_deref());
400
401 let password = match auth_plugin {
402 MYSQL_NATIVE_PASSWORD => Password::MysqlNativePassword(auth_data, salt),
403 MYSQL_CLEAR_PASSWORD => {
404 let password = if let &[password @ .., 0] = &auth_data {
407 password
408 } else {
409 auth_data
410 };
411 Password::PlainText(String::from_utf8_lossy(password).to_string().into())
412 }
413 other => {
414 error!("Unsupported mysql auth plugin: {}", other);
415 return false;
416 }
417 };
418 match user_provider.authenticate(user_id, password).await {
419 Ok(userinfo) => {
420 user_info = Some(userinfo);
421 }
422 Err(e) => {
423 METRIC_AUTH_FAILURE
424 .with_label_values(&[e.status_code().as_ref()])
425 .inc();
426 warn!(e; "Failed to auth");
427 return false;
428 }
429 };
430 }
431 let user_info =
432 user_info.unwrap_or_else(|| auth::userinfo_by_name(Some(username.to_string())));
433
434 self.session.set_user_info(user_info);
435
436 true
437 }
438
439 async fn on_prepare<'a>(
440 &'a mut self,
441 raw_query: &'a str,
442 w: StatementMetaWriter<'a, W>,
443 ) -> Result<()> {
444 let query_ctx = self.session.new_query_context();
445 let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::Relaxed);
446 let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
447 let (params, columns) = match self
448 .do_prepare(raw_query, query_ctx.clone(), stmt_key)
449 .await
450 {
451 Ok(x) => x,
452 Err(e) => {
453 let (kind, msg) = handle_err(e, query_ctx.clone());
454 w.error(kind, msg.as_bytes()).await?;
455 return Ok(());
456 }
457 };
458 debug!("on_prepare: Params: {:?}, Columns: {:?}", params, columns);
459 w.reply(stmt_id, ¶ms, &columns).await?;
460 crate::metrics::METRIC_MYSQL_PREPARED_COUNT
461 .with_label_values(&[query_ctx.get_db_string().as_str()])
462 .inc();
463 return Ok(());
464 }
465
466 async fn on_execute<'a>(
467 &'a mut self,
468 stmt_id: u32,
469 p: ParamParser<'a>,
470 w: QueryResultWriter<'a, W>,
471 ) -> Result<()> {
472 self.session.clear_warnings();
473
474 let query_ctx = self.session.new_query_context();
475 let db = query_ctx.get_db_string();
476 let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER
477 .with_label_values(&[crate::metrics::METRIC_MYSQL_BINQUERY, db.as_str()])
478 .start_timer();
479
480 let params: Vec<ParamValue> = p.into_iter().collect();
481 let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
482
483 let outputs = match self
484 .do_execute(query_ctx.clone(), stmt_key, Params::ProtocolParams(params))
485 .await
486 {
487 Ok(outputs) => outputs,
488 Err(e) => {
489 let (kind, err) = handle_err(e, query_ctx);
490 debug!(
491 "Failed to execute prepared statement, kind: {:?}, err: {}",
492 kind, err
493 );
494 w.error(kind, err.as_bytes()).await?;
495 return Ok(());
496 }
497 };
498
499 writer::write_output(w, query_ctx, self.session.clone(), outputs).await?;
500
501 Ok(())
502 }
503
504 async fn on_close<'a>(&'a mut self, stmt_id: u32)
505 where
506 W: 'async_trait,
507 {
508 let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
509 self.do_close(stmt_key);
510 }
511
512 #[tracing::instrument(skip_all, fields(protocol = "mysql"))]
513 async fn on_query<'a>(
514 &'a mut self,
515 query: &'a str,
516 writer: QueryResultWriter<'a, W>,
517 ) -> Result<()> {
518 let query_ctx = self.session.new_query_context();
519 let db = query_ctx.get_db_string();
520 let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER
521 .with_label_values(&[crate::metrics::METRIC_MYSQL_TEXTQUERY, db.as_str()])
522 .start_timer();
523
524 let query_upcase = query.to_uppercase();
526 if !query_upcase.starts_with("SHOW WARNINGS") {
527 self.session.clear_warnings();
528 }
529
530 if query_upcase.starts_with("PREPARE ") {
531 match ParserContext::parse_mysql_prepare_stmt(query, query_ctx.sql_dialect()) {
532 Ok((stmt_name, stmt)) => {
533 let prepare_results =
534 self.do_prepare(&stmt, query_ctx.clone(), stmt_name).await;
535 match prepare_results {
536 Ok(_) => {
537 let outputs = vec![Ok(Output::new_with_affected_rows(0))];
538 writer::write_output(writer, query_ctx, self.session.clone(), outputs)
539 .await?;
540 return Ok(());
541 }
542 Err(e) => {
543 writer
544 .error(ErrorKind::ER_SP_BADSTATEMENT, e.output_msg().as_bytes())
545 .await?;
546 return Ok(());
547 }
548 }
549 }
550 Err(e) => {
551 writer
552 .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
553 .await?;
554 return Ok(());
555 }
556 }
557 } else if query_upcase.starts_with("EXECUTE ") {
558 match ParserContext::parse_mysql_execute_stmt(query, query_ctx.sql_dialect()) {
559 Ok((stmt_name, params)) => {
560 let outputs = match self
561 .do_execute(query_ctx.clone(), stmt_name, Params::CliParams(params))
562 .await
563 {
564 Ok(outputs) => outputs,
565 Err(e) => {
566 let (kind, err) = handle_err(e, query_ctx);
567 debug!(
568 "Failed to execute prepared statement, kind: {:?}, err: {}",
569 kind, err
570 );
571 writer.error(kind, err.as_bytes()).await?;
572 return Ok(());
573 }
574 };
575 writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
576
577 return Ok(());
578 }
579 Err(e) => {
580 writer
581 .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
582 .await?;
583 return Ok(());
584 }
585 }
586 } else if query_upcase.starts_with("DEALLOCATE ") {
587 match ParserContext::parse_mysql_deallocate_stmt(query, query_ctx.sql_dialect()) {
588 Ok(stmt_name) => {
589 self.do_close(stmt_name);
590 let outputs = vec![Ok(Output::new_with_affected_rows(0))];
591 writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
592 return Ok(());
593 }
594 Err(e) => {
595 writer
596 .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
597 .await?;
598 return Ok(());
599 }
600 }
601 }
602
603 let outputs = self.do_query(query, query_ctx.clone()).await;
604 writer::write_output(writer, query_ctx, self.session.clone(), outputs).await?;
605
606 Ok(())
607 }
608
609 async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> {
610 let (catalog_from_db, schema) = parse_optional_catalog_and_schema_from_db_string(database);
611 let catalog = if let Some(catalog) = &catalog_from_db {
612 catalog.clone()
613 } else {
614 self.session.catalog()
615 };
616
617 if !self
618 .query_handler
619 .is_valid_schema(&catalog, &schema)
620 .await?
621 {
622 return w
623 .error(
624 ErrorKind::ER_WRONG_DB_NAME,
625 format!("Unknown database '{}'", database).as_bytes(),
626 )
627 .await
628 .map_err(|e| e.into());
629 }
630
631 let user_info = &self.session.user_info();
632
633 if let Some(schema_validator) = &self.user_provider
634 && let Err(e) = schema_validator
635 .authorize(&catalog, &schema, user_info)
636 .await
637 {
638 METRIC_AUTH_FAILURE
639 .with_label_values(&[e.status_code().as_ref()])
640 .inc();
641 return w
642 .error(
643 ErrorKind::ER_DBACCESS_DENIED_ERROR,
644 e.output_msg().as_bytes(),
645 )
646 .await
647 .map_err(|e| e.into());
648 }
649
650 if catalog_from_db.is_some() {
651 self.session.set_catalog(catalog)
652 }
653 self.session.set_schema(schema);
654
655 w.ok().await.map_err(|e| e.into())
656 }
657}
658
659fn convert_param_value_to_string(param: &ParamValue) -> String {
660 match param.value.into_inner() {
661 ValueInner::Int(u) => u.to_string(),
662 ValueInner::UInt(u) => u.to_string(),
663 ValueInner::Double(u) => u.to_string(),
664 ValueInner::NULL => "NULL".to_string(),
665 ValueInner::Bytes(b) => format!("'{}'", &String::from_utf8_lossy(b)),
666 ValueInner::Date(_) => format!("'{}'", NaiveDate::from(param.value)),
667 ValueInner::Datetime(_) => format!("'{}'", NaiveDateTime::from(param.value)),
668 ValueInner::Time(_) => format_duration(Duration::from(param.value)),
669 }
670}
671
672fn replace_params(params: Vec<String>, query: String) -> String {
673 let mut query = query;
674 for (index, param) in (1..).zip(params) {
675 query = query.replace(&format_placeholder(index), ¶m);
676 }
677 query
678}
679
680fn format_duration(duration: Duration) -> String {
681 let seconds = duration.as_secs() % 60;
682 let minutes = (duration.as_secs() / 60) % 60;
683 let hours = (duration.as_secs() / 60) / 60;
684 format!("'{}:{}:{}'", hours, minutes, seconds)
685}
686
687fn replace_params_with_values(
688 plan: &LogicalPlan,
689 param_types: HashMap<String, Option<ConcreteDataType>>,
690 params: &[ParamValue],
691) -> Result<LogicalPlan> {
692 debug_assert_eq!(param_types.len(), params.len());
693
694 debug!(
695 "replace_params_with_values(param_types: {:#?}, params: {:#?}, plan: {:#?})",
696 param_types,
697 params
698 .iter()
699 .map(|x| format!("({:?}, {:?})", x.value, x.coltype))
700 .join(", "),
701 plan
702 );
703
704 let mut values = Vec::with_capacity(params.len());
705
706 for (i, param) in params.iter().enumerate() {
707 if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
708 let value = helper::convert_value(param, t)?;
709
710 values.push(value.into());
711 }
712 }
713
714 plan.clone()
715 .replace_params_with_values(&ParamValues::List(values.clone()))
716 .context(DataFrameSnafu)
717}
718
719fn replace_params_with_exprs(
720 plan: &LogicalPlan,
721 param_types: HashMap<String, Option<ConcreteDataType>>,
722 params: &[sql::ast::Expr],
723) -> Result<LogicalPlan> {
724 debug_assert_eq!(param_types.len(), params.len());
725
726 debug!(
727 "replace_params_with_exprs(param_types: {:#?}, params: {:#?}, plan: {:#?})",
728 param_types,
729 params.iter().map(|x| format!("({:?})", x)).join(", "),
730 plan
731 );
732
733 let mut values = Vec::with_capacity(params.len());
734
735 for (i, param) in params.iter().enumerate() {
736 if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
737 let value = helper::convert_expr_to_scalar_value(param, t)?;
738
739 values.push(value.into());
740 }
741 }
742
743 plan.clone()
744 .replace_params_with_values(&ParamValues::List(values.clone()))
745 .context(DataFrameSnafu)
746}
747
748async fn validate_query(query: &str) -> Result<Statement> {
749 let statement =
750 ParserContext::create_with_dialect(query, &MySqlDialect {}, ParseOptions::default());
751 let mut statement = statement.map_err(|e| {
752 InvalidPrepareStatementSnafu {
753 err_msg: e.output_msg(),
754 }
755 .build()
756 })?;
757
758 ensure!(
759 statement.len() == 1,
760 InvalidPrepareStatementSnafu {
761 err_msg: "prepare statement only support single statement".to_string(),
762 }
763 );
764
765 let statement = statement.remove(0);
766
767 Ok(statement)
768}
769
770fn dummy_params(index: usize) -> Result<Vec<Column>> {
771 let mut params = Vec::with_capacity(index - 1);
772
773 for _ in 1..index {
774 params.push(create_mysql_column(&ConcreteDataType::null_datatype(), "")?);
775 }
776
777 Ok(params)
778}
779
780fn prepared_params(param_types: &HashMap<String, Option<ConcreteDataType>>) -> Result<Vec<Column>> {
782 let mut params = Vec::with_capacity(param_types.len());
783
784 for index in 1..=param_types.len() {
786 if let Some(Some(t)) = param_types.get(&format_placeholder(index)) {
787 let column = create_mysql_column(t, "")?;
788 params.push(column);
789 }
790 }
791
792 Ok(params)
793}
794
795#[cfg(test)]
796mod tests {
797 use std::sync::Arc;
798
799 use async_trait::async_trait;
800 use common_query::Output;
801 use datafusion_expr::LogicalPlan;
802 use query::parser::PromQuery;
803 use query::query_engine::DescribeResult;
804 use session::context::QueryContext;
805 use sql::statements::statement::Statement;
806
807 use super::*;
808 use crate::error::Result;
809 use crate::query_handler::sql::SqlQueryHandler;
810
811 struct DummyQueryHandler;
812
813 #[async_trait]
814 impl SqlQueryHandler for DummyQueryHandler {
815 async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
816 unimplemented!()
817 }
818
819 async fn do_promql_query(&self, _: &PromQuery, _: QueryContextRef) -> Vec<Result<Output>> {
820 unimplemented!()
821 }
822
823 async fn do_exec_plan(
824 &self,
825 _: LogicalPlan,
826 _: String,
827 _: QueryContextRef,
828 ) -> Result<Output> {
829 unimplemented!()
830 }
831
832 async fn do_describe(
833 &self,
834 _: Statement,
835 _: QueryContextRef,
836 ) -> Result<Option<DescribeResult>> {
837 unimplemented!()
838 }
839
840 async fn is_valid_schema(&self, _: &str, _: &str) -> Result<bool> {
841 Ok(true)
842 }
843 }
844
845 fn create_shim() -> MysqlInstanceShim {
846 MysqlInstanceShim::create(
847 Arc::new(DummyQueryHandler),
848 None,
849 "127.0.0.1:3306".parse().unwrap(),
850 1,
851 1024,
852 )
853 }
854
855 #[tokio::test]
856 async fn test_prepare_federated_query() {
857 let mut shim = create_shim();
858 let query_ctx = QueryContext::arc();
859 let stmt_key = "test_federated".to_string();
860
861 let (params, columns) = shim
862 .do_prepare(
863 "SELECT @@version_comment",
864 query_ctx.clone(),
865 stmt_key.clone(),
866 )
867 .await
868 .unwrap();
869
870 assert!(params.is_empty());
871 assert!(columns.is_empty());
872
873 let plan = shim.plan(&stmt_key).unwrap();
874 assert!(matches!(plan, SqlPlan::Shortcut(q) if q == "SELECT @@version_comment"));
875 }
876
877 #[tokio::test]
878 async fn test_execute_federated_shortcut() {
879 let mut shim = create_shim();
880 let query_ctx = QueryContext::arc();
881 let stmt_key = "test_federated_exec".to_string();
882
883 shim.do_prepare(
884 "SELECT @@version_comment",
885 query_ctx.clone(),
886 stmt_key.clone(),
887 )
888 .await
889 .unwrap();
890
891 let outputs = shim
892 .do_execute(query_ctx.clone(), stmt_key, Params::CliParams(vec![]))
893 .await
894 .unwrap();
895
896 assert_eq!(outputs.len(), 1);
897 let output = outputs.into_iter().next().unwrap().unwrap();
898 let pretty = output.data.pretty_print().await;
899 assert!(pretty.contains("GreptimeDB"));
900 }
901
902 #[tokio::test]
903 async fn test_prepare_non_federated_query_not_shortcut() {
904 let mut shim = create_shim();
905 let query_ctx = QueryContext::arc();
906 let stmt_key = "test_non_federated".to_string();
907
908 let result = shim
909 .do_prepare("SET NAMES utf8", query_ctx.clone(), stmt_key.clone())
910 .await;
911
912 assert!(result.is_ok());
913 let plan = shim.plan(&stmt_key).unwrap();
914 assert!(matches!(plan, SqlPlan::Shortcut(_)));
915 }
916
917 #[tokio::test]
918 async fn test_execute_set_shortcut() {
919 let mut shim = create_shim();
920 let query_ctx = QueryContext::arc();
921 let stmt_key = "test_set_shortcut".to_string();
922
923 shim.do_prepare("SET NAMES utf8", query_ctx.clone(), stmt_key.clone())
924 .await
925 .unwrap();
926
927 let outputs = shim
928 .do_execute(query_ctx.clone(), stmt_key, Params::CliParams(vec![]))
929 .await
930 .unwrap();
931
932 assert_eq!(outputs.len(), 1);
933 let output = outputs.into_iter().next().unwrap().unwrap();
934 match output.data {
935 common_query::OutputData::RecordBatches(batches) => {
936 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
937 assert_eq!(total_rows, 0);
938 }
939 other => panic!("Expected RecordBatches, got {:?}", other),
940 }
941 }
942}