1use std::any::Any;
16use std::collections::HashMap;
17use std::marker::PhantomData;
18#[cfg(any(feature = "pg_kvbackend", feature = "mysql_kvbackend"))]
19use std::sync::OnceLock;
20use std::time::Duration;
21
22use backon::{BackoffBuilder, ExponentialBuilder};
23use common_telemetry::{debug, info};
24
25use crate::error::{Error, RdsTransactionRetryFailedSnafu, Result, UnexpectedSnafu};
26use crate::kv_backend::txn::{
27 Compare, Txn as KvTxn, TxnOp, TxnOpResponse, TxnResponse as KvTxnResponse,
28};
29use crate::kv_backend::{KvBackend, TxnService};
30use crate::metrics::METRIC_META_TXN_REQUEST;
31use crate::rpc::KeyValue;
32use crate::rpc::store::{
33 BatchDeleteRequest, BatchDeleteResponse, BatchGetRequest, BatchGetResponse, BatchPutRequest,
34 BatchPutResponse, DeleteRangeRequest, DeleteRangeResponse, PutRequest, PutResponse,
35 RangeRequest, RangeResponse,
36};
37
38const RDS_STORE_OP_BATCH_GET: &str = "batch_get";
39const RDS_STORE_OP_BATCH_PUT: &str = "batch_put";
40const RDS_STORE_OP_RANGE_QUERY: &str = "range_query";
41const RDS_STORE_OP_RANGE_DELETE: &str = "range_delete";
42const RDS_STORE_OP_BATCH_DELETE: &str = "batch_delete";
43
44#[cfg(feature = "pg_kvbackend")]
45pub mod postgres;
46#[cfg(feature = "pg_kvbackend")]
47pub use postgres::PgStore;
48
49#[cfg(feature = "mysql_kvbackend")]
50mod mysql;
51#[cfg(feature = "mysql_kvbackend")]
52pub use mysql::MySqlStore;
53
54const RDS_STORE_TXN_RETRY_COUNT: usize = 3;
55
56#[cfg(any(feature = "pg_kvbackend", feature = "mysql_kvbackend"))]
57static RUSTLS_CRYPTO_PROVIDER_INIT: OnceLock<std::result::Result<(), String>> = OnceLock::new();
58
59#[cfg(any(feature = "pg_kvbackend", feature = "mysql_kvbackend"))]
60pub(crate) fn ensure_rustls_crypto_provider_installed() -> Result<()> {
61 RUSTLS_CRYPTO_PROVIDER_INIT
62 .get_or_init(|| {
63 if rustls::crypto::CryptoProvider::get_default().is_some() {
64 return Ok(());
65 }
66
67 match rustls::crypto::CryptoProvider::install_default(
68 rustls::crypto::aws_lc_rs::default_provider(),
69 ) {
70 Ok(()) => Ok(()),
71 Err(_provider) if rustls::crypto::CryptoProvider::get_default().is_some() => {
72 Ok(())
73 }
74 Err(provider) => Err(format!(
75 "Failed to install rustls CryptoProvider, existing default: {:?}, attempted provider: {:?}",
76 rustls::crypto::CryptoProvider::get_default(),
77 provider
78 )),
79 }
80 })
81 .clone()
82 .map_err(|err_msg| {
83 info!("Failed to install rustls crypto provider: {err_msg}");
84 UnexpectedSnafu { err_msg }.build()
85 })
86}
87
88#[async_trait::async_trait]
90pub trait Executor: Send + Sync {
91 type Transaction<'a>: 'a + Transaction<'a>
92 where
93 Self: 'a;
94
95 fn name() -> &'static str;
96
97 async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>>;
98
99 async fn execute(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<()> {
101 self.query(query, params).await?;
102 Ok(())
103 }
104
105 async fn txn_executor<'a>(&'a mut self) -> Result<Self::Transaction<'a>>;
106}
107
108#[async_trait::async_trait]
110pub trait Transaction<'a>: Send + Sync {
111 async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>>;
112
113 async fn execute(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<()> {
114 self.query(query, params).await?;
115 Ok(())
116 }
117
118 async fn commit(self) -> Result<()>;
119}
120
121#[async_trait::async_trait]
123pub trait ExecutorFactory<T: Executor>: Send + Sync {
124 async fn default_executor(&self) -> Result<T>;
125
126 async fn txn_executor<'a>(&self, default_executor: &'a mut T) -> Result<T::Transaction<'a>>;
127}
128
129pub struct RdsStore<T, S, R>
131where
132 T: Executor + Send + Sync,
133 S: ExecutorFactory<T> + Send + Sync,
134{
135 max_txn_ops: usize,
136 txn_retry_count: usize,
137 executor_factory: S,
138 sql_template_set: R,
139 _phantom: PhantomData<T>,
140}
141
142pub enum ExecutorImpl<'a, T: Executor + 'a> {
143 Default(T),
144 Txn(T::Transaction<'a>),
145}
146
147impl<T: Executor> ExecutorImpl<'_, T> {
148 async fn query(&mut self, query: &str, params: &Vec<&Vec<u8>>) -> Result<Vec<KeyValue>> {
149 match self {
150 Self::Default(executor) => executor.query(query, params).await,
151 Self::Txn(executor) => executor.query(query, params).await,
152 }
153 }
154
155 #[allow(dead_code)] async fn execute(&mut self, query: &str, params: &Vec<&Vec<u8>>) -> Result<()> {
157 match self {
158 Self::Default(executor) => executor.execute(query, params).await,
159 Self::Txn(executor) => executor.execute(query, params).await,
160 }
161 }
162
163 async fn commit(self) -> Result<()> {
164 match self {
165 Self::Txn(executor) => executor.commit().await,
166 _ => Ok(()),
167 }
168 }
169}
170
171#[async_trait::async_trait]
172pub trait KvQueryExecutor<T: Executor> {
173 async fn range_with_query_executor(
174 &self,
175 query_executor: &mut ExecutorImpl<'_, T>,
176 req: RangeRequest,
177 ) -> Result<RangeResponse>;
178
179 async fn put_with_query_executor(
180 &self,
181 query_executor: &mut ExecutorImpl<'_, T>,
182 req: PutRequest,
183 ) -> Result<PutResponse> {
184 let kv = KeyValue {
185 key: req.key,
186 value: req.value,
187 };
188 let mut res = self
189 .batch_put_with_query_executor(
190 query_executor,
191 BatchPutRequest {
192 kvs: vec![kv],
193 prev_kv: req.prev_kv,
194 },
195 )
196 .await?;
197
198 if !res.prev_kvs.is_empty() {
199 debug_assert!(req.prev_kv);
200 return Ok(PutResponse {
201 prev_kv: Some(res.prev_kvs.remove(0)),
202 });
203 }
204 Ok(PutResponse::default())
205 }
206
207 async fn batch_put_with_query_executor(
208 &self,
209 query_executor: &mut ExecutorImpl<'_, T>,
210 req: BatchPutRequest,
211 ) -> Result<BatchPutResponse>;
212
213 async fn batch_get_with_query_executor(
215 &self,
216 query_executor: &mut ExecutorImpl<'_, T>,
217 req: BatchGetRequest,
218 ) -> Result<BatchGetResponse>;
219
220 async fn delete_range_with_query_executor(
221 &self,
222 query_executor: &mut ExecutorImpl<'_, T>,
223 req: DeleteRangeRequest,
224 ) -> Result<DeleteRangeResponse>;
225
226 async fn batch_delete_with_query_executor(
227 &self,
228 query_executor: &mut ExecutorImpl<'_, T>,
229 req: BatchDeleteRequest,
230 ) -> Result<BatchDeleteResponse>;
231}
232
233impl<T, S, R> RdsStore<T, S, R>
234where
235 Self: KvQueryExecutor<T> + Send + Sync,
236 T: Executor + Send + Sync,
237 S: ExecutorFactory<T> + Send + Sync,
238{
239 async fn execute_txn_cmp(
240 &self,
241 query_executor: &mut ExecutorImpl<'_, T>,
242 cmp: &[Compare],
243 ) -> Result<bool> {
244 let batch_get_req = BatchGetRequest {
245 keys: cmp.iter().map(|c| c.key.clone()).collect(),
246 };
247 let res = self
248 .batch_get_with_query_executor(query_executor, batch_get_req)
249 .await?;
250 debug!("batch get res: {:?}", res);
251 let res_map = res
252 .kvs
253 .into_iter()
254 .map(|kv| (kv.key, kv.value))
255 .collect::<HashMap<Vec<u8>, Vec<u8>>>();
256 for c in cmp {
257 let value = res_map.get(&c.key);
258 if !c.compare_value(value) {
259 return Ok(false);
260 }
261 }
262 Ok(true)
263 }
264
265 async fn try_batch_txn(
267 &self,
268 query_executor: &mut ExecutorImpl<'_, T>,
269 txn_ops: &[TxnOp],
270 ) -> Result<Option<Vec<TxnOpResponse>>> {
271 if !check_txn_ops(txn_ops)? {
272 return Ok(None);
273 }
274 match txn_ops.first().unwrap() {
276 TxnOp::Delete(_) => self.handle_batch_delete(query_executor, txn_ops).await,
277 TxnOp::Put(_, _) => self.handle_batch_put(query_executor, txn_ops).await,
278 TxnOp::Get(_) => self.handle_batch_get(query_executor, txn_ops).await,
279 }
280 }
281
282 async fn handle_batch_delete(
283 &self,
284 query_executor: &mut ExecutorImpl<'_, T>,
285 txn_ops: &[TxnOp],
286 ) -> Result<Option<Vec<TxnOpResponse>>> {
287 let mut batch_del_req = BatchDeleteRequest {
288 keys: vec![],
289 prev_kv: true,
290 };
291 for op in txn_ops {
292 if let TxnOp::Delete(key) = op {
293 batch_del_req.keys.push(key.clone());
294 }
295 }
296 let res = self
297 .batch_delete_with_query_executor(query_executor, batch_del_req)
298 .await?;
299 let res_map = res
300 .prev_kvs
301 .into_iter()
302 .map(|kv| (kv.key, kv.value))
303 .collect::<HashMap<Vec<u8>, Vec<u8>>>();
304 let mut resps = Vec::with_capacity(txn_ops.len());
305 for op in txn_ops {
306 if let TxnOp::Delete(key) = op {
307 let value = res_map.get(key);
308 resps.push(TxnOpResponse::ResponseDelete(DeleteRangeResponse {
309 deleted: if value.is_some() { 1 } else { 0 },
310 prev_kvs: vec![],
311 }));
312 }
313 }
314 Ok(Some(resps))
315 }
316
317 async fn handle_batch_put(
318 &self,
319 query_executor: &mut ExecutorImpl<'_, T>,
320 txn_ops: &[TxnOp],
321 ) -> Result<Option<Vec<TxnOpResponse>>> {
322 let mut batch_put_req = BatchPutRequest {
323 kvs: vec![],
324 prev_kv: false,
325 };
326 for op in txn_ops {
327 if let TxnOp::Put(key, value) = op {
328 batch_put_req.kvs.push(KeyValue {
329 key: key.clone(),
330 value: value.clone(),
331 });
332 }
333 }
334 let _ = self
335 .batch_put_with_query_executor(query_executor, batch_put_req)
336 .await?;
337 let mut resps = Vec::with_capacity(txn_ops.len());
338 for op in txn_ops {
339 if let TxnOp::Put(_, _) = op {
340 resps.push(TxnOpResponse::ResponsePut(PutResponse { prev_kv: None }));
341 }
342 }
343 Ok(Some(resps))
344 }
345
346 async fn handle_batch_get(
347 &self,
348 query_executor: &mut ExecutorImpl<'_, T>,
349 txn_ops: &[TxnOp],
350 ) -> Result<Option<Vec<TxnOpResponse>>> {
351 let mut batch_get_req = BatchGetRequest { keys: vec![] };
352 for op in txn_ops {
353 if let TxnOp::Get(key) = op {
354 batch_get_req.keys.push(key.clone());
355 }
356 }
357 let res = self
358 .batch_get_with_query_executor(query_executor, batch_get_req)
359 .await?;
360 let res_map = res
361 .kvs
362 .into_iter()
363 .map(|kv| (kv.key, kv.value))
364 .collect::<HashMap<Vec<u8>, Vec<u8>>>();
365 let mut resps = Vec::with_capacity(txn_ops.len());
366 for op in txn_ops {
367 if let TxnOp::Get(key) = op {
368 let value = res_map.get(key);
369 resps.push(TxnOpResponse::ResponseGet(RangeResponse {
370 kvs: value
371 .map(|v| {
372 vec![KeyValue {
373 key: key.clone(),
374 value: v.clone(),
375 }]
376 })
377 .unwrap_or_default(),
378 more: false,
379 }));
380 }
381 }
382 Ok(Some(resps))
383 }
384
385 async fn execute_txn_op(
386 &self,
387 query_executor: &mut ExecutorImpl<'_, T>,
388 op: &TxnOp,
389 ) -> Result<TxnOpResponse> {
390 match op {
391 TxnOp::Put(key, value) => {
392 let res = self
393 .put_with_query_executor(
394 query_executor,
395 PutRequest {
396 key: key.clone(),
397 value: value.clone(),
398 prev_kv: false,
399 },
400 )
401 .await?;
402 Ok(TxnOpResponse::ResponsePut(res))
403 }
404 TxnOp::Get(key) => {
405 let res = self
406 .range_with_query_executor(
407 query_executor,
408 RangeRequest {
409 key: key.clone(),
410 range_end: vec![],
411 limit: 1,
412 keys_only: false,
413 },
414 )
415 .await?;
416 Ok(TxnOpResponse::ResponseGet(res))
417 }
418 TxnOp::Delete(key) => {
419 let res = self
420 .delete_range_with_query_executor(
421 query_executor,
422 DeleteRangeRequest {
423 key: key.clone(),
424 range_end: vec![],
425 prev_kv: false,
426 },
427 )
428 .await?;
429 Ok(TxnOpResponse::ResponseDelete(res))
430 }
431 }
432 }
433
434 async fn txn_inner(&self, txn: &KvTxn) -> Result<KvTxnResponse> {
435 let mut default_executor = self.executor_factory.default_executor().await?;
436 let mut txn_executor = ExecutorImpl::Txn(
437 self.executor_factory
438 .txn_executor(&mut default_executor)
439 .await?,
440 );
441 let mut success = true;
442 if txn.c_when {
443 success = self
444 .execute_txn_cmp(&mut txn_executor, &txn.req.compare)
445 .await?;
446 }
447 let mut responses = vec![];
448 if success && txn.c_then {
449 match self
450 .try_batch_txn(&mut txn_executor, &txn.req.success)
451 .await?
452 {
453 Some(res) => responses.extend(res),
454 None => {
455 for txnop in &txn.req.success {
456 let res = self.execute_txn_op(&mut txn_executor, txnop).await?;
457 responses.push(res);
458 }
459 }
460 }
461 } else if !success && txn.c_else {
462 match self
463 .try_batch_txn(&mut txn_executor, &txn.req.failure)
464 .await?
465 {
466 Some(res) => responses.extend(res),
467 None => {
468 for txnop in &txn.req.failure {
469 let res = self.execute_txn_op(&mut txn_executor, txnop).await?;
470 responses.push(res);
471 }
472 }
473 }
474 }
475
476 txn_executor.commit().await?;
477 Ok(KvTxnResponse {
478 responses,
479 succeeded: success,
480 })
481 }
482}
483
484#[async_trait::async_trait]
485impl<T, S, R> KvBackend for RdsStore<T, S, R>
486where
487 R: 'static,
488 Self: KvQueryExecutor<T> + Send + Sync,
489 T: Executor + 'static,
490 S: ExecutorFactory<T> + 'static,
491{
492 fn name(&self) -> &str {
493 T::name()
494 }
495
496 fn as_any(&self) -> &dyn Any {
497 self
498 }
499
500 async fn range(&self, req: RangeRequest) -> Result<RangeResponse> {
501 let client = self.executor_factory.default_executor().await?;
502 let mut query_executor = ExecutorImpl::Default(client);
503 self.range_with_query_executor(&mut query_executor, req)
504 .await
505 }
506
507 async fn put(&self, req: PutRequest) -> Result<PutResponse> {
508 let client = self.executor_factory.default_executor().await?;
509 let mut query_executor = ExecutorImpl::Default(client);
510 self.put_with_query_executor(&mut query_executor, req).await
511 }
512
513 async fn batch_put(&self, req: BatchPutRequest) -> Result<BatchPutResponse> {
514 let client = self.executor_factory.default_executor().await?;
515 let mut query_executor = ExecutorImpl::Default(client);
516 self.batch_put_with_query_executor(&mut query_executor, req)
517 .await
518 }
519
520 async fn batch_get(&self, req: BatchGetRequest) -> Result<BatchGetResponse> {
521 let client = self.executor_factory.default_executor().await?;
522 let mut query_executor = ExecutorImpl::Default(client);
523 self.batch_get_with_query_executor(&mut query_executor, req)
524 .await
525 }
526
527 async fn delete_range(&self, req: DeleteRangeRequest) -> Result<DeleteRangeResponse> {
528 let client = self.executor_factory.default_executor().await?;
529 let mut query_executor = ExecutorImpl::Default(client);
530 self.delete_range_with_query_executor(&mut query_executor, req)
531 .await
532 }
533
534 async fn batch_delete(&self, req: BatchDeleteRequest) -> Result<BatchDeleteResponse> {
535 let client = self.executor_factory.default_executor().await?;
536 let mut query_executor = ExecutorImpl::Default(client);
537 self.batch_delete_with_query_executor(&mut query_executor, req)
538 .await
539 }
540}
541
542#[async_trait::async_trait]
543impl<T, S, R> TxnService for RdsStore<T, S, R>
544where
545 Self: KvQueryExecutor<T> + Send + Sync,
546 T: Executor + 'static,
547 S: ExecutorFactory<T> + 'static,
548{
549 type Error = Error;
550
551 async fn txn(&self, txn: KvTxn) -> Result<KvTxnResponse> {
552 let _timer = METRIC_META_TXN_REQUEST
553 .with_label_values(&[T::name(), "txn"])
554 .start_timer();
555
556 let mut backoff = ExponentialBuilder::default()
557 .with_min_delay(Duration::from_millis(10))
558 .with_max_delay(Duration::from_millis(200))
559 .with_max_times(self.txn_retry_count)
560 .build();
561
562 loop {
563 match self.txn_inner(&txn).await {
564 Ok(res) => return Ok(res),
565 Err(e) => {
566 if e.is_serialization_error() {
567 let d = backoff.next();
568 if let Some(d) = d {
569 tokio::time::sleep(d).await;
570 continue;
571 }
572 break;
573 } else {
574 return Err(e);
575 }
576 }
577 }
578 }
579
580 RdsTransactionRetryFailedSnafu {}.fail()
581 }
582
583 fn max_txn_ops(&self) -> usize {
584 self.max_txn_ops
585 }
586}
587
588fn check_txn_ops(txn_ops: &[TxnOp]) -> Result<bool> {
590 if txn_ops.is_empty() {
591 return Ok(false);
592 }
593 let same = txn_ops.windows(2).all(|a| {
594 matches!(
595 (&a[0], &a[1]),
596 (TxnOp::Put(_, _), TxnOp::Put(_, _))
597 | (TxnOp::Get(_), TxnOp::Get(_))
598 | (TxnOp::Delete(_), TxnOp::Delete(_))
599 )
600 });
601 Ok(same)
602}
603
604#[macro_export]
605macro_rules! record_rds_sql_execute_elapsed {
606 ($result:expr, $label_store:expr,$label_op:expr,$label_type:expr) => {{
607 let timer = std::time::Instant::now();
608 $result
609 .inspect(|_| {
610 $crate::metrics::RDS_SQL_EXECUTE_ELAPSED
611 .with_label_values(&[$label_store, "success", $label_op, $label_type])
612 .observe(timer.elapsed().as_millis() as f64)
613 })
614 .inspect_err(|_| {
615 $crate::metrics::RDS_SQL_EXECUTE_ELAPSED
616 .with_label_values(&[$label_store, "error", $label_op, $label_type])
617 .observe(timer.elapsed().as_millis() as f64);
618 })
619 }};
620}