1use std::fs::File;
16use std::io::BufReader;
17use std::marker::PhantomData;
18use std::sync::Arc;
19
20use common_telemetry::debug;
21use deadpool_postgres::{Config, Pool, Runtime};
22use rustls::ClientConfig;
24use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
25use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
26use rustls::server::ParsedCertificate;
27use rustls::{DigitallySignedStruct, Error as TlsError, SignatureScheme};
28use rustls_pemfile::{certs, private_key};
29use snafu::ResultExt;
30use strum::AsRefStr;
31use tokio_postgres::types::ToSql;
32use tokio_postgres::{IsolationLevel, NoTls, Row};
33use tokio_postgres_rustls::MakeRustlsConnect;
34
35use crate::error::{
36 CreatePostgresPoolSnafu, GetPostgresConnectionSnafu, LoadTlsCertificateSnafu,
37 PostgresExecutionSnafu, PostgresTlsConfigSnafu, PostgresTransactionSnafu, Result,
38};
39use crate::kv_backend::KvBackendRef;
40use crate::kv_backend::rds::{
41 Executor, ExecutorFactory, ExecutorImpl, KvQueryExecutor, RDS_STORE_OP_BATCH_DELETE,
42 RDS_STORE_OP_BATCH_GET, RDS_STORE_OP_BATCH_PUT, RDS_STORE_OP_RANGE_DELETE,
43 RDS_STORE_OP_RANGE_QUERY, RDS_STORE_TXN_RETRY_COUNT, RdsStore, Transaction,
44};
45use crate::rpc::KeyValue;
46use crate::rpc::store::{
47 BatchDeleteRequest, BatchDeleteResponse, BatchGetRequest, BatchGetResponse, BatchPutRequest,
48 BatchPutResponse, DeleteRangeRequest, DeleteRangeResponse, RangeRequest, RangeResponse,
49};
50
51#[derive(Debug, Clone, PartialEq, Eq, Default)]
54pub enum TlsMode {
55 Disable,
56 #[default]
57 Prefer,
58 Require,
59 VerifyCa,
60 VerifyFull,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
66pub struct TlsOption {
67 pub mode: TlsMode,
68 pub cert_path: String,
69 pub key_path: String,
70 pub ca_cert_path: String,
71 pub watch: bool,
72}
73
74impl Default for TlsOption {
75 fn default() -> Self {
76 TlsOption {
77 mode: TlsMode::Prefer,
78 cert_path: String::new(),
79 key_path: String::new(),
80 ca_cert_path: String::new(),
81 watch: false,
82 }
83 }
84}
85
86const PG_STORE_NAME: &str = "pg_store";
87
88pub struct PgClient(deadpool::managed::Object<deadpool_postgres::Manager>);
89pub struct PgTxnClient<'a>(deadpool_postgres::Transaction<'a>);
90
91fn key_value_from_row(r: Row) -> KeyValue {
93 KeyValue {
94 key: r.get(0),
95 value: r.get(1),
96 }
97}
98
99const EMPTY: &[u8] = &[0];
100
101#[derive(Debug, Clone, Copy, AsRefStr)]
103enum RangeTemplateType {
104 Point,
105 Range,
106 Full,
107 LeftBounded,
108 Prefix,
109}
110
111impl RangeTemplateType {
113 fn build_params(&self, mut key: Vec<u8>, range_end: Vec<u8>) -> Vec<Vec<u8>> {
116 match self {
117 RangeTemplateType::Point => vec![key],
118 RangeTemplateType::Range => vec![key, range_end],
119 RangeTemplateType::Full => vec![],
120 RangeTemplateType::LeftBounded => vec![key],
121 RangeTemplateType::Prefix => {
122 key.push(b'%');
123 vec![key]
124 }
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
131struct RangeTemplate {
132 point: String,
133 range: String,
134 full: String,
135 left_bounded: String,
136 prefix: String,
137}
138
139impl RangeTemplate {
140 fn get(&self, typ: RangeTemplateType) -> &str {
142 match typ {
143 RangeTemplateType::Point => &self.point,
144 RangeTemplateType::Range => &self.range,
145 RangeTemplateType::Full => &self.full,
146 RangeTemplateType::LeftBounded => &self.left_bounded,
147 RangeTemplateType::Prefix => &self.prefix,
148 }
149 }
150
151 fn with_limit(template: &str, limit: i64) -> String {
153 if limit == 0 {
154 return format!("{};", template);
155 }
156 format!("{} LIMIT {};", template, limit)
157 }
158}
159
160fn is_prefix_range(start: &[u8], end: &[u8]) -> bool {
161 if start.len() != end.len() {
162 return false;
163 }
164 let l = start.len();
165 let same_prefix = start[0..l - 1] == end[0..l - 1];
166 if let (Some(rhs), Some(lhs)) = (start.last(), end.last()) {
167 return same_prefix && (*rhs + 1) == *lhs;
168 }
169 false
170}
171
172fn range_template(key: &[u8], range_end: &[u8]) -> RangeTemplateType {
174 match (key, range_end) {
175 (_, &[]) => RangeTemplateType::Point,
176 (EMPTY, EMPTY) => RangeTemplateType::Full,
177 (_, EMPTY) => RangeTemplateType::LeftBounded,
178 (start, end) => {
179 if is_prefix_range(start, end) {
180 RangeTemplateType::Prefix
181 } else {
182 RangeTemplateType::Range
183 }
184 }
185 }
186}
187
188fn pg_generate_in_placeholders(from: usize, to: usize) -> Vec<String> {
190 (from..=to).map(|i| format!("${}", i)).collect()
191}
192
193struct PgSqlTemplateFactory<'a> {
195 schema_name: Option<&'a str>,
196 table_name: &'a str,
197}
198
199impl<'a> PgSqlTemplateFactory<'a> {
200 fn new(schema_name: Option<&'a str>, table_name: &'a str) -> Self {
202 Self {
203 schema_name,
204 table_name,
205 }
206 }
207
208 fn build(&self) -> PgSqlTemplateSet {
210 let table_ident = Self::format_table_ident(self.schema_name, self.table_name);
211 PgSqlTemplateSet {
213 table_ident: table_ident.clone(),
214 create_table_statement: format!(
216 "CREATE TABLE IF NOT EXISTS {table_ident}(k bytea PRIMARY KEY, v bytea)",
217 ),
218 range_template: RangeTemplate {
219 point: format!("SELECT k, v FROM {table_ident} WHERE k = $1"),
220 range: format!(
221 "SELECT k, v FROM {table_ident} WHERE k >= $1 AND k < $2 ORDER BY k"
222 ),
223 full: format!("SELECT k, v FROM {table_ident} ORDER BY k"),
224 left_bounded: format!("SELECT k, v FROM {table_ident} WHERE k >= $1 ORDER BY k"),
225 prefix: format!("SELECT k, v FROM {table_ident} WHERE k LIKE $1 ORDER BY k"),
226 },
227 delete_template: RangeTemplate {
228 point: format!("DELETE FROM {table_ident} WHERE k = $1 RETURNING k,v;"),
229 range: format!("DELETE FROM {table_ident} WHERE k >= $1 AND k < $2 RETURNING k,v;"),
230 full: format!("DELETE FROM {table_ident} RETURNING k,v"),
231 left_bounded: format!("DELETE FROM {table_ident} WHERE k >= $1 RETURNING k,v;"),
232 prefix: format!("DELETE FROM {table_ident} WHERE k LIKE $1 RETURNING k,v;"),
233 },
234 }
235 }
236
237 fn format_table_ident(schema_name: Option<&str>, table_name: &str) -> String {
239 match schema_name {
240 Some(s) if !s.is_empty() => format!("\"{}\".\"{}\"", s, table_name),
241 _ => format!("\"{}\"", table_name),
242 }
243 }
244}
245
246#[derive(Debug, Clone)]
248pub struct PgSqlTemplateSet {
249 table_ident: String,
250 create_table_statement: String,
251 range_template: RangeTemplate,
252 delete_template: RangeTemplate,
253}
254
255impl PgSqlTemplateSet {
256 fn generate_batch_get_query(&self, key_len: usize) -> String {
258 let in_clause = pg_generate_in_placeholders(1, key_len).join(", ");
259 format!(
260 "SELECT k, v FROM {} WHERE k in ({});",
261 self.table_ident, in_clause
262 )
263 }
264
265 fn generate_batch_delete_query(&self, key_len: usize) -> String {
267 let in_clause = pg_generate_in_placeholders(1, key_len).join(", ");
268 format!(
269 "DELETE FROM {} WHERE k in ({}) RETURNING k,v;",
270 self.table_ident, in_clause
271 )
272 }
273
274 fn generate_batch_upsert_query(&self, kv_len: usize) -> String {
276 let in_placeholders: Vec<String> = (1..=kv_len).map(|i| format!("${}", i)).collect();
277 let in_clause = in_placeholders.join(", ");
278 let mut param_index = kv_len + 1;
279 let mut values_placeholders = Vec::new();
280 for _ in 0..kv_len {
281 values_placeholders.push(format!("(${0}, ${1})", param_index, param_index + 1));
282 param_index += 2;
283 }
284 let values_clause = values_placeholders.join(", ");
285
286 format!(
287 r#"
288 WITH prev AS (
289 SELECT k,v FROM {table} WHERE k IN ({in_clause})
290 ), update AS (
291 INSERT INTO {table} (k, v) VALUES
292 {values_clause}
293 ON CONFLICT (
294 k
295 ) DO UPDATE SET
296 v = excluded.v
297 )
298
299 SELECT k, v FROM prev;
300 "#,
301 table = self.table_ident,
302 in_clause = in_clause,
303 values_clause = values_clause
304 )
305 }
306}
307
308#[async_trait::async_trait]
309impl Executor for PgClient {
310 type Transaction<'a>
311 = PgTxnClient<'a>
312 where
313 Self: 'a;
314
315 fn name() -> &'static str {
316 "Postgres"
317 }
318
319 async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>> {
320 let params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
321 let stmt = self
322 .0
323 .prepare_cached(query)
324 .await
325 .context(PostgresExecutionSnafu { sql: query })?;
326 let rows = self
327 .0
328 .query(&stmt, ¶ms)
329 .await
330 .context(PostgresExecutionSnafu { sql: query })?;
331 Ok(rows.into_iter().map(key_value_from_row).collect())
332 }
333
334 async fn txn_executor<'a>(&'a mut self) -> Result<Self::Transaction<'a>> {
335 let txn = self
336 .0
337 .build_transaction()
338 .isolation_level(IsolationLevel::Serializable)
339 .start()
340 .await
341 .context(PostgresTransactionSnafu {
342 operation: "begin".to_string(),
343 })?;
344 Ok(PgTxnClient(txn))
345 }
346}
347
348#[async_trait::async_trait]
349impl<'a> Transaction<'a> for PgTxnClient<'a> {
350 async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>> {
351 let params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
352 let stmt = self
353 .0
354 .prepare_cached(query)
355 .await
356 .context(PostgresExecutionSnafu { sql: query })?;
357 let rows = self
358 .0
359 .query(&stmt, ¶ms)
360 .await
361 .context(PostgresExecutionSnafu { sql: query })?;
362 Ok(rows.into_iter().map(key_value_from_row).collect())
363 }
364
365 async fn commit(self) -> Result<()> {
366 self.0.commit().await.context(PostgresTransactionSnafu {
367 operation: "commit",
368 })?;
369 Ok(())
370 }
371}
372
373pub struct PgExecutorFactory {
374 pool: Pool,
375}
376
377impl PgExecutorFactory {
378 async fn client(&self) -> Result<PgClient> {
379 match self.pool.get().await {
380 Ok(client) => Ok(PgClient(client)),
381 Err(e) => GetPostgresConnectionSnafu {
382 reason: e.to_string(),
383 }
384 .fail(),
385 }
386 }
387}
388
389#[async_trait::async_trait]
390impl ExecutorFactory<PgClient> for PgExecutorFactory {
391 async fn default_executor(&self) -> Result<PgClient> {
392 self.client().await
393 }
394
395 async fn txn_executor<'a>(
396 &self,
397 default_executor: &'a mut PgClient,
398 ) -> Result<PgTxnClient<'a>> {
399 default_executor.txn_executor().await
400 }
401}
402
403pub type PgStore = RdsStore<PgClient, PgExecutorFactory, PgSqlTemplateSet>;
406
407pub fn create_postgres_tls_connector(tls_config: &TlsOption) -> Result<MakeRustlsConnect> {
421 common_telemetry::info!(
422 "Creating PostgreSQL TLS connector with mode: {:?}",
423 tls_config.mode
424 );
425
426 let config_builder = match tls_config.mode {
427 TlsMode::Disable => {
428 return PostgresTlsConfigSnafu {
429 reason: "Cannot create TLS connector for Disable mode".to_string(),
430 }
431 .fail();
432 }
433 TlsMode::Prefer | TlsMode::Require => {
434 let verifier = Arc::new(AcceptAnyVerifier);
436 ClientConfig::builder()
437 .dangerous()
438 .with_custom_certificate_verifier(verifier)
439 }
440 TlsMode::VerifyCa => {
441 let ca_store = load_ca(&tls_config.ca_cert_path)?;
443 let verifier = Arc::new(NoHostnameVerification { roots: ca_store });
444 ClientConfig::builder()
445 .dangerous()
446 .with_custom_certificate_verifier(verifier)
447 }
448 TlsMode::VerifyFull => {
449 let ca_store = load_ca(&tls_config.ca_cert_path)?;
450 ClientConfig::builder().with_root_certificates(ca_store)
451 }
452 };
453
454 let client_config = if !tls_config.cert_path.is_empty() && !tls_config.key_path.is_empty() {
456 common_telemetry::info!("Loading client certificate for mutual TLS");
458 let cert_chain = load_certs(&tls_config.cert_path)?;
459 let private_key = load_private_key(&tls_config.key_path)?;
460
461 config_builder
462 .with_client_auth_cert(cert_chain, private_key)
463 .map_err(|e| {
464 PostgresTlsConfigSnafu {
465 reason: format!("Failed to configure client authentication: {}", e),
466 }
467 .build()
468 })?
469 } else {
470 common_telemetry::info!("No client certificate provided, skip client authentication");
471 config_builder.with_no_client_auth()
472 };
473
474 common_telemetry::info!("Successfully created PostgreSQL TLS connector");
475 Ok(MakeRustlsConnect::new(client_config))
476}
477
478#[derive(Debug)]
480struct AcceptAnyVerifier;
481
482impl ServerCertVerifier for AcceptAnyVerifier {
483 fn verify_server_cert(
484 &self,
485 _end_entity: &CertificateDer<'_>,
486 _intermediates: &[CertificateDer<'_>],
487 _server_name: &ServerName<'_>,
488 _ocsp_response: &[u8],
489 _now: UnixTime,
490 ) -> std::result::Result<ServerCertVerified, TlsError> {
491 common_telemetry::debug!(
492 "Accepting server certificate without verification (Prefer/Require mode)"
493 );
494 Ok(ServerCertVerified::assertion())
495 }
496
497 fn verify_tls12_signature(
498 &self,
499 _message: &[u8],
500 _cert: &CertificateDer<'_>,
501 _dss: &DigitallySignedStruct,
502 ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
503 Ok(HandshakeSignatureValid::assertion())
505 }
506
507 fn verify_tls13_signature(
508 &self,
509 _message: &[u8],
510 _cert: &CertificateDer<'_>,
511 _dss: &DigitallySignedStruct,
512 ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
513 Ok(HandshakeSignatureValid::assertion())
515 }
516
517 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
518 rustls::crypto::ring::default_provider()
520 .signature_verification_algorithms
521 .supported_schemes()
522 }
523}
524
525#[derive(Debug)]
528struct NoHostnameVerification {
529 roots: Arc<rustls::RootCertStore>,
530}
531
532impl ServerCertVerifier for NoHostnameVerification {
533 fn verify_server_cert(
534 &self,
535 end_entity: &CertificateDer<'_>,
536 intermediates: &[CertificateDer<'_>],
537 _server_name: &ServerName<'_>,
538 _ocsp_response: &[u8],
539 now: UnixTime,
540 ) -> std::result::Result<ServerCertVerified, TlsError> {
541 let cert = ParsedCertificate::try_from(end_entity)?;
542 rustls::client::verify_server_cert_signed_by_trust_anchor(
543 &cert,
544 &self.roots,
545 intermediates,
546 now,
547 rustls::crypto::ring::default_provider()
548 .signature_verification_algorithms
549 .all,
550 )?;
551
552 Ok(ServerCertVerified::assertion())
553 }
554
555 fn verify_tls12_signature(
556 &self,
557 message: &[u8],
558 cert: &CertificateDer<'_>,
559 dss: &DigitallySignedStruct,
560 ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
561 rustls::crypto::verify_tls12_signature(
562 message,
563 cert,
564 dss,
565 &rustls::crypto::ring::default_provider().signature_verification_algorithms,
566 )
567 }
568
569 fn verify_tls13_signature(
570 &self,
571 message: &[u8],
572 cert: &CertificateDer<'_>,
573 dss: &DigitallySignedStruct,
574 ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
575 rustls::crypto::verify_tls13_signature(
576 message,
577 cert,
578 dss,
579 &rustls::crypto::ring::default_provider().signature_verification_algorithms,
580 )
581 }
582
583 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
584 rustls::crypto::ring::default_provider()
586 .signature_verification_algorithms
587 .supported_schemes()
588 }
589}
590
591fn load_certs(path: &str) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
592 let file = File::open(path).context(LoadTlsCertificateSnafu { path })?;
593 let mut reader = BufReader::new(file);
594 let certs = certs(&mut reader)
595 .collect::<std::result::Result<Vec<_>, _>>()
596 .map_err(|e| {
597 PostgresTlsConfigSnafu {
598 reason: format!("Failed to parse certificates from {}: {}", path, e),
599 }
600 .build()
601 })?;
602 Ok(certs)
603}
604
605fn load_private_key(path: &str) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
606 let file = File::open(path).context(LoadTlsCertificateSnafu { path })?;
607 let mut reader = BufReader::new(file);
608 let key = private_key(&mut reader)
609 .map_err(|e| {
610 PostgresTlsConfigSnafu {
611 reason: format!("Failed to parse private key from {}: {}", path, e),
612 }
613 .build()
614 })?
615 .ok_or_else(|| {
616 PostgresTlsConfigSnafu {
617 reason: format!("No private key found in {}", path),
618 }
619 .build()
620 })?;
621 Ok(key)
622}
623
624fn load_ca(path: &str) -> Result<Arc<rustls::RootCertStore>> {
625 let mut root_store = rustls::RootCertStore::empty();
626
627 match rustls_native_certs::load_native_certs() {
629 Ok(certs) => {
630 let num_certs = certs.len();
631 for cert in certs {
632 if let Err(e) = root_store.add(cert) {
633 return PostgresTlsConfigSnafu {
634 reason: format!("Failed to add root certificate: {}", e),
635 }
636 .fail();
637 }
638 }
639 common_telemetry::info!("Loaded {num_certs} system root certificates successfully");
640 }
641 Err(e) => {
642 return PostgresTlsConfigSnafu {
643 reason: format!("Failed to load system root certificates: {}", e),
644 }
645 .fail();
646 }
647 }
648
649 if !path.is_empty() {
651 let ca_certs = load_certs(path)?;
652 for cert in ca_certs {
653 if let Err(e) = root_store.add(cert) {
654 return PostgresTlsConfigSnafu {
655 reason: format!("Failed to add custom CA certificate: {}", e),
656 }
657 .fail();
658 }
659 }
660 common_telemetry::info!("Added custom CA certificate from {}", path);
661 }
662
663 Ok(Arc::new(root_store))
664}
665
666#[async_trait::async_trait]
667impl KvQueryExecutor<PgClient> for PgStore {
668 async fn range_with_query_executor(
669 &self,
670 query_executor: &mut ExecutorImpl<'_, PgClient>,
671 req: RangeRequest,
672 ) -> Result<RangeResponse> {
673 let template_type = range_template(&req.key, &req.range_end);
674 let template = self.sql_template_set.range_template.get(template_type);
675 let params = template_type.build_params(req.key, req.range_end);
676 let params_ref = params.iter().collect::<Vec<_>>();
677 let query =
679 RangeTemplate::with_limit(template, if req.limit == 0 { 0 } else { req.limit + 1 });
680 let limit = req.limit as usize;
681 debug!("query: {:?}, params: {:?}", query, params);
682 let mut kvs = crate::record_rds_sql_execute_elapsed!(
683 query_executor.query(&query, ¶ms_ref).await,
684 PG_STORE_NAME,
685 RDS_STORE_OP_RANGE_QUERY,
686 template_type.as_ref()
687 )?;
688
689 if req.keys_only {
690 kvs.iter_mut().for_each(|kv| kv.value = vec![]);
691 }
692 if limit == 0 || kvs.len() <= limit {
694 return Ok(RangeResponse { kvs, more: false });
695 }
696 let removed = kvs.pop();
698 debug_assert!(removed.is_some());
699 Ok(RangeResponse { kvs, more: true })
700 }
701
702 async fn batch_put_with_query_executor(
703 &self,
704 query_executor: &mut ExecutorImpl<'_, PgClient>,
705 req: BatchPutRequest,
706 ) -> Result<BatchPutResponse> {
707 let mut in_params = Vec::with_capacity(req.kvs.len() * 3);
708 let mut values_params = Vec::with_capacity(req.kvs.len() * 2);
709
710 for kv in &req.kvs {
711 let processed_key = &kv.key;
712 in_params.push(processed_key);
713
714 let processed_value = &kv.value;
715 values_params.push(processed_key);
716 values_params.push(processed_value);
717 }
718 in_params.extend(values_params);
719 let params = in_params.iter().map(|x| x as _).collect::<Vec<_>>();
720 let query = self
721 .sql_template_set
722 .generate_batch_upsert_query(req.kvs.len());
723
724 let kvs = crate::record_rds_sql_execute_elapsed!(
725 query_executor.query(&query, ¶ms).await,
726 PG_STORE_NAME,
727 RDS_STORE_OP_BATCH_PUT,
728 ""
729 )?;
730 if req.prev_kv {
731 Ok(BatchPutResponse { prev_kvs: kvs })
732 } else {
733 Ok(BatchPutResponse::default())
734 }
735 }
736
737 async fn batch_get_with_query_executor(
739 &self,
740 query_executor: &mut ExecutorImpl<'_, PgClient>,
741 req: BatchGetRequest,
742 ) -> Result<BatchGetResponse> {
743 if req.keys.is_empty() {
744 return Ok(BatchGetResponse { kvs: vec![] });
745 }
746 let query = self
747 .sql_template_set
748 .generate_batch_get_query(req.keys.len());
749 let params = req.keys.iter().map(|x| x as _).collect::<Vec<_>>();
750 let kvs = crate::record_rds_sql_execute_elapsed!(
751 query_executor.query(&query, ¶ms).await,
752 PG_STORE_NAME,
753 RDS_STORE_OP_BATCH_GET,
754 ""
755 )?;
756 Ok(BatchGetResponse { kvs })
757 }
758
759 async fn delete_range_with_query_executor(
760 &self,
761 query_executor: &mut ExecutorImpl<'_, PgClient>,
762 req: DeleteRangeRequest,
763 ) -> Result<DeleteRangeResponse> {
764 let template_type = range_template(&req.key, &req.range_end);
765 let template = self.sql_template_set.delete_template.get(template_type);
766 let params = template_type.build_params(req.key, req.range_end);
767 let params_ref = params.iter().map(|x| x as _).collect::<Vec<_>>();
768 let kvs = crate::record_rds_sql_execute_elapsed!(
769 query_executor.query(template, ¶ms_ref).await,
770 PG_STORE_NAME,
771 RDS_STORE_OP_RANGE_DELETE,
772 template_type.as_ref()
773 )?;
774 let mut resp = DeleteRangeResponse::new(kvs.len() as i64);
775 if req.prev_kv {
776 resp.with_prev_kvs(kvs);
777 }
778 Ok(resp)
779 }
780
781 async fn batch_delete_with_query_executor(
782 &self,
783 query_executor: &mut ExecutorImpl<'_, PgClient>,
784 req: BatchDeleteRequest,
785 ) -> Result<BatchDeleteResponse> {
786 if req.keys.is_empty() {
787 return Ok(BatchDeleteResponse::default());
788 }
789 let query = self
790 .sql_template_set
791 .generate_batch_delete_query(req.keys.len());
792 let params = req.keys.iter().map(|x| x as _).collect::<Vec<_>>();
793
794 let kvs = crate::record_rds_sql_execute_elapsed!(
795 query_executor.query(&query, ¶ms).await,
796 PG_STORE_NAME,
797 RDS_STORE_OP_BATCH_DELETE,
798 ""
799 )?;
800 if req.prev_kv {
801 Ok(BatchDeleteResponse { prev_kvs: kvs })
802 } else {
803 Ok(BatchDeleteResponse::default())
804 }
805 }
806}
807
808impl PgStore {
809 pub async fn with_url_and_tls(
818 url: &str,
819 table_name: &str,
820 max_txn_ops: usize,
821 tls_config: Option<TlsOption>,
822 ) -> Result<KvBackendRef> {
823 let mut cfg = Config::new();
824 cfg.url = Some(url.to_string());
825
826 let pool = match tls_config {
827 Some(tls_config) if tls_config.mode != TlsMode::Disable => {
828 match create_postgres_tls_connector(&tls_config) {
829 Ok(tls_connector) => cfg
830 .create_pool(Some(Runtime::Tokio1), tls_connector)
831 .context(CreatePostgresPoolSnafu)?,
832 Err(e) => {
833 if tls_config.mode == TlsMode::Prefer {
834 common_telemetry::info!(
836 "Failed to create TLS connector, falling back to insecure connection"
837 );
838 cfg.create_pool(Some(Runtime::Tokio1), NoTls)
839 .context(CreatePostgresPoolSnafu)?
840 } else {
841 return Err(e);
842 }
843 }
844 }
845 }
846 _ => cfg
847 .create_pool(Some(Runtime::Tokio1), NoTls)
848 .context(CreatePostgresPoolSnafu)?,
849 };
850
851 Self::with_pg_pool(pool, None, table_name, max_txn_ops).await
852 }
853
854 pub async fn with_url(url: &str, table_name: &str, max_txn_ops: usize) -> Result<KvBackendRef> {
856 Self::with_url_and_tls(url, table_name, max_txn_ops, None).await
857 }
858
859 pub async fn with_pg_pool(
861 pool: Pool,
862 schema_name: Option<&str>,
863 table_name: &str,
864 max_txn_ops: usize,
865 ) -> Result<KvBackendRef> {
866 let client = match pool.get().await {
868 Ok(client) => client,
869 Err(e) => {
870 return GetPostgresConnectionSnafu {
871 reason: e.to_string(),
872 }
873 .fail();
874 }
875 };
876 let template_factory = PgSqlTemplateFactory::new(schema_name, table_name);
877 let sql_template_set = template_factory.build();
878 client
880 .execute(&sql_template_set.create_table_statement, &[])
881 .await
882 .with_context(|_| PostgresExecutionSnafu {
883 sql: sql_template_set.create_table_statement.to_string(),
884 })?;
885 Ok(Arc::new(Self {
886 max_txn_ops,
887 sql_template_set,
888 txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
889 executor_factory: PgExecutorFactory { pool },
890 _phantom: PhantomData,
891 }))
892 }
893}
894
895#[cfg(test)]
896mod tests {
897 use super::*;
898 use crate::kv_backend::test::{
899 prepare_kv_with_prefix, test_kv_batch_delete_with_prefix, test_kv_batch_get_with_prefix,
900 test_kv_compare_and_put_with_prefix, test_kv_delete_range_with_prefix,
901 test_kv_put_with_prefix, test_kv_range_2_with_prefix, test_kv_range_with_prefix,
902 test_simple_kv_range, test_txn_compare_equal, test_txn_compare_greater,
903 test_txn_compare_less, test_txn_compare_not_equal, test_txn_one_compare_op,
904 text_txn_multi_compare_op, unprepare_kv,
905 };
906 use crate::{maybe_skip_postgres_integration_test, maybe_skip_postgres15_integration_test};
907
908 async fn build_pg_kv_backend(table_name: &str) -> Option<PgStore> {
909 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap_or_default();
910 if endpoints.is_empty() {
911 return None;
912 }
913
914 let mut cfg = Config::new();
915 cfg.url = Some(endpoints);
916 let pool = cfg
917 .create_pool(Some(Runtime::Tokio1), NoTls)
918 .context(CreatePostgresPoolSnafu)
919 .unwrap();
920 let client = pool.get().await.unwrap();
921 let template_factory = PgSqlTemplateFactory::new(None, table_name);
923 let sql_templates = template_factory.build();
924 client
926 .execute(&sql_templates.create_table_statement, &[])
927 .await
928 .context(PostgresExecutionSnafu {
929 sql: sql_templates.create_table_statement.to_string(),
930 })
931 .unwrap();
932 Some(PgStore {
933 max_txn_ops: 128,
934 sql_template_set: sql_templates,
935 txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
936 executor_factory: PgExecutorFactory { pool },
937 _phantom: PhantomData,
938 })
939 }
940
941 async fn build_pg15_pool() -> Option<Pool> {
942 let url = std::env::var("GT_POSTGRES15_ENDPOINTS").unwrap_or_default();
943 if url.is_empty() {
944 return None;
945 }
946 let mut cfg = Config::new();
947 cfg.url = Some(url);
948 let pool = cfg
949 .create_pool(Some(Runtime::Tokio1), NoTls)
950 .context(CreatePostgresPoolSnafu)
951 .ok()?;
952 Some(pool)
953 }
954
955 #[tokio::test]
956 async fn test_pg15_create_table_in_public_should_fail() {
957 maybe_skip_postgres15_integration_test!();
958 let Some(pool) = build_pg15_pool().await else {
959 return;
960 };
961 let res = PgStore::with_pg_pool(pool, None, "pg15_public_should_fail", 128).await;
962 assert!(
963 res.is_err(),
964 "creating table in public should fail for test_user"
965 );
966 }
967
968 #[tokio::test]
969 async fn test_pg15_create_table_in_test_schema_and_crud_should_succeed() {
970 maybe_skip_postgres15_integration_test!();
971 let Some(pool) = build_pg15_pool().await else {
972 return;
973 };
974 let schema_name = std::env::var("GT_POSTGRES15_SCHEMA").unwrap();
975 let client = pool.get().await.unwrap();
976 let factory = PgSqlTemplateFactory::new(Some(&schema_name), "pg15_ok");
977 let templates = factory.build();
978 client
979 .execute(&templates.create_table_statement, &[])
980 .await
981 .unwrap();
982 let kv = PgStore {
983 max_txn_ops: 128,
984 sql_template_set: templates,
985 txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
986 executor_factory: PgExecutorFactory { pool },
987 _phantom: PhantomData,
988 };
989 let prefix = b"pg15_crud/";
990 prepare_kv_with_prefix(&kv, prefix.to_vec()).await;
991 test_kv_put_with_prefix(&kv, prefix.to_vec()).await;
992 test_kv_batch_get_with_prefix(&kv, prefix.to_vec()).await;
993 unprepare_kv(&kv, prefix).await;
994 }
995
996 #[tokio::test]
997 async fn test_pg_put() {
998 maybe_skip_postgres_integration_test!();
999 let kv_backend = build_pg_kv_backend("put_test").await.unwrap();
1000 let prefix = b"put/";
1001 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1002 test_kv_put_with_prefix(&kv_backend, prefix.to_vec()).await;
1003 unprepare_kv(&kv_backend, prefix).await;
1004 }
1005
1006 #[tokio::test]
1007 async fn test_pg_range() {
1008 maybe_skip_postgres_integration_test!();
1009 let kv_backend = build_pg_kv_backend("range_test").await.unwrap();
1010 let prefix = b"range/";
1011 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1012 test_kv_range_with_prefix(&kv_backend, prefix.to_vec()).await;
1013 unprepare_kv(&kv_backend, prefix).await;
1014 }
1015
1016 #[tokio::test]
1017 async fn test_pg_range_2() {
1018 maybe_skip_postgres_integration_test!();
1019 let kv_backend = build_pg_kv_backend("range2_test").await.unwrap();
1020 let prefix = b"range2/";
1021 test_kv_range_2_with_prefix(&kv_backend, prefix.to_vec()).await;
1022 unprepare_kv(&kv_backend, prefix).await;
1023 }
1024
1025 #[tokio::test]
1026 async fn test_pg_all_range() {
1027 maybe_skip_postgres_integration_test!();
1028 let kv_backend = build_pg_kv_backend("simple_range_test").await.unwrap();
1029 let prefix = b"";
1030 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1031 test_simple_kv_range(&kv_backend).await;
1032 unprepare_kv(&kv_backend, prefix).await;
1033 }
1034
1035 #[tokio::test]
1036 async fn test_pg_batch_get() {
1037 maybe_skip_postgres_integration_test!();
1038 let kv_backend = build_pg_kv_backend("batch_get_test").await.unwrap();
1039 let prefix = b"batch_get/";
1040 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1041 test_kv_batch_get_with_prefix(&kv_backend, prefix.to_vec()).await;
1042 unprepare_kv(&kv_backend, prefix).await;
1043 }
1044
1045 #[tokio::test]
1046 async fn test_pg_batch_delete() {
1047 maybe_skip_postgres_integration_test!();
1048 let kv_backend = build_pg_kv_backend("batch_delete_test").await.unwrap();
1049 let prefix = b"batch_delete/";
1050 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1051 test_kv_delete_range_with_prefix(&kv_backend, prefix.to_vec()).await;
1052 unprepare_kv(&kv_backend, prefix).await;
1053 }
1054
1055 #[tokio::test]
1056 async fn test_pg_batch_delete_with_prefix() {
1057 maybe_skip_postgres_integration_test!();
1058 let kv_backend = build_pg_kv_backend("batch_delete_with_prefix_test")
1059 .await
1060 .unwrap();
1061 let prefix = b"batch_delete/";
1062 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1063 test_kv_batch_delete_with_prefix(&kv_backend, prefix.to_vec()).await;
1064 unprepare_kv(&kv_backend, prefix).await;
1065 }
1066
1067 #[tokio::test]
1068 async fn test_pg_delete_range() {
1069 maybe_skip_postgres_integration_test!();
1070 let kv_backend = build_pg_kv_backend("delete_range_test").await.unwrap();
1071 let prefix = b"delete_range/";
1072 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1073 test_kv_delete_range_with_prefix(&kv_backend, prefix.to_vec()).await;
1074 unprepare_kv(&kv_backend, prefix).await;
1075 }
1076
1077 #[tokio::test]
1078 async fn test_pg_compare_and_put() {
1079 maybe_skip_postgres_integration_test!();
1080 let kv_backend = build_pg_kv_backend("compare_and_put_test").await.unwrap();
1081 let prefix = b"compare_and_put/";
1082 let kv_backend = Arc::new(kv_backend);
1083 test_kv_compare_and_put_with_prefix(kv_backend.clone(), prefix.to_vec()).await;
1084 }
1085
1086 #[tokio::test]
1087 async fn test_pg_txn() {
1088 maybe_skip_postgres_integration_test!();
1089 let kv_backend = build_pg_kv_backend("txn_test").await.unwrap();
1090 test_txn_one_compare_op(&kv_backend).await;
1091 text_txn_multi_compare_op(&kv_backend).await;
1092 test_txn_compare_equal(&kv_backend).await;
1093 test_txn_compare_greater(&kv_backend).await;
1094 test_txn_compare_less(&kv_backend).await;
1095 test_txn_compare_not_equal(&kv_backend).await;
1096 }
1097
1098 #[test]
1099 fn test_pg_template_with_schema() {
1100 let factory = PgSqlTemplateFactory::new(Some("test_schema"), "greptime_metakv");
1101 let t = factory.build();
1102 assert!(
1103 t.create_table_statement
1104 .contains("\"test_schema\".\"greptime_metakv\"")
1105 );
1106 let upsert = t.generate_batch_upsert_query(1);
1107 assert!(upsert.contains("\"test_schema\".\"greptime_metakv\""));
1108 let get = t.generate_batch_get_query(1);
1109 assert!(get.contains("\"test_schema\".\"greptime_metakv\""));
1110 let del = t.generate_batch_delete_query(1);
1111 assert!(del.contains("\"test_schema\".\"greptime_metakv\""));
1112 }
1113
1114 #[test]
1115 fn test_format_table_ident() {
1116 let t = PgSqlTemplateFactory::format_table_ident(None, "test_table");
1117 assert_eq!(t, "\"test_table\"");
1118
1119 let t = PgSqlTemplateFactory::format_table_ident(Some("test_schema"), "test_table");
1120 assert_eq!(t, "\"test_schema\".\"test_table\"");
1121
1122 let t = PgSqlTemplateFactory::format_table_ident(Some(""), "test_table");
1123 assert_eq!(t, "\"test_table\"");
1124 }
1125}