1use std::fs::File;
16use std::io::BufReader;
17use std::marker::PhantomData;
18use std::sync::Arc;
19
20use common_telemetry::debug;
21use deadpool_postgres::{Config, Pool, Runtime};
22use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
23use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
24use rustls::server::ParsedCertificate;
25use rustls::ClientConfig;
27use rustls::{DigitallySignedStruct, Error as TlsError, SignatureScheme};
28use rustls_pemfile::{certs, private_key};
29use snafu::ResultExt;
30use strum::AsRefStr;
31use tokio_postgres::types::ToSql;
32use tokio_postgres::{IsolationLevel, NoTls, Row};
33use tokio_postgres_rustls::MakeRustlsConnect;
34
35use crate::error::{
36 CreatePostgresPoolSnafu, GetPostgresConnectionSnafu, LoadTlsCertificateSnafu,
37 PostgresExecutionSnafu, PostgresTlsConfigSnafu, PostgresTransactionSnafu, Result,
38};
39use crate::kv_backend::rds::{
40 Executor, ExecutorFactory, ExecutorImpl, KvQueryExecutor, RdsStore, Transaction,
41 RDS_STORE_OP_BATCH_DELETE, RDS_STORE_OP_BATCH_GET, RDS_STORE_OP_BATCH_PUT,
42 RDS_STORE_OP_RANGE_DELETE, RDS_STORE_OP_RANGE_QUERY, RDS_STORE_TXN_RETRY_COUNT,
43};
44use crate::kv_backend::KvBackendRef;
45use crate::rpc::store::{
46 BatchDeleteRequest, BatchDeleteResponse, BatchGetRequest, BatchGetResponse, BatchPutRequest,
47 BatchPutResponse, DeleteRangeRequest, DeleteRangeResponse, RangeRequest, RangeResponse,
48};
49use crate::rpc::KeyValue;
50
51#[derive(Debug, Clone, PartialEq, Eq, Default)]
54pub enum TlsMode {
55 Disable,
56 #[default]
57 Prefer,
58 Require,
59 VerifyCa,
60 VerifyFull,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
66pub struct TlsOption {
67 pub mode: TlsMode,
68 pub cert_path: String,
69 pub key_path: String,
70 pub ca_cert_path: String,
71 pub watch: bool,
72}
73
74impl Default for TlsOption {
75 fn default() -> Self {
76 TlsOption {
77 mode: TlsMode::Prefer,
78 cert_path: String::new(),
79 key_path: String::new(),
80 ca_cert_path: String::new(),
81 watch: false,
82 }
83 }
84}
85
86const PG_STORE_NAME: &str = "pg_store";
87
88pub struct PgClient(deadpool::managed::Object<deadpool_postgres::Manager>);
89pub struct PgTxnClient<'a>(deadpool_postgres::Transaction<'a>);
90
91fn key_value_from_row(r: Row) -> KeyValue {
93 KeyValue {
94 key: r.get(0),
95 value: r.get(1),
96 }
97}
98
99const EMPTY: &[u8] = &[0];
100
101#[derive(Debug, Clone, Copy, AsRefStr)]
103enum RangeTemplateType {
104 Point,
105 Range,
106 Full,
107 LeftBounded,
108 Prefix,
109}
110
111impl RangeTemplateType {
113 fn build_params(&self, mut key: Vec<u8>, range_end: Vec<u8>) -> Vec<Vec<u8>> {
116 match self {
117 RangeTemplateType::Point => vec![key],
118 RangeTemplateType::Range => vec![key, range_end],
119 RangeTemplateType::Full => vec![],
120 RangeTemplateType::LeftBounded => vec![key],
121 RangeTemplateType::Prefix => {
122 key.push(b'%');
123 vec![key]
124 }
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
131struct RangeTemplate {
132 point: String,
133 range: String,
134 full: String,
135 left_bounded: String,
136 prefix: String,
137}
138
139impl RangeTemplate {
140 fn get(&self, typ: RangeTemplateType) -> &str {
142 match typ {
143 RangeTemplateType::Point => &self.point,
144 RangeTemplateType::Range => &self.range,
145 RangeTemplateType::Full => &self.full,
146 RangeTemplateType::LeftBounded => &self.left_bounded,
147 RangeTemplateType::Prefix => &self.prefix,
148 }
149 }
150
151 fn with_limit(template: &str, limit: i64) -> String {
153 if limit == 0 {
154 return format!("{};", template);
155 }
156 format!("{} LIMIT {};", template, limit)
157 }
158}
159
160fn is_prefix_range(start: &[u8], end: &[u8]) -> bool {
161 if start.len() != end.len() {
162 return false;
163 }
164 let l = start.len();
165 let same_prefix = start[0..l - 1] == end[0..l - 1];
166 if let (Some(rhs), Some(lhs)) = (start.last(), end.last()) {
167 return same_prefix && (*rhs + 1) == *lhs;
168 }
169 false
170}
171
172fn range_template(key: &[u8], range_end: &[u8]) -> RangeTemplateType {
174 match (key, range_end) {
175 (_, &[]) => RangeTemplateType::Point,
176 (EMPTY, EMPTY) => RangeTemplateType::Full,
177 (_, EMPTY) => RangeTemplateType::LeftBounded,
178 (start, end) => {
179 if is_prefix_range(start, end) {
180 RangeTemplateType::Prefix
181 } else {
182 RangeTemplateType::Range
183 }
184 }
185 }
186}
187
188fn pg_generate_in_placeholders(from: usize, to: usize) -> Vec<String> {
190 (from..=to).map(|i| format!("${}", i)).collect()
191}
192
193struct PgSqlTemplateFactory<'a> {
195 table_name: &'a str,
196}
197
198impl<'a> PgSqlTemplateFactory<'a> {
199 fn new(table_name: &'a str) -> Self {
201 Self { table_name }
202 }
203
204 fn build(&self) -> PgSqlTemplateSet {
206 let table_name = self.table_name;
207 PgSqlTemplateSet {
209 table_name: table_name.to_string(),
210 create_table_statement: format!(
211 "CREATE TABLE IF NOT EXISTS \"{table_name}\"(k bytea PRIMARY KEY, v bytea)",
212 ),
213 range_template: RangeTemplate {
214 point: format!("SELECT k, v FROM \"{table_name}\" WHERE k = $1"),
215 range: format!(
216 "SELECT k, v FROM \"{table_name}\" WHERE k >= $1 AND k < $2 ORDER BY k"
217 ),
218 full: format!("SELECT k, v FROM \"{table_name}\" ORDER BY k"),
219 left_bounded: format!("SELECT k, v FROM \"{table_name}\" WHERE k >= $1 ORDER BY k"),
220 prefix: format!("SELECT k, v FROM \"{table_name}\" WHERE k LIKE $1 ORDER BY k"),
221 },
222 delete_template: RangeTemplate {
223 point: format!("DELETE FROM \"{table_name}\" WHERE k = $1 RETURNING k,v;"),
224 range: format!(
225 "DELETE FROM \"{table_name}\" WHERE k >= $1 AND k < $2 RETURNING k,v;"
226 ),
227 full: format!("DELETE FROM \"{table_name}\" RETURNING k,v"),
228 left_bounded: format!("DELETE FROM \"{table_name}\" WHERE k >= $1 RETURNING k,v;"),
229 prefix: format!("DELETE FROM \"{table_name}\" WHERE k LIKE $1 RETURNING k,v;"),
230 },
231 }
232 }
233}
234
235#[derive(Debug, Clone)]
237pub struct PgSqlTemplateSet {
238 table_name: String,
239 create_table_statement: String,
240 range_template: RangeTemplate,
241 delete_template: RangeTemplate,
242}
243
244impl PgSqlTemplateSet {
245 fn generate_batch_get_query(&self, key_len: usize) -> String {
247 let table_name = &self.table_name;
248 let in_clause = pg_generate_in_placeholders(1, key_len).join(", ");
249 format!(
250 "SELECT k, v FROM \"{table_name}\" WHERE k in ({});",
251 in_clause
252 )
253 }
254
255 fn generate_batch_delete_query(&self, key_len: usize) -> String {
257 let table_name = &self.table_name;
258 let in_clause = pg_generate_in_placeholders(1, key_len).join(", ");
259 format!(
260 "DELETE FROM \"{table_name}\" WHERE k in ({}) RETURNING k,v;",
261 in_clause
262 )
263 }
264
265 fn generate_batch_upsert_query(&self, kv_len: usize) -> String {
267 let table_name = &self.table_name;
268 let in_placeholders: Vec<String> = (1..=kv_len).map(|i| format!("${}", i)).collect();
269 let in_clause = in_placeholders.join(", ");
270 let mut param_index = kv_len + 1;
271 let mut values_placeholders = Vec::new();
272 for _ in 0..kv_len {
273 values_placeholders.push(format!("(${0}, ${1})", param_index, param_index + 1));
274 param_index += 2;
275 }
276 let values_clause = values_placeholders.join(", ");
277
278 format!(
279 r#"
280 WITH prev AS (
281 SELECT k,v FROM "{table_name}" WHERE k IN ({in_clause})
282 ), update AS (
283 INSERT INTO "{table_name}" (k, v) VALUES
284 {values_clause}
285 ON CONFLICT (
286 k
287 ) DO UPDATE SET
288 v = excluded.v
289 )
290
291 SELECT k, v FROM prev;
292 "#
293 )
294 }
295}
296
297#[async_trait::async_trait]
298impl Executor for PgClient {
299 type Transaction<'a>
300 = PgTxnClient<'a>
301 where
302 Self: 'a;
303
304 fn name() -> &'static str {
305 "Postgres"
306 }
307
308 async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>> {
309 let params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
310 let stmt = self
311 .0
312 .prepare_cached(query)
313 .await
314 .context(PostgresExecutionSnafu { sql: query })?;
315 let rows = self
316 .0
317 .query(&stmt, ¶ms)
318 .await
319 .context(PostgresExecutionSnafu { sql: query })?;
320 Ok(rows.into_iter().map(key_value_from_row).collect())
321 }
322
323 async fn txn_executor<'a>(&'a mut self) -> Result<Self::Transaction<'a>> {
324 let txn = self
325 .0
326 .build_transaction()
327 .isolation_level(IsolationLevel::Serializable)
328 .start()
329 .await
330 .context(PostgresTransactionSnafu {
331 operation: "begin".to_string(),
332 })?;
333 Ok(PgTxnClient(txn))
334 }
335}
336
337#[async_trait::async_trait]
338impl<'a> Transaction<'a> for PgTxnClient<'a> {
339 async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>> {
340 let params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
341 let stmt = self
342 .0
343 .prepare_cached(query)
344 .await
345 .context(PostgresExecutionSnafu { sql: query })?;
346 let rows = self
347 .0
348 .query(&stmt, ¶ms)
349 .await
350 .context(PostgresExecutionSnafu { sql: query })?;
351 Ok(rows.into_iter().map(key_value_from_row).collect())
352 }
353
354 async fn commit(self) -> Result<()> {
355 self.0.commit().await.context(PostgresTransactionSnafu {
356 operation: "commit",
357 })?;
358 Ok(())
359 }
360}
361
362pub struct PgExecutorFactory {
363 pool: Pool,
364}
365
366impl PgExecutorFactory {
367 async fn client(&self) -> Result<PgClient> {
368 match self.pool.get().await {
369 Ok(client) => Ok(PgClient(client)),
370 Err(e) => GetPostgresConnectionSnafu {
371 reason: e.to_string(),
372 }
373 .fail(),
374 }
375 }
376}
377
378#[async_trait::async_trait]
379impl ExecutorFactory<PgClient> for PgExecutorFactory {
380 async fn default_executor(&self) -> Result<PgClient> {
381 self.client().await
382 }
383
384 async fn txn_executor<'a>(
385 &self,
386 default_executor: &'a mut PgClient,
387 ) -> Result<PgTxnClient<'a>> {
388 default_executor.txn_executor().await
389 }
390}
391
392pub type PgStore = RdsStore<PgClient, PgExecutorFactory, PgSqlTemplateSet>;
395
396pub fn create_postgres_tls_connector(tls_config: &TlsOption) -> Result<MakeRustlsConnect> {
410 common_telemetry::info!(
411 "Creating PostgreSQL TLS connector with mode: {:?}",
412 tls_config.mode
413 );
414
415 let config_builder = match tls_config.mode {
416 TlsMode::Disable => {
417 return PostgresTlsConfigSnafu {
418 reason: "Cannot create TLS connector for Disable mode".to_string(),
419 }
420 .fail();
421 }
422 TlsMode::Prefer | TlsMode::Require => {
423 let verifier = Arc::new(AcceptAnyVerifier);
425 ClientConfig::builder()
426 .dangerous()
427 .with_custom_certificate_verifier(verifier)
428 }
429 TlsMode::VerifyCa => {
430 let ca_store = load_ca(&tls_config.ca_cert_path)?;
432 let verifier = Arc::new(NoHostnameVerification { roots: ca_store });
433 ClientConfig::builder()
434 .dangerous()
435 .with_custom_certificate_verifier(verifier)
436 }
437 TlsMode::VerifyFull => {
438 let ca_store = load_ca(&tls_config.ca_cert_path)?;
439 ClientConfig::builder().with_root_certificates(ca_store)
440 }
441 };
442
443 let client_config = if !tls_config.cert_path.is_empty() && !tls_config.key_path.is_empty() {
445 common_telemetry::info!("Loading client certificate for mutual TLS");
447 let cert_chain = load_certs(&tls_config.cert_path)?;
448 let private_key = load_private_key(&tls_config.key_path)?;
449
450 config_builder
451 .with_client_auth_cert(cert_chain, private_key)
452 .map_err(|e| {
453 PostgresTlsConfigSnafu {
454 reason: format!("Failed to configure client authentication: {}", e),
455 }
456 .build()
457 })?
458 } else {
459 common_telemetry::info!("No client certificate provided, skip client authentication");
460 config_builder.with_no_client_auth()
461 };
462
463 common_telemetry::info!("Successfully created PostgreSQL TLS connector");
464 Ok(MakeRustlsConnect::new(client_config))
465}
466
467#[derive(Debug)]
469struct AcceptAnyVerifier;
470
471impl ServerCertVerifier for AcceptAnyVerifier {
472 fn verify_server_cert(
473 &self,
474 _end_entity: &CertificateDer<'_>,
475 _intermediates: &[CertificateDer<'_>],
476 _server_name: &ServerName<'_>,
477 _ocsp_response: &[u8],
478 _now: UnixTime,
479 ) -> std::result::Result<ServerCertVerified, TlsError> {
480 common_telemetry::debug!(
481 "Accepting server certificate without verification (Prefer/Require mode)"
482 );
483 Ok(ServerCertVerified::assertion())
484 }
485
486 fn verify_tls12_signature(
487 &self,
488 _message: &[u8],
489 _cert: &CertificateDer<'_>,
490 _dss: &DigitallySignedStruct,
491 ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
492 Ok(HandshakeSignatureValid::assertion())
494 }
495
496 fn verify_tls13_signature(
497 &self,
498 _message: &[u8],
499 _cert: &CertificateDer<'_>,
500 _dss: &DigitallySignedStruct,
501 ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
502 Ok(HandshakeSignatureValid::assertion())
504 }
505
506 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
507 rustls::crypto::ring::default_provider()
509 .signature_verification_algorithms
510 .supported_schemes()
511 }
512}
513
514#[derive(Debug)]
517struct NoHostnameVerification {
518 roots: Arc<rustls::RootCertStore>,
519}
520
521impl ServerCertVerifier for NoHostnameVerification {
522 fn verify_server_cert(
523 &self,
524 end_entity: &CertificateDer<'_>,
525 intermediates: &[CertificateDer<'_>],
526 _server_name: &ServerName<'_>,
527 _ocsp_response: &[u8],
528 now: UnixTime,
529 ) -> std::result::Result<ServerCertVerified, TlsError> {
530 let cert = ParsedCertificate::try_from(end_entity)?;
531 rustls::client::verify_server_cert_signed_by_trust_anchor(
532 &cert,
533 &self.roots,
534 intermediates,
535 now,
536 rustls::crypto::ring::default_provider()
537 .signature_verification_algorithms
538 .all,
539 )?;
540
541 Ok(ServerCertVerified::assertion())
542 }
543
544 fn verify_tls12_signature(
545 &self,
546 message: &[u8],
547 cert: &CertificateDer<'_>,
548 dss: &DigitallySignedStruct,
549 ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
550 rustls::crypto::verify_tls12_signature(
551 message,
552 cert,
553 dss,
554 &rustls::crypto::ring::default_provider().signature_verification_algorithms,
555 )
556 }
557
558 fn verify_tls13_signature(
559 &self,
560 message: &[u8],
561 cert: &CertificateDer<'_>,
562 dss: &DigitallySignedStruct,
563 ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
564 rustls::crypto::verify_tls13_signature(
565 message,
566 cert,
567 dss,
568 &rustls::crypto::ring::default_provider().signature_verification_algorithms,
569 )
570 }
571
572 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
573 rustls::crypto::ring::default_provider()
575 .signature_verification_algorithms
576 .supported_schemes()
577 }
578}
579
580fn load_certs(path: &str) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
581 let file = File::open(path).context(LoadTlsCertificateSnafu { path })?;
582 let mut reader = BufReader::new(file);
583 let certs = certs(&mut reader)
584 .collect::<std::result::Result<Vec<_>, _>>()
585 .map_err(|e| {
586 PostgresTlsConfigSnafu {
587 reason: format!("Failed to parse certificates from {}: {}", path, e),
588 }
589 .build()
590 })?;
591 Ok(certs)
592}
593
594fn load_private_key(path: &str) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
595 let file = File::open(path).context(LoadTlsCertificateSnafu { path })?;
596 let mut reader = BufReader::new(file);
597 let key = private_key(&mut reader)
598 .map_err(|e| {
599 PostgresTlsConfigSnafu {
600 reason: format!("Failed to parse private key from {}: {}", path, e),
601 }
602 .build()
603 })?
604 .ok_or_else(|| {
605 PostgresTlsConfigSnafu {
606 reason: format!("No private key found in {}", path),
607 }
608 .build()
609 })?;
610 Ok(key)
611}
612
613fn load_ca(path: &str) -> Result<Arc<rustls::RootCertStore>> {
614 let mut root_store = rustls::RootCertStore::empty();
615
616 match rustls_native_certs::load_native_certs() {
618 Ok(certs) => {
619 let num_certs = certs.len();
620 for cert in certs {
621 if let Err(e) = root_store.add(cert) {
622 return PostgresTlsConfigSnafu {
623 reason: format!("Failed to add root certificate: {}", e),
624 }
625 .fail();
626 }
627 }
628 common_telemetry::info!("Loaded {num_certs} system root certificates successfully");
629 }
630 Err(e) => {
631 return PostgresTlsConfigSnafu {
632 reason: format!("Failed to load system root certificates: {}", e),
633 }
634 .fail();
635 }
636 }
637
638 if !path.is_empty() {
640 let ca_certs = load_certs(path)?;
641 for cert in ca_certs {
642 if let Err(e) = root_store.add(cert) {
643 return PostgresTlsConfigSnafu {
644 reason: format!("Failed to add custom CA certificate: {}", e),
645 }
646 .fail();
647 }
648 }
649 common_telemetry::info!("Added custom CA certificate from {}", path);
650 }
651
652 Ok(Arc::new(root_store))
653}
654
655#[async_trait::async_trait]
656impl KvQueryExecutor<PgClient> for PgStore {
657 async fn range_with_query_executor(
658 &self,
659 query_executor: &mut ExecutorImpl<'_, PgClient>,
660 req: RangeRequest,
661 ) -> Result<RangeResponse> {
662 let template_type = range_template(&req.key, &req.range_end);
663 let template = self.sql_template_set.range_template.get(template_type);
664 let params = template_type.build_params(req.key, req.range_end);
665 let params_ref = params.iter().collect::<Vec<_>>();
666 let query =
668 RangeTemplate::with_limit(template, if req.limit == 0 { 0 } else { req.limit + 1 });
669 let limit = req.limit as usize;
670 debug!("query: {:?}, params: {:?}", query, params);
671 let mut kvs = crate::record_rds_sql_execute_elapsed!(
672 query_executor.query(&query, ¶ms_ref).await,
673 PG_STORE_NAME,
674 RDS_STORE_OP_RANGE_QUERY,
675 template_type.as_ref()
676 )?;
677
678 if req.keys_only {
679 kvs.iter_mut().for_each(|kv| kv.value = vec![]);
680 }
681 if limit == 0 || kvs.len() <= limit {
683 return Ok(RangeResponse { kvs, more: false });
684 }
685 let removed = kvs.pop();
687 debug_assert!(removed.is_some());
688 Ok(RangeResponse { kvs, more: true })
689 }
690
691 async fn batch_put_with_query_executor(
692 &self,
693 query_executor: &mut ExecutorImpl<'_, PgClient>,
694 req: BatchPutRequest,
695 ) -> Result<BatchPutResponse> {
696 let mut in_params = Vec::with_capacity(req.kvs.len() * 3);
697 let mut values_params = Vec::with_capacity(req.kvs.len() * 2);
698
699 for kv in &req.kvs {
700 let processed_key = &kv.key;
701 in_params.push(processed_key);
702
703 let processed_value = &kv.value;
704 values_params.push(processed_key);
705 values_params.push(processed_value);
706 }
707 in_params.extend(values_params);
708 let params = in_params.iter().map(|x| x as _).collect::<Vec<_>>();
709 let query = self
710 .sql_template_set
711 .generate_batch_upsert_query(req.kvs.len());
712
713 let kvs = crate::record_rds_sql_execute_elapsed!(
714 query_executor.query(&query, ¶ms).await,
715 PG_STORE_NAME,
716 RDS_STORE_OP_BATCH_PUT,
717 ""
718 )?;
719 if req.prev_kv {
720 Ok(BatchPutResponse { prev_kvs: kvs })
721 } else {
722 Ok(BatchPutResponse::default())
723 }
724 }
725
726 async fn batch_get_with_query_executor(
728 &self,
729 query_executor: &mut ExecutorImpl<'_, PgClient>,
730 req: BatchGetRequest,
731 ) -> Result<BatchGetResponse> {
732 if req.keys.is_empty() {
733 return Ok(BatchGetResponse { kvs: vec![] });
734 }
735 let query = self
736 .sql_template_set
737 .generate_batch_get_query(req.keys.len());
738 let params = req.keys.iter().map(|x| x as _).collect::<Vec<_>>();
739 let kvs = crate::record_rds_sql_execute_elapsed!(
740 query_executor.query(&query, ¶ms).await,
741 PG_STORE_NAME,
742 RDS_STORE_OP_BATCH_GET,
743 ""
744 )?;
745 Ok(BatchGetResponse { kvs })
746 }
747
748 async fn delete_range_with_query_executor(
749 &self,
750 query_executor: &mut ExecutorImpl<'_, PgClient>,
751 req: DeleteRangeRequest,
752 ) -> Result<DeleteRangeResponse> {
753 let template_type = range_template(&req.key, &req.range_end);
754 let template = self.sql_template_set.delete_template.get(template_type);
755 let params = template_type.build_params(req.key, req.range_end);
756 let params_ref = params.iter().map(|x| x as _).collect::<Vec<_>>();
757 let kvs = crate::record_rds_sql_execute_elapsed!(
758 query_executor.query(template, ¶ms_ref).await,
759 PG_STORE_NAME,
760 RDS_STORE_OP_RANGE_DELETE,
761 template_type.as_ref()
762 )?;
763 let mut resp = DeleteRangeResponse::new(kvs.len() as i64);
764 if req.prev_kv {
765 resp.with_prev_kvs(kvs);
766 }
767 Ok(resp)
768 }
769
770 async fn batch_delete_with_query_executor(
771 &self,
772 query_executor: &mut ExecutorImpl<'_, PgClient>,
773 req: BatchDeleteRequest,
774 ) -> Result<BatchDeleteResponse> {
775 if req.keys.is_empty() {
776 return Ok(BatchDeleteResponse::default());
777 }
778 let query = self
779 .sql_template_set
780 .generate_batch_delete_query(req.keys.len());
781 let params = req.keys.iter().map(|x| x as _).collect::<Vec<_>>();
782
783 let kvs = crate::record_rds_sql_execute_elapsed!(
784 query_executor.query(&query, ¶ms).await,
785 PG_STORE_NAME,
786 RDS_STORE_OP_BATCH_DELETE,
787 ""
788 )?;
789 if req.prev_kv {
790 Ok(BatchDeleteResponse { prev_kvs: kvs })
791 } else {
792 Ok(BatchDeleteResponse::default())
793 }
794 }
795}
796
797impl PgStore {
798 pub async fn with_url_and_tls(
807 url: &str,
808 table_name: &str,
809 max_txn_ops: usize,
810 tls_config: Option<TlsOption>,
811 ) -> Result<KvBackendRef> {
812 let mut cfg = Config::new();
813 cfg.url = Some(url.to_string());
814
815 let pool = match tls_config {
816 Some(tls_config) if tls_config.mode != TlsMode::Disable => {
817 match create_postgres_tls_connector(&tls_config) {
818 Ok(tls_connector) => cfg
819 .create_pool(Some(Runtime::Tokio1), tls_connector)
820 .context(CreatePostgresPoolSnafu)?,
821 Err(e) => {
822 if tls_config.mode == TlsMode::Prefer {
823 common_telemetry::info!("Failed to create TLS connector, falling back to insecure connection");
825 cfg.create_pool(Some(Runtime::Tokio1), NoTls)
826 .context(CreatePostgresPoolSnafu)?
827 } else {
828 return Err(e);
829 }
830 }
831 }
832 }
833 _ => cfg
834 .create_pool(Some(Runtime::Tokio1), NoTls)
835 .context(CreatePostgresPoolSnafu)?,
836 };
837
838 Self::with_pg_pool(pool, table_name, max_txn_ops).await
839 }
840
841 pub async fn with_url(url: &str, table_name: &str, max_txn_ops: usize) -> Result<KvBackendRef> {
843 Self::with_url_and_tls(url, table_name, max_txn_ops, None).await
844 }
845
846 pub async fn with_pg_pool(
848 pool: Pool,
849 table_name: &str,
850 max_txn_ops: usize,
851 ) -> Result<KvBackendRef> {
852 let client = match pool.get().await {
856 Ok(client) => client,
857 Err(e) => {
858 return GetPostgresConnectionSnafu {
859 reason: e.to_string(),
860 }
861 .fail();
862 }
863 };
864 let template_factory = PgSqlTemplateFactory::new(table_name);
865 let sql_template_set = template_factory.build();
866 client
867 .execute(&sql_template_set.create_table_statement, &[])
868 .await
869 .with_context(|_| PostgresExecutionSnafu {
870 sql: sql_template_set.create_table_statement.to_string(),
871 })?;
872 Ok(Arc::new(Self {
873 max_txn_ops,
874 sql_template_set,
875 txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
876 executor_factory: PgExecutorFactory { pool },
877 _phantom: PhantomData,
878 }))
879 }
880}
881
882#[cfg(test)]
883mod tests {
884 use super::*;
885 use crate::kv_backend::test::{
886 prepare_kv_with_prefix, test_kv_batch_delete_with_prefix, test_kv_batch_get_with_prefix,
887 test_kv_compare_and_put_with_prefix, test_kv_delete_range_with_prefix,
888 test_kv_put_with_prefix, test_kv_range_2_with_prefix, test_kv_range_with_prefix,
889 test_simple_kv_range, test_txn_compare_equal, test_txn_compare_greater,
890 test_txn_compare_less, test_txn_compare_not_equal, test_txn_one_compare_op,
891 text_txn_multi_compare_op, unprepare_kv,
892 };
893 use crate::maybe_skip_postgres_integration_test;
894
895 async fn build_pg_kv_backend(table_name: &str) -> Option<PgStore> {
896 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap_or_default();
897 if endpoints.is_empty() {
898 return None;
899 }
900
901 let mut cfg = Config::new();
902 cfg.url = Some(endpoints);
903 let pool = cfg
904 .create_pool(Some(Runtime::Tokio1), NoTls)
905 .context(CreatePostgresPoolSnafu)
906 .unwrap();
907 let client = pool.get().await.unwrap();
908 let template_factory = PgSqlTemplateFactory::new(table_name);
909 let sql_templates = template_factory.build();
910 client
911 .execute(&sql_templates.create_table_statement, &[])
912 .await
913 .context(PostgresExecutionSnafu {
914 sql: sql_templates.create_table_statement.to_string(),
915 })
916 .unwrap();
917 Some(PgStore {
918 max_txn_ops: 128,
919 sql_template_set: sql_templates,
920 txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
921 executor_factory: PgExecutorFactory { pool },
922 _phantom: PhantomData,
923 })
924 }
925
926 #[tokio::test]
927 async fn test_pg_put() {
928 maybe_skip_postgres_integration_test!();
929 let kv_backend = build_pg_kv_backend("put_test").await.unwrap();
930 let prefix = b"put/";
931 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
932 test_kv_put_with_prefix(&kv_backend, prefix.to_vec()).await;
933 unprepare_kv(&kv_backend, prefix).await;
934 }
935
936 #[tokio::test]
937 async fn test_pg_range() {
938 maybe_skip_postgres_integration_test!();
939 let kv_backend = build_pg_kv_backend("range_test").await.unwrap();
940 let prefix = b"range/";
941 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
942 test_kv_range_with_prefix(&kv_backend, prefix.to_vec()).await;
943 unprepare_kv(&kv_backend, prefix).await;
944 }
945
946 #[tokio::test]
947 async fn test_pg_range_2() {
948 maybe_skip_postgres_integration_test!();
949 let kv_backend = build_pg_kv_backend("range2_test").await.unwrap();
950 let prefix = b"range2/";
951 test_kv_range_2_with_prefix(&kv_backend, prefix.to_vec()).await;
952 unprepare_kv(&kv_backend, prefix).await;
953 }
954
955 #[tokio::test]
956 async fn test_pg_all_range() {
957 maybe_skip_postgres_integration_test!();
958 let kv_backend = build_pg_kv_backend("simple_range_test").await.unwrap();
959 let prefix = b"";
960 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
961 test_simple_kv_range(&kv_backend).await;
962 unprepare_kv(&kv_backend, prefix).await;
963 }
964
965 #[tokio::test]
966 async fn test_pg_batch_get() {
967 maybe_skip_postgres_integration_test!();
968 let kv_backend = build_pg_kv_backend("batch_get_test").await.unwrap();
969 let prefix = b"batch_get/";
970 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
971 test_kv_batch_get_with_prefix(&kv_backend, prefix.to_vec()).await;
972 unprepare_kv(&kv_backend, prefix).await;
973 }
974
975 #[tokio::test]
976 async fn test_pg_batch_delete() {
977 maybe_skip_postgres_integration_test!();
978 let kv_backend = build_pg_kv_backend("batch_delete_test").await.unwrap();
979 let prefix = b"batch_delete/";
980 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
981 test_kv_delete_range_with_prefix(&kv_backend, prefix.to_vec()).await;
982 unprepare_kv(&kv_backend, prefix).await;
983 }
984
985 #[tokio::test]
986 async fn test_pg_batch_delete_with_prefix() {
987 maybe_skip_postgres_integration_test!();
988 let kv_backend = build_pg_kv_backend("batch_delete_with_prefix_test")
989 .await
990 .unwrap();
991 let prefix = b"batch_delete/";
992 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
993 test_kv_batch_delete_with_prefix(&kv_backend, prefix.to_vec()).await;
994 unprepare_kv(&kv_backend, prefix).await;
995 }
996
997 #[tokio::test]
998 async fn test_pg_delete_range() {
999 maybe_skip_postgres_integration_test!();
1000 let kv_backend = build_pg_kv_backend("delete_range_test").await.unwrap();
1001 let prefix = b"delete_range/";
1002 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1003 test_kv_delete_range_with_prefix(&kv_backend, prefix.to_vec()).await;
1004 unprepare_kv(&kv_backend, prefix).await;
1005 }
1006
1007 #[tokio::test]
1008 async fn test_pg_compare_and_put() {
1009 maybe_skip_postgres_integration_test!();
1010 let kv_backend = build_pg_kv_backend("compare_and_put_test").await.unwrap();
1011 let prefix = b"compare_and_put/";
1012 let kv_backend = Arc::new(kv_backend);
1013 test_kv_compare_and_put_with_prefix(kv_backend.clone(), prefix.to_vec()).await;
1014 }
1015
1016 #[tokio::test]
1017 async fn test_pg_txn() {
1018 maybe_skip_postgres_integration_test!();
1019 let kv_backend = build_pg_kv_backend("txn_test").await.unwrap();
1020 test_txn_one_compare_op(&kv_backend).await;
1021 text_txn_multi_compare_op(&kv_backend).await;
1022 test_txn_compare_equal(&kv_backend).await;
1023 test_txn_compare_greater(&kv_backend).await;
1024 test_txn_compare_less(&kv_backend).await;
1025 test_txn_compare_not_equal(&kv_backend).await;
1026 }
1027}