1use std::collections::HashMap;
16use std::net::SocketAddr;
17use std::sync::atomic::{AtomicU32, Ordering};
18use std::sync::Arc;
19use std::time::Duration;
20
21use ::auth::{Identity, Password, UserProviderRef};
22use async_trait::async_trait;
23use chrono::{NaiveDate, NaiveDateTime};
24use common_catalog::parse_optional_catalog_and_schema_from_db_string;
25use common_error::ext::ErrorExt;
26use common_query::Output;
27use common_telemetry::{debug, error, tracing, warn};
28use datafusion_common::ParamValues;
29use datafusion_expr::LogicalPlan;
30use datatypes::prelude::ConcreteDataType;
31use itertools::Itertools;
32use opensrv_mysql::{
33 AsyncMysqlShim, Column, ErrorKind, InitWriter, ParamParser, ParamValue, QueryResultWriter,
34 StatementMetaWriter, ValueInner,
35};
36use parking_lot::RwLock;
37use query::query_engine::DescribeResult;
38use rand::RngCore;
39use session::context::{Channel, QueryContextRef};
40use session::{Session, SessionRef};
41use snafu::{ensure, ResultExt};
42use sql::dialect::MySqlDialect;
43use sql::parser::{ParseOptions, ParserContext};
44use sql::statements::statement::Statement;
45use tokio::io::AsyncWrite;
46
47use crate::error::{self, DataFrameSnafu, InvalidPrepareStatementSnafu, Result};
48use crate::metrics::METRIC_AUTH_FAILURE;
49use crate::mysql::helper::{
50 self, fix_placeholder_types, format_placeholder, replace_placeholders, transform_placeholders,
51};
52use crate::mysql::writer;
53use crate::mysql::writer::{create_mysql_column, handle_err};
54use crate::query_handler::sql::ServerSqlQueryHandlerRef;
55use crate::SqlPlan;
56
57const MYSQL_NATIVE_PASSWORD: &str = "mysql_native_password";
58const MYSQL_CLEAR_PASSWORD: &str = "mysql_clear_password";
59
60enum Params<'a> {
62 ProtocolParams(Vec<ParamValue<'a>>),
64 CliParams(Vec<sql::ast::Expr>),
66}
67
68impl Params<'_> {
69 fn len(&self) -> usize {
70 match self {
71 Params::ProtocolParams(params) => params.len(),
72 Params::CliParams(params) => params.len(),
73 }
74 }
75}
76
77pub struct MysqlInstanceShim {
79 query_handler: ServerSqlQueryHandlerRef,
80 salt: [u8; 20],
81 session: SessionRef,
82 user_provider: Option<UserProviderRef>,
83 prepared_stmts: Arc<RwLock<HashMap<String, SqlPlan>>>,
84 prepared_stmts_counter: AtomicU32,
85 process_id: u32,
86}
87
88impl MysqlInstanceShim {
89 pub fn create(
90 query_handler: ServerSqlQueryHandlerRef,
91 user_provider: Option<UserProviderRef>,
92 client_addr: SocketAddr,
93 process_id: u32,
94 ) -> MysqlInstanceShim {
95 let mut bs = vec![0u8; 20];
97 let mut rng = rand::rng();
98 rng.fill_bytes(bs.as_mut());
99
100 let mut scramble: [u8; 20] = [0; 20];
101 for i in 0..20 {
102 scramble[i] = bs[i] & 0x7fu8;
103 if scramble[i] == b'\0' || scramble[i] == b'$' {
104 scramble[i] += 1;
105 }
106 }
107
108 MysqlInstanceShim {
109 query_handler,
110 salt: scramble,
111 session: Arc::new(Session::new(
112 Some(client_addr),
113 Channel::Mysql,
114 Default::default(),
115 process_id,
116 )),
117 user_provider,
118 prepared_stmts: Default::default(),
119 prepared_stmts_counter: AtomicU32::new(1),
120 process_id,
121 }
122 }
123
124 #[tracing::instrument(skip_all, name = "mysql::do_query")]
125 async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec<Result<Output>> {
126 if let Some(output) =
127 crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone())
128 {
129 vec![Ok(output)]
130 } else {
131 self.query_handler.do_query(query, query_ctx.clone()).await
132 }
133 }
134
135 async fn do_exec_plan(
137 &self,
138 query: &str,
139 plan: LogicalPlan,
140 query_ctx: QueryContextRef,
141 ) -> Result<Output> {
142 if let Some(output) =
143 crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone())
144 {
145 Ok(output)
146 } else {
147 self.query_handler.do_exec_plan(plan, query_ctx).await
148 }
149 }
150
151 async fn do_describe(
153 &self,
154 statement: Statement,
155 query_ctx: QueryContextRef,
156 ) -> Result<Option<DescribeResult>> {
157 self.query_handler.do_describe(statement, query_ctx).await
158 }
159
160 fn save_plan(&self, plan: SqlPlan, stmt_key: String) {
162 let mut prepared_stmts = self.prepared_stmts.write();
163 let _ = prepared_stmts.insert(stmt_key, plan);
164 }
165
166 fn plan(&self, stmt_key: &str) -> Option<SqlPlan> {
168 let guard = self.prepared_stmts.read();
169 guard.get(stmt_key).cloned()
170 }
171
172 async fn do_prepare(
174 &mut self,
175 raw_query: &str,
176 query_ctx: QueryContextRef,
177 stmt_key: String,
178 ) -> Result<(Vec<Column>, Vec<Column>)> {
179 let (query, param_num) = replace_placeholders(raw_query);
180
181 let statement = validate_query(raw_query).await?;
182
183 let statement = transform_placeholders(statement);
186
187 let describe_result = self
188 .do_describe(statement.clone(), query_ctx.clone())
189 .await?;
190 let (mut plan, schema) = if let Some(DescribeResult {
191 logical_plan,
192 schema,
193 }) = describe_result
194 {
195 (Some(logical_plan), Some(schema))
196 } else {
197 (None, None)
198 };
199
200 let params = if let Some(plan) = &mut plan {
201 fix_placeholder_types(plan)?;
202 debug!("Plan after fix placeholder types: {:#?}", plan);
203 prepared_params(
204 &plan
205 .get_parameter_types()
206 .context(DataFrameSnafu)?
207 .into_iter()
208 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
209 .collect(),
210 )?
211 } else {
212 dummy_params(param_num)?
213 };
214
215 let columns = schema
216 .as_ref()
217 .map(|schema| {
218 schema
219 .column_schemas()
220 .iter()
221 .map(|column_schema| {
222 create_mysql_column(&column_schema.data_type, &column_schema.name)
223 })
224 .collect::<Result<Vec<_>>>()
225 })
226 .transpose()?
227 .unwrap_or_default();
228
229 if params.len() != param_num - 1 {
231 self.save_plan(
232 SqlPlan {
233 query: query.to_string(),
234 plan: None,
235 schema: None,
236 },
237 stmt_key,
238 );
239 } else {
240 self.save_plan(
241 SqlPlan {
242 query: query.to_string(),
243 plan,
244 schema,
245 },
246 stmt_key,
247 );
248 }
249
250 Ok((params, columns))
251 }
252
253 async fn do_execute(
254 &mut self,
255 query_ctx: QueryContextRef,
256 stmt_key: String,
257 params: Params<'_>,
258 ) -> Result<Vec<std::result::Result<Output, error::Error>>> {
259 let sql_plan = match self.plan(&stmt_key) {
260 None => {
261 return error::PrepareStatementNotFoundSnafu { name: stmt_key }.fail();
262 }
263 Some(sql_plan) => sql_plan,
264 };
265
266 let outputs = match sql_plan.plan {
267 Some(mut plan) => {
268 fix_placeholder_types(&mut plan)?;
269 let param_types = plan
270 .get_parameter_types()
271 .context(DataFrameSnafu)?
272 .into_iter()
273 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
274 .collect::<HashMap<_, _>>();
275
276 if params.len() != param_types.len() {
277 return error::InternalSnafu {
278 err_msg: "Prepare statement params number mismatch".to_string(),
279 }
280 .fail();
281 }
282
283 let plan = match params {
284 Params::ProtocolParams(params) => {
285 replace_params_with_values(&plan, param_types, ¶ms)
286 }
287 Params::CliParams(params) => {
288 replace_params_with_exprs(&plan, param_types, ¶ms)
289 }
290 }?;
291
292 debug!("Mysql execute prepared plan: {}", plan.display_indent());
293 vec![
294 self.do_exec_plan(&sql_plan.query, plan, query_ctx.clone())
295 .await,
296 ]
297 }
298 None => {
299 let param_strs = match params {
300 Params::ProtocolParams(params) => {
301 params.iter().map(convert_param_value_to_string).collect()
302 }
303 Params::CliParams(params) => params.iter().map(|x| x.to_string()).collect(),
304 };
305 debug!(
306 "do_execute Replacing with Params: {:?}, Original Query: {}",
307 param_strs, sql_plan.query
308 );
309 let query = replace_params(param_strs, sql_plan.query);
310 debug!("Mysql execute replaced query: {}", query);
311 self.do_query(&query, query_ctx.clone()).await
312 }
313 };
314
315 Ok(outputs)
316 }
317
318 fn do_close(&mut self, stmt_key: String) {
320 let mut guard = self.prepared_stmts.write();
321 let _ = guard.remove(&stmt_key);
322 }
323
324 fn auth_plugin(&self) -> &str {
325 if self
326 .user_provider
327 .as_ref()
328 .map(|x| x.external())
329 .unwrap_or(false)
330 {
331 MYSQL_CLEAR_PASSWORD
332 } else {
333 MYSQL_NATIVE_PASSWORD
334 }
335 }
336}
337
338#[async_trait]
339impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShim {
340 type Error = error::Error;
341
342 fn version(&self) -> String {
343 std::env::var("GREPTIMEDB_MYSQL_SERVER_VERSION").unwrap_or_else(|_| "8.4.2".to_string())
344 }
345
346 fn connect_id(&self) -> u32 {
347 self.process_id
348 }
349
350 fn default_auth_plugin(&self) -> &str {
351 self.auth_plugin()
352 }
353
354 async fn auth_plugin_for_username<'a, 'user>(&'a self, _user: &'user [u8]) -> &'a str {
355 self.auth_plugin()
356 }
357
358 fn salt(&self) -> [u8; 20] {
359 self.salt
360 }
361
362 async fn authenticate(
363 &self,
364 auth_plugin: &str,
365 username: &[u8],
366 salt: &[u8],
367 auth_data: &[u8],
368 ) -> bool {
369 let username = String::from_utf8_lossy(username);
371
372 let mut user_info = None;
373 let addr = self
374 .session
375 .conn_info()
376 .client_addr
377 .map(|addr| addr.to_string());
378 if let Some(user_provider) = &self.user_provider {
379 let user_id = Identity::UserId(&username, addr.as_deref());
380
381 let password = match auth_plugin {
382 MYSQL_NATIVE_PASSWORD => Password::MysqlNativePassword(auth_data, salt),
383 MYSQL_CLEAR_PASSWORD => {
384 let password = if let &[password @ .., 0] = &auth_data {
387 password
388 } else {
389 auth_data
390 };
391 Password::PlainText(String::from_utf8_lossy(password).to_string().into())
392 }
393 other => {
394 error!("Unsupported mysql auth plugin: {}", other);
395 return false;
396 }
397 };
398 match user_provider.authenticate(user_id, password).await {
399 Ok(userinfo) => {
400 user_info = Some(userinfo);
401 }
402 Err(e) => {
403 METRIC_AUTH_FAILURE
404 .with_label_values(&[e.status_code().as_ref()])
405 .inc();
406 warn!(e; "Failed to auth");
407 return false;
408 }
409 };
410 }
411 let user_info =
412 user_info.unwrap_or_else(|| auth::userinfo_by_name(Some(username.to_string())));
413
414 self.session.set_user_info(user_info);
415
416 true
417 }
418
419 async fn on_prepare<'a>(
420 &'a mut self,
421 raw_query: &'a str,
422 w: StatementMetaWriter<'a, W>,
423 ) -> Result<()> {
424 let query_ctx = self.session.new_query_context();
425 let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::Relaxed);
426 let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
427 let (params, columns) = self
428 .do_prepare(raw_query, query_ctx.clone(), stmt_key)
429 .await?;
430 debug!("on_prepare: Params: {:?}, Columns: {:?}", params, columns);
431 w.reply(stmt_id, ¶ms, &columns).await?;
432 crate::metrics::METRIC_MYSQL_PREPARED_COUNT
433 .with_label_values(&[query_ctx.get_db_string().as_str()])
434 .inc();
435 return Ok(());
436 }
437
438 async fn on_execute<'a>(
439 &'a mut self,
440 stmt_id: u32,
441 p: ParamParser<'a>,
442 w: QueryResultWriter<'a, W>,
443 ) -> Result<()> {
444 let query_ctx = self.session.new_query_context();
445 let db = query_ctx.get_db_string();
446 let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER
447 .with_label_values(&[crate::metrics::METRIC_MYSQL_BINQUERY, db.as_str()])
448 .start_timer();
449
450 let params: Vec<ParamValue> = p.into_iter().collect();
451 let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
452
453 let outputs = match self
454 .do_execute(query_ctx.clone(), stmt_key, Params::ProtocolParams(params))
455 .await
456 {
457 Ok(outputs) => outputs,
458 Err(e) => {
459 let (kind, err) = handle_err(e, query_ctx);
460 debug!(
461 "Failed to execute prepared statement, kind: {:?}, err: {}",
462 kind, err
463 );
464 w.error(kind, err.as_bytes()).await?;
465 return Ok(());
466 }
467 };
468
469 writer::write_output(w, query_ctx, outputs).await?;
470
471 Ok(())
472 }
473
474 async fn on_close<'a>(&'a mut self, stmt_id: u32)
475 where
476 W: 'async_trait,
477 {
478 let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
479 self.do_close(stmt_key);
480 }
481
482 #[tracing::instrument(skip_all, fields(protocol = "mysql"))]
483 async fn on_query<'a>(
484 &'a mut self,
485 query: &'a str,
486 writer: QueryResultWriter<'a, W>,
487 ) -> Result<()> {
488 let query_ctx = self.session.new_query_context();
489 let db = query_ctx.get_db_string();
490 let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER
491 .with_label_values(&[crate::metrics::METRIC_MYSQL_TEXTQUERY, db.as_str()])
492 .start_timer();
493
494 let query_upcase = query.to_uppercase();
495 if query_upcase.starts_with("PREPARE ") {
496 match ParserContext::parse_mysql_prepare_stmt(query, query_ctx.sql_dialect()) {
497 Ok((stmt_name, stmt)) => {
498 let prepare_results =
499 self.do_prepare(&stmt, query_ctx.clone(), stmt_name).await;
500 match prepare_results {
501 Ok(_) => {
502 let outputs = vec![Ok(Output::new_with_affected_rows(0))];
503 writer::write_output(writer, query_ctx, outputs).await?;
504 return Ok(());
505 }
506 Err(e) => {
507 writer
508 .error(ErrorKind::ER_SP_BADSTATEMENT, e.output_msg().as_bytes())
509 .await?;
510 return Ok(());
511 }
512 }
513 }
514 Err(e) => {
515 writer
516 .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
517 .await?;
518 return Ok(());
519 }
520 }
521 } else if query_upcase.starts_with("EXECUTE ") {
522 match ParserContext::parse_mysql_execute_stmt(query, query_ctx.sql_dialect()) {
523 Ok((stmt_name, params)) => {
524 let outputs = match self
525 .do_execute(query_ctx.clone(), stmt_name, Params::CliParams(params))
526 .await
527 {
528 Ok(outputs) => outputs,
529 Err(e) => {
530 let (kind, err) = handle_err(e, query_ctx);
531 debug!(
532 "Failed to execute prepared statement, kind: {:?}, err: {}",
533 kind, err
534 );
535 writer.error(kind, err.as_bytes()).await?;
536 return Ok(());
537 }
538 };
539 writer::write_output(writer, query_ctx, outputs).await?;
540 return Ok(());
541 }
542 Err(e) => {
543 writer
544 .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
545 .await?;
546 return Ok(());
547 }
548 }
549 } else if query_upcase.starts_with("DEALLOCATE ") {
550 match ParserContext::parse_mysql_deallocate_stmt(query, query_ctx.sql_dialect()) {
551 Ok(stmt_name) => {
552 self.do_close(stmt_name);
553 let outputs = vec![Ok(Output::new_with_affected_rows(0))];
554 writer::write_output(writer, query_ctx, outputs).await?;
555 return Ok(());
556 }
557 Err(e) => {
558 writer
559 .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
560 .await?;
561 return Ok(());
562 }
563 }
564 }
565
566 let outputs = self.do_query(query, query_ctx.clone()).await;
567 writer::write_output(writer, query_ctx, outputs).await?;
568 Ok(())
569 }
570
571 async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> {
572 let (catalog_from_db, schema) = parse_optional_catalog_and_schema_from_db_string(database);
573 let catalog = if let Some(catalog) = &catalog_from_db {
574 catalog.to_string()
575 } else {
576 self.session.catalog()
577 };
578
579 if !self
580 .query_handler
581 .is_valid_schema(&catalog, &schema)
582 .await?
583 {
584 return w
585 .error(
586 ErrorKind::ER_WRONG_DB_NAME,
587 format!("Unknown database '{}'", database).as_bytes(),
588 )
589 .await
590 .map_err(|e| e.into());
591 }
592
593 let user_info = &self.session.user_info();
594
595 if let Some(schema_validator) = &self.user_provider {
596 if let Err(e) = schema_validator
597 .authorize(&catalog, &schema, user_info)
598 .await
599 {
600 METRIC_AUTH_FAILURE
601 .with_label_values(&[e.status_code().as_ref()])
602 .inc();
603 return w
604 .error(
605 ErrorKind::ER_DBACCESS_DENIED_ERROR,
606 e.output_msg().as_bytes(),
607 )
608 .await
609 .map_err(|e| e.into());
610 }
611 }
612
613 if catalog_from_db.is_some() {
614 self.session.set_catalog(catalog)
615 }
616 self.session.set_schema(schema);
617
618 w.ok().await.map_err(|e| e.into())
619 }
620}
621
622fn convert_param_value_to_string(param: &ParamValue) -> String {
623 match param.value.into_inner() {
624 ValueInner::Int(u) => u.to_string(),
625 ValueInner::UInt(u) => u.to_string(),
626 ValueInner::Double(u) => u.to_string(),
627 ValueInner::NULL => "NULL".to_string(),
628 ValueInner::Bytes(b) => format!("'{}'", &String::from_utf8_lossy(b)),
629 ValueInner::Date(_) => format!("'{}'", NaiveDate::from(param.value)),
630 ValueInner::Datetime(_) => format!("'{}'", NaiveDateTime::from(param.value)),
631 ValueInner::Time(_) => format_duration(Duration::from(param.value)),
632 }
633}
634
635fn replace_params(params: Vec<String>, query: String) -> String {
636 let mut query = query;
637 let mut index = 1;
638 for param in params {
639 query = query.replace(&format_placeholder(index), ¶m);
640 index += 1;
641 }
642 query
643}
644
645fn format_duration(duration: Duration) -> String {
646 let seconds = duration.as_secs() % 60;
647 let minutes = (duration.as_secs() / 60) % 60;
648 let hours = (duration.as_secs() / 60) / 60;
649 format!("'{}:{}:{}'", hours, minutes, seconds)
650}
651
652fn replace_params_with_values(
653 plan: &LogicalPlan,
654 param_types: HashMap<String, Option<ConcreteDataType>>,
655 params: &[ParamValue],
656) -> Result<LogicalPlan> {
657 debug_assert_eq!(param_types.len(), params.len());
658
659 debug!(
660 "replace_params_with_values(param_types: {:#?}, params: {:#?}, plan: {:#?})",
661 param_types,
662 params
663 .iter()
664 .map(|x| format!("({:?}, {:?})", x.value, x.coltype))
665 .join(", "),
666 plan
667 );
668
669 let mut values = Vec::with_capacity(params.len());
670
671 for (i, param) in params.iter().enumerate() {
672 if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
673 let value = helper::convert_value(param, t)?;
674
675 values.push(value);
676 }
677 }
678
679 plan.clone()
680 .replace_params_with_values(&ParamValues::List(values.clone()))
681 .context(DataFrameSnafu)
682}
683
684fn replace_params_with_exprs(
685 plan: &LogicalPlan,
686 param_types: HashMap<String, Option<ConcreteDataType>>,
687 params: &[sql::ast::Expr],
688) -> Result<LogicalPlan> {
689 debug_assert_eq!(param_types.len(), params.len());
690
691 debug!(
692 "replace_params_with_exprs(param_types: {:#?}, params: {:#?}, plan: {:#?})",
693 param_types,
694 params.iter().map(|x| format!("({:?})", x)).join(", "),
695 plan
696 );
697
698 let mut values = Vec::with_capacity(params.len());
699
700 for (i, param) in params.iter().enumerate() {
701 if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
702 let value = helper::convert_expr_to_scalar_value(param, t)?;
703
704 values.push(value);
705 }
706 }
707
708 plan.clone()
709 .replace_params_with_values(&ParamValues::List(values.clone()))
710 .context(DataFrameSnafu)
711}
712
713async fn validate_query(query: &str) -> Result<Statement> {
714 let statement =
715 ParserContext::create_with_dialect(query, &MySqlDialect {}, ParseOptions::default());
716 let mut statement = statement.map_err(|e| {
717 InvalidPrepareStatementSnafu {
718 err_msg: e.output_msg(),
719 }
720 .build()
721 })?;
722
723 ensure!(
724 statement.len() == 1,
725 InvalidPrepareStatementSnafu {
726 err_msg: "prepare statement only support single statement".to_string(),
727 }
728 );
729
730 let statement = statement.remove(0);
731
732 Ok(statement)
733}
734
735fn dummy_params(index: usize) -> Result<Vec<Column>> {
736 let mut params = Vec::with_capacity(index - 1);
737
738 for _ in 1..index {
739 params.push(create_mysql_column(&ConcreteDataType::null_datatype(), "")?);
740 }
741
742 Ok(params)
743}
744
745fn prepared_params(param_types: &HashMap<String, Option<ConcreteDataType>>) -> Result<Vec<Column>> {
747 let mut params = Vec::with_capacity(param_types.len());
748
749 for index in 1..=param_types.len() {
751 if let Some(Some(t)) = param_types.get(&format_placeholder(index)) {
752 let column = create_mysql_column(t, "")?;
753 params.push(column);
754 }
755 }
756
757 Ok(params)
758}