1use std::collections::HashMap;
16use std::net::SocketAddr;
17use std::sync::atomic::{AtomicU32, Ordering};
18use std::sync::Arc;
19use std::time::Duration;
20
21use ::auth::{Identity, Password, UserProviderRef};
22use async_trait::async_trait;
23use chrono::{NaiveDate, NaiveDateTime};
24use common_catalog::parse_optional_catalog_and_schema_from_db_string;
25use common_error::ext::ErrorExt;
26use common_query::Output;
27use common_telemetry::{debug, error, tracing, warn};
28use datafusion_common::ParamValues;
29use datafusion_expr::LogicalPlan;
30use datatypes::prelude::ConcreteDataType;
31use itertools::Itertools;
32use opensrv_mysql::{
33 AsyncMysqlShim, Column, ErrorKind, InitWriter, ParamParser, ParamValue, QueryResultWriter,
34 StatementMetaWriter, ValueInner,
35};
36use parking_lot::RwLock;
37use query::query_engine::DescribeResult;
38use rand::RngCore;
39use session::context::{Channel, QueryContextRef};
40use session::{Session, SessionRef};
41use snafu::{ensure, ResultExt};
42use sql::dialect::MySqlDialect;
43use sql::parser::{ParseOptions, ParserContext};
44use sql::statements::statement::Statement;
45use tokio::io::AsyncWrite;
46
47use crate::error::{self, DataFrameSnafu, InvalidPrepareStatementSnafu, Result};
48use crate::metrics::METRIC_AUTH_FAILURE;
49use crate::mysql::helper::{
50 self, fix_placeholder_types, format_placeholder, replace_placeholders, transform_placeholders,
51};
52use crate::mysql::writer;
53use crate::mysql::writer::{create_mysql_column, handle_err};
54use crate::query_handler::sql::ServerSqlQueryHandlerRef;
55use crate::SqlPlan;
56
57const MYSQL_NATIVE_PASSWORD: &str = "mysql_native_password";
58const MYSQL_CLEAR_PASSWORD: &str = "mysql_clear_password";
59
60enum Params<'a> {
62 ProtocolParams(Vec<ParamValue<'a>>),
64 CliParams(Vec<sql::ast::Expr>),
66}
67
68impl Params<'_> {
69 fn len(&self) -> usize {
70 match self {
71 Params::ProtocolParams(params) => params.len(),
72 Params::CliParams(params) => params.len(),
73 }
74 }
75}
76
77pub struct MysqlInstanceShim {
79 query_handler: ServerSqlQueryHandlerRef,
80 salt: [u8; 20],
81 session: SessionRef,
82 user_provider: Option<UserProviderRef>,
83 prepared_stmts: Arc<RwLock<HashMap<String, SqlPlan>>>,
84 prepared_stmts_counter: AtomicU32,
85}
86
87impl MysqlInstanceShim {
88 pub fn create(
89 query_handler: ServerSqlQueryHandlerRef,
90 user_provider: Option<UserProviderRef>,
91 client_addr: SocketAddr,
92 ) -> MysqlInstanceShim {
93 let mut bs = vec![0u8; 20];
95 let mut rng = rand::rng();
96 rng.fill_bytes(bs.as_mut());
97
98 let mut scramble: [u8; 20] = [0; 20];
99 for i in 0..20 {
100 scramble[i] = bs[i] & 0x7fu8;
101 if scramble[i] == b'\0' || scramble[i] == b'$' {
102 scramble[i] += 1;
103 }
104 }
105
106 MysqlInstanceShim {
107 query_handler,
108 salt: scramble,
109 session: Arc::new(Session::new(
110 Some(client_addr),
111 Channel::Mysql,
112 Default::default(),
113 )),
114 user_provider,
115 prepared_stmts: Default::default(),
116 prepared_stmts_counter: AtomicU32::new(1),
117 }
118 }
119
120 #[tracing::instrument(skip_all, name = "mysql::do_query")]
121 async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec<Result<Output>> {
122 if let Some(output) =
123 crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone())
124 {
125 vec![Ok(output)]
126 } else {
127 self.query_handler.do_query(query, query_ctx.clone()).await
128 }
129 }
130
131 async fn do_exec_plan(
133 &self,
134 query: &str,
135 plan: LogicalPlan,
136 query_ctx: QueryContextRef,
137 ) -> Result<Output> {
138 if let Some(output) =
139 crate::mysql::federated::check(query, query_ctx.clone(), self.session.clone())
140 {
141 Ok(output)
142 } else {
143 self.query_handler.do_exec_plan(plan, query_ctx).await
144 }
145 }
146
147 async fn do_describe(
149 &self,
150 statement: Statement,
151 query_ctx: QueryContextRef,
152 ) -> Result<Option<DescribeResult>> {
153 self.query_handler.do_describe(statement, query_ctx).await
154 }
155
156 fn save_plan(&self, plan: SqlPlan, stmt_key: String) {
158 let mut prepared_stmts = self.prepared_stmts.write();
159 let _ = prepared_stmts.insert(stmt_key, plan);
160 }
161
162 fn plan(&self, stmt_key: &str) -> Option<SqlPlan> {
164 let guard = self.prepared_stmts.read();
165 guard.get(stmt_key).cloned()
166 }
167
168 async fn do_prepare(
170 &mut self,
171 raw_query: &str,
172 query_ctx: QueryContextRef,
173 stmt_key: String,
174 ) -> Result<(Vec<Column>, Vec<Column>)> {
175 let (query, param_num) = replace_placeholders(raw_query);
176
177 let statement = validate_query(raw_query).await?;
178
179 let statement = transform_placeholders(statement);
182
183 let describe_result = self
184 .do_describe(statement.clone(), query_ctx.clone())
185 .await?;
186 let (mut plan, schema) = if let Some(DescribeResult {
187 logical_plan,
188 schema,
189 }) = describe_result
190 {
191 (Some(logical_plan), Some(schema))
192 } else {
193 (None, None)
194 };
195
196 let params = if let Some(plan) = &mut plan {
197 fix_placeholder_types(plan)?;
198 debug!("Plan after fix placeholder types: {:#?}", plan);
199 prepared_params(
200 &plan
201 .get_parameter_types()
202 .context(DataFrameSnafu)?
203 .into_iter()
204 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
205 .collect(),
206 )?
207 } else {
208 dummy_params(param_num)?
209 };
210
211 let columns = schema
212 .as_ref()
213 .map(|schema| {
214 schema
215 .column_schemas()
216 .iter()
217 .map(|column_schema| {
218 create_mysql_column(&column_schema.data_type, &column_schema.name)
219 })
220 .collect::<Result<Vec<_>>>()
221 })
222 .transpose()?
223 .unwrap_or_default();
224
225 if params.len() != param_num - 1 {
227 self.save_plan(
228 SqlPlan {
229 query: query.to_string(),
230 plan: None,
231 schema: None,
232 },
233 stmt_key,
234 );
235 } else {
236 self.save_plan(
237 SqlPlan {
238 query: query.to_string(),
239 plan,
240 schema,
241 },
242 stmt_key,
243 );
244 }
245
246 Ok((params, columns))
247 }
248
249 async fn do_execute(
250 &mut self,
251 query_ctx: QueryContextRef,
252 stmt_key: String,
253 params: Params<'_>,
254 ) -> Result<Vec<std::result::Result<Output, error::Error>>> {
255 let sql_plan = match self.plan(&stmt_key) {
256 None => {
257 return error::PrepareStatementNotFoundSnafu { name: stmt_key }.fail();
258 }
259 Some(sql_plan) => sql_plan,
260 };
261
262 let outputs = match sql_plan.plan {
263 Some(mut plan) => {
264 fix_placeholder_types(&mut plan)?;
265 let param_types = plan
266 .get_parameter_types()
267 .context(DataFrameSnafu)?
268 .into_iter()
269 .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v))))
270 .collect::<HashMap<_, _>>();
271
272 if params.len() != param_types.len() {
273 return error::InternalSnafu {
274 err_msg: "Prepare statement params number mismatch".to_string(),
275 }
276 .fail();
277 }
278
279 let plan = match params {
280 Params::ProtocolParams(params) => {
281 replace_params_with_values(&plan, param_types, ¶ms)
282 }
283 Params::CliParams(params) => {
284 replace_params_with_exprs(&plan, param_types, ¶ms)
285 }
286 }?;
287
288 debug!("Mysql execute prepared plan: {}", plan.display_indent());
289 vec![
290 self.do_exec_plan(&sql_plan.query, plan, query_ctx.clone())
291 .await,
292 ]
293 }
294 None => {
295 let param_strs = match params {
296 Params::ProtocolParams(params) => {
297 params.iter().map(convert_param_value_to_string).collect()
298 }
299 Params::CliParams(params) => params.iter().map(|x| x.to_string()).collect(),
300 };
301 debug!(
302 "do_execute Replacing with Params: {:?}, Original Query: {}",
303 param_strs, sql_plan.query
304 );
305 let query = replace_params(param_strs, sql_plan.query);
306 debug!("Mysql execute replaced query: {}", query);
307 self.do_query(&query, query_ctx.clone()).await
308 }
309 };
310
311 Ok(outputs)
312 }
313
314 fn do_close(&mut self, stmt_key: String) {
316 let mut guard = self.prepared_stmts.write();
317 let _ = guard.remove(&stmt_key);
318 }
319
320 fn auth_plugin(&self) -> &str {
321 if self
322 .user_provider
323 .as_ref()
324 .map(|x| x.external())
325 .unwrap_or(false)
326 {
327 MYSQL_CLEAR_PASSWORD
328 } else {
329 MYSQL_NATIVE_PASSWORD
330 }
331 }
332}
333
334#[async_trait]
335impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShim {
336 type Error = error::Error;
337
338 fn version(&self) -> String {
339 std::env::var("GREPTIMEDB_MYSQL_SERVER_VERSION").unwrap_or_else(|_| "8.4.2".to_string())
340 }
341
342 fn default_auth_plugin(&self) -> &str {
343 self.auth_plugin()
344 }
345
346 async fn auth_plugin_for_username<'a, 'user>(&'a self, _user: &'user [u8]) -> &'a str {
347 self.auth_plugin()
348 }
349
350 fn salt(&self) -> [u8; 20] {
351 self.salt
352 }
353
354 async fn authenticate(
355 &self,
356 auth_plugin: &str,
357 username: &[u8],
358 salt: &[u8],
359 auth_data: &[u8],
360 ) -> bool {
361 let username = String::from_utf8_lossy(username);
363
364 let mut user_info = None;
365 let addr = self
366 .session
367 .conn_info()
368 .client_addr
369 .map(|addr| addr.to_string());
370 if let Some(user_provider) = &self.user_provider {
371 let user_id = Identity::UserId(&username, addr.as_deref());
372
373 let password = match auth_plugin {
374 MYSQL_NATIVE_PASSWORD => Password::MysqlNativePassword(auth_data, salt),
375 MYSQL_CLEAR_PASSWORD => {
376 let password = if let &[password @ .., 0] = &auth_data {
379 password
380 } else {
381 auth_data
382 };
383 Password::PlainText(String::from_utf8_lossy(password).to_string().into())
384 }
385 other => {
386 error!("Unsupported mysql auth plugin: {}", other);
387 return false;
388 }
389 };
390 match user_provider.authenticate(user_id, password).await {
391 Ok(userinfo) => {
392 user_info = Some(userinfo);
393 }
394 Err(e) => {
395 METRIC_AUTH_FAILURE
396 .with_label_values(&[e.status_code().as_ref()])
397 .inc();
398 warn!(e; "Failed to auth");
399 return false;
400 }
401 };
402 }
403 let user_info =
404 user_info.unwrap_or_else(|| auth::userinfo_by_name(Some(username.to_string())));
405
406 self.session.set_user_info(user_info);
407
408 true
409 }
410
411 async fn on_prepare<'a>(
412 &'a mut self,
413 raw_query: &'a str,
414 w: StatementMetaWriter<'a, W>,
415 ) -> Result<()> {
416 let query_ctx = self.session.new_query_context();
417 let stmt_id = self.prepared_stmts_counter.fetch_add(1, Ordering::Relaxed);
418 let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
419 let (params, columns) = self
420 .do_prepare(raw_query, query_ctx.clone(), stmt_key)
421 .await?;
422 debug!("on_prepare: Params: {:?}, Columns: {:?}", params, columns);
423 w.reply(stmt_id, ¶ms, &columns).await?;
424 crate::metrics::METRIC_MYSQL_PREPARED_COUNT
425 .with_label_values(&[query_ctx.get_db_string().as_str()])
426 .inc();
427 return Ok(());
428 }
429
430 async fn on_execute<'a>(
431 &'a mut self,
432 stmt_id: u32,
433 p: ParamParser<'a>,
434 w: QueryResultWriter<'a, W>,
435 ) -> Result<()> {
436 let query_ctx = self.session.new_query_context();
437 let db = query_ctx.get_db_string();
438 let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER
439 .with_label_values(&[crate::metrics::METRIC_MYSQL_BINQUERY, db.as_str()])
440 .start_timer();
441
442 let params: Vec<ParamValue> = p.into_iter().collect();
443 let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
444
445 let outputs = match self
446 .do_execute(query_ctx.clone(), stmt_key, Params::ProtocolParams(params))
447 .await
448 {
449 Ok(outputs) => outputs,
450 Err(e) => {
451 let (kind, err) = handle_err(e, query_ctx);
452 debug!(
453 "Failed to execute prepared statement, kind: {:?}, err: {}",
454 kind, err
455 );
456 w.error(kind, err.as_bytes()).await?;
457 return Ok(());
458 }
459 };
460
461 writer::write_output(w, query_ctx, outputs).await?;
462
463 Ok(())
464 }
465
466 async fn on_close<'a>(&'a mut self, stmt_id: u32)
467 where
468 W: 'async_trait,
469 {
470 let stmt_key = uuid::Uuid::from_u128(stmt_id as u128).to_string();
471 self.do_close(stmt_key);
472 }
473
474 #[tracing::instrument(skip_all, fields(protocol = "mysql"))]
475 async fn on_query<'a>(
476 &'a mut self,
477 query: &'a str,
478 writer: QueryResultWriter<'a, W>,
479 ) -> Result<()> {
480 let query_ctx = self.session.new_query_context();
481 let db = query_ctx.get_db_string();
482 let _timer = crate::metrics::METRIC_MYSQL_QUERY_TIMER
483 .with_label_values(&[crate::metrics::METRIC_MYSQL_TEXTQUERY, db.as_str()])
484 .start_timer();
485
486 let query_upcase = query.to_uppercase();
487 if query_upcase.starts_with("PREPARE ") {
488 match ParserContext::parse_mysql_prepare_stmt(query, query_ctx.sql_dialect()) {
489 Ok((stmt_name, stmt)) => {
490 let prepare_results =
491 self.do_prepare(&stmt, query_ctx.clone(), stmt_name).await;
492 match prepare_results {
493 Ok(_) => {
494 let outputs = vec![Ok(Output::new_with_affected_rows(0))];
495 writer::write_output(writer, query_ctx, outputs).await?;
496 return Ok(());
497 }
498 Err(e) => {
499 writer
500 .error(ErrorKind::ER_SP_BADSTATEMENT, e.output_msg().as_bytes())
501 .await?;
502 return Ok(());
503 }
504 }
505 }
506 Err(e) => {
507 writer
508 .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
509 .await?;
510 return Ok(());
511 }
512 }
513 } else if query_upcase.starts_with("EXECUTE ") {
514 match ParserContext::parse_mysql_execute_stmt(query, query_ctx.sql_dialect()) {
515 Ok((stmt_name, params)) => {
516 let outputs = match self
517 .do_execute(query_ctx.clone(), stmt_name, Params::CliParams(params))
518 .await
519 {
520 Ok(outputs) => outputs,
521 Err(e) => {
522 let (kind, err) = handle_err(e, query_ctx);
523 debug!(
524 "Failed to execute prepared statement, kind: {:?}, err: {}",
525 kind, err
526 );
527 writer.error(kind, err.as_bytes()).await?;
528 return Ok(());
529 }
530 };
531 writer::write_output(writer, query_ctx, outputs).await?;
532 return Ok(());
533 }
534 Err(e) => {
535 writer
536 .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
537 .await?;
538 return Ok(());
539 }
540 }
541 } else if query_upcase.starts_with("DEALLOCATE ") {
542 match ParserContext::parse_mysql_deallocate_stmt(query, query_ctx.sql_dialect()) {
543 Ok(stmt_name) => {
544 self.do_close(stmt_name);
545 let outputs = vec![Ok(Output::new_with_affected_rows(0))];
546 writer::write_output(writer, query_ctx, outputs).await?;
547 return Ok(());
548 }
549 Err(e) => {
550 writer
551 .error(ErrorKind::ER_PARSE_ERROR, e.output_msg().as_bytes())
552 .await?;
553 return Ok(());
554 }
555 }
556 }
557
558 let outputs = self.do_query(query, query_ctx.clone()).await;
559 writer::write_output(writer, query_ctx, outputs).await?;
560 Ok(())
561 }
562
563 async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> {
564 let (catalog_from_db, schema) = parse_optional_catalog_and_schema_from_db_string(database);
565 let catalog = if let Some(catalog) = &catalog_from_db {
566 catalog.to_string()
567 } else {
568 self.session.catalog()
569 };
570
571 if !self
572 .query_handler
573 .is_valid_schema(&catalog, &schema)
574 .await?
575 {
576 return w
577 .error(
578 ErrorKind::ER_WRONG_DB_NAME,
579 format!("Unknown database '{}'", database).as_bytes(),
580 )
581 .await
582 .map_err(|e| e.into());
583 }
584
585 let user_info = &self.session.user_info();
586
587 if let Some(schema_validator) = &self.user_provider {
588 if let Err(e) = schema_validator
589 .authorize(&catalog, &schema, user_info)
590 .await
591 {
592 METRIC_AUTH_FAILURE
593 .with_label_values(&[e.status_code().as_ref()])
594 .inc();
595 return w
596 .error(
597 ErrorKind::ER_DBACCESS_DENIED_ERROR,
598 e.output_msg().as_bytes(),
599 )
600 .await
601 .map_err(|e| e.into());
602 }
603 }
604
605 if catalog_from_db.is_some() {
606 self.session.set_catalog(catalog)
607 }
608 self.session.set_schema(schema);
609
610 w.ok().await.map_err(|e| e.into())
611 }
612}
613
614fn convert_param_value_to_string(param: &ParamValue) -> String {
615 match param.value.into_inner() {
616 ValueInner::Int(u) => u.to_string(),
617 ValueInner::UInt(u) => u.to_string(),
618 ValueInner::Double(u) => u.to_string(),
619 ValueInner::NULL => "NULL".to_string(),
620 ValueInner::Bytes(b) => format!("'{}'", &String::from_utf8_lossy(b)),
621 ValueInner::Date(_) => format!("'{}'", NaiveDate::from(param.value)),
622 ValueInner::Datetime(_) => format!("'{}'", NaiveDateTime::from(param.value)),
623 ValueInner::Time(_) => format_duration(Duration::from(param.value)),
624 }
625}
626
627fn replace_params(params: Vec<String>, query: String) -> String {
628 let mut query = query;
629 let mut index = 1;
630 for param in params {
631 query = query.replace(&format_placeholder(index), ¶m);
632 index += 1;
633 }
634 query
635}
636
637fn format_duration(duration: Duration) -> String {
638 let seconds = duration.as_secs() % 60;
639 let minutes = (duration.as_secs() / 60) % 60;
640 let hours = (duration.as_secs() / 60) / 60;
641 format!("'{}:{}:{}'", hours, minutes, seconds)
642}
643
644fn replace_params_with_values(
645 plan: &LogicalPlan,
646 param_types: HashMap<String, Option<ConcreteDataType>>,
647 params: &[ParamValue],
648) -> Result<LogicalPlan> {
649 debug_assert_eq!(param_types.len(), params.len());
650
651 debug!(
652 "replace_params_with_values(param_types: {:#?}, params: {:#?}, plan: {:#?})",
653 param_types,
654 params
655 .iter()
656 .map(|x| format!("({:?}, {:?})", x.value, x.coltype))
657 .join(", "),
658 plan
659 );
660
661 let mut values = Vec::with_capacity(params.len());
662
663 for (i, param) in params.iter().enumerate() {
664 if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
665 let value = helper::convert_value(param, t)?;
666
667 values.push(value);
668 }
669 }
670
671 plan.clone()
672 .replace_params_with_values(&ParamValues::List(values.clone()))
673 .context(DataFrameSnafu)
674}
675
676fn replace_params_with_exprs(
677 plan: &LogicalPlan,
678 param_types: HashMap<String, Option<ConcreteDataType>>,
679 params: &[sql::ast::Expr],
680) -> Result<LogicalPlan> {
681 debug_assert_eq!(param_types.len(), params.len());
682
683 debug!(
684 "replace_params_with_exprs(param_types: {:#?}, params: {:#?}, plan: {:#?})",
685 param_types,
686 params.iter().map(|x| format!("({:?})", x)).join(", "),
687 plan
688 );
689
690 let mut values = Vec::with_capacity(params.len());
691
692 for (i, param) in params.iter().enumerate() {
693 if let Some(Some(t)) = param_types.get(&format_placeholder(i + 1)) {
694 let value = helper::convert_expr_to_scalar_value(param, t)?;
695
696 values.push(value);
697 }
698 }
699
700 plan.clone()
701 .replace_params_with_values(&ParamValues::List(values.clone()))
702 .context(DataFrameSnafu)
703}
704
705async fn validate_query(query: &str) -> Result<Statement> {
706 let statement =
707 ParserContext::create_with_dialect(query, &MySqlDialect {}, ParseOptions::default());
708 let mut statement = statement.map_err(|e| {
709 InvalidPrepareStatementSnafu {
710 err_msg: e.output_msg(),
711 }
712 .build()
713 })?;
714
715 ensure!(
716 statement.len() == 1,
717 InvalidPrepareStatementSnafu {
718 err_msg: "prepare statement only support single statement".to_string(),
719 }
720 );
721
722 let statement = statement.remove(0);
723
724 Ok(statement)
725}
726
727fn dummy_params(index: usize) -> Result<Vec<Column>> {
728 let mut params = Vec::with_capacity(index - 1);
729
730 for _ in 1..index {
731 params.push(create_mysql_column(&ConcreteDataType::null_datatype(), "")?);
732 }
733
734 Ok(params)
735}
736
737fn prepared_params(param_types: &HashMap<String, Option<ConcreteDataType>>) -> Result<Vec<Column>> {
739 let mut params = Vec::with_capacity(param_types.len());
740
741 for index in 1..=param_types.len() {
743 if let Some(Some(t)) = param_types.get(&format_placeholder(index)) {
744 let column = create_mysql_column(t, "")?;
745 params.push(column);
746 }
747 }
748
749 Ok(params)
750}