1use std::fs::File;
16use std::io::BufReader;
17use std::marker::PhantomData;
18use std::sync::Arc;
19
20use common_telemetry::debug;
21use deadpool_postgres::{Config, Pool, Runtime};
22use rustls::ClientConfig;
24use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
25use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
26use rustls::server::ParsedCertificate;
27use rustls::{DigitallySignedStruct, Error as TlsError, SignatureScheme};
28use rustls_pemfile::{certs, private_key};
29use snafu::ResultExt;
30use strum::AsRefStr;
31use tokio_postgres::types::ToSql;
32use tokio_postgres::{IsolationLevel, NoTls, Row};
33use tokio_postgres_rustls::MakeRustlsConnect;
34
35use crate::error::{
36 CreatePostgresPoolSnafu, GetPostgresConnectionSnafu, LoadTlsCertificateSnafu,
37 PostgresExecutionSnafu, PostgresTlsConfigSnafu, PostgresTransactionSnafu, Result,
38};
39use crate::kv_backend::KvBackendRef;
40use crate::kv_backend::rds::{
41 Executor, ExecutorFactory, ExecutorImpl, KvQueryExecutor, RDS_STORE_OP_BATCH_DELETE,
42 RDS_STORE_OP_BATCH_GET, RDS_STORE_OP_BATCH_PUT, RDS_STORE_OP_RANGE_DELETE,
43 RDS_STORE_OP_RANGE_QUERY, RDS_STORE_TXN_RETRY_COUNT, RdsStore, Transaction,
44 ensure_rustls_crypto_provider_installed,
45};
46use crate::rpc::KeyValue;
47use crate::rpc::store::{
48 BatchDeleteRequest, BatchDeleteResponse, BatchGetRequest, BatchGetResponse, BatchPutRequest,
49 BatchPutResponse, DeleteRangeRequest, DeleteRangeResponse, RangeRequest, RangeResponse,
50};
51
52#[derive(Debug, Clone, PartialEq, Eq, Default)]
55pub enum TlsMode {
56 Disable,
57 #[default]
58 Prefer,
59 Require,
60 VerifyCa,
61 VerifyFull,
62}
63
64#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct TlsOption {
68 pub mode: TlsMode,
69 pub cert_path: String,
70 pub key_path: String,
71 pub ca_cert_path: String,
72 pub watch: bool,
73}
74
75impl Default for TlsOption {
76 fn default() -> Self {
77 TlsOption {
78 mode: TlsMode::Prefer,
79 cert_path: String::new(),
80 key_path: String::new(),
81 ca_cert_path: String::new(),
82 watch: false,
83 }
84 }
85}
86
87const PG_STORE_NAME: &str = "pg_store";
88
89pub struct PgClient(deadpool::managed::Object<deadpool_postgres::Manager>);
90pub struct PgTxnClient<'a>(deadpool_postgres::Transaction<'a>);
91
92fn key_value_from_row(r: Row) -> KeyValue {
94 KeyValue {
95 key: r.get(0),
96 value: r.get(1),
97 }
98}
99
100const EMPTY: &[u8] = &[0];
101
102#[derive(Debug, Clone, Copy, AsRefStr)]
104enum RangeTemplateType {
105 Point,
106 Range,
107 Full,
108 LeftBounded,
109 Prefix,
110}
111
112impl RangeTemplateType {
114 fn build_params(&self, mut key: Vec<u8>, range_end: Vec<u8>) -> Vec<Vec<u8>> {
117 match self {
118 RangeTemplateType::Point => vec![key],
119 RangeTemplateType::Range => vec![key, range_end],
120 RangeTemplateType::Full => vec![],
121 RangeTemplateType::LeftBounded => vec![key],
122 RangeTemplateType::Prefix => {
123 key.push(b'%');
124 vec![key]
125 }
126 }
127 }
128}
129
130#[derive(Debug, Clone)]
132struct RangeTemplate {
133 point: String,
134 range: String,
135 full: String,
136 left_bounded: String,
137 prefix: String,
138}
139
140impl RangeTemplate {
141 fn get(&self, typ: RangeTemplateType) -> &str {
143 match typ {
144 RangeTemplateType::Point => &self.point,
145 RangeTemplateType::Range => &self.range,
146 RangeTemplateType::Full => &self.full,
147 RangeTemplateType::LeftBounded => &self.left_bounded,
148 RangeTemplateType::Prefix => &self.prefix,
149 }
150 }
151
152 fn with_limit(template: &str, limit: i64) -> String {
154 if limit == 0 {
155 return format!("{};", template);
156 }
157 format!("{} LIMIT {};", template, limit)
158 }
159}
160
161fn is_prefix_range(start: &[u8], end: &[u8]) -> bool {
162 if start.len() != end.len() {
163 return false;
164 }
165 let l = start.len();
166 let same_prefix = start[0..l - 1] == end[0..l - 1];
167 if let (Some(rhs), Some(lhs)) = (start.last(), end.last()) {
168 return same_prefix && (*rhs + 1) == *lhs;
169 }
170 false
171}
172
173fn range_template(key: &[u8], range_end: &[u8]) -> RangeTemplateType {
175 match (key, range_end) {
176 (_, &[]) => RangeTemplateType::Point,
177 (EMPTY, EMPTY) => RangeTemplateType::Full,
178 (_, EMPTY) => RangeTemplateType::LeftBounded,
179 (start, end) => {
180 if is_prefix_range(start, end) {
181 RangeTemplateType::Prefix
182 } else {
183 RangeTemplateType::Range
184 }
185 }
186 }
187}
188
189fn pg_generate_in_placeholders(from: usize, to: usize) -> Vec<String> {
191 (from..=to).map(|i| format!("${}", i)).collect()
192}
193
194struct PgSqlTemplateFactory<'a> {
196 schema_name: Option<&'a str>,
197 table_name: &'a str,
198}
199
200impl<'a> PgSqlTemplateFactory<'a> {
201 fn new(schema_name: Option<&'a str>, table_name: &'a str) -> Self {
203 Self {
204 schema_name,
205 table_name,
206 }
207 }
208
209 fn build(&self) -> PgSqlTemplateSet {
211 let table_ident = Self::format_table_ident(self.schema_name, self.table_name);
212 PgSqlTemplateSet {
214 table_ident: table_ident.clone(),
215 create_table_statement: format!(
217 "CREATE TABLE IF NOT EXISTS {table_ident}(k bytea PRIMARY KEY, v bytea)",
218 ),
219 range_template: RangeTemplate {
220 point: format!("SELECT k, v FROM {table_ident} WHERE k = $1"),
221 range: format!(
222 "SELECT k, v FROM {table_ident} WHERE k >= $1 AND k < $2 ORDER BY k"
223 ),
224 full: format!("SELECT k, v FROM {table_ident} ORDER BY k"),
225 left_bounded: format!("SELECT k, v FROM {table_ident} WHERE k >= $1 ORDER BY k"),
226 prefix: format!("SELECT k, v FROM {table_ident} WHERE k LIKE $1 ORDER BY k"),
227 },
228 delete_template: RangeTemplate {
229 point: format!("DELETE FROM {table_ident} WHERE k = $1 RETURNING k,v;"),
230 range: format!("DELETE FROM {table_ident} WHERE k >= $1 AND k < $2 RETURNING k,v;"),
231 full: format!("DELETE FROM {table_ident} RETURNING k,v"),
232 left_bounded: format!("DELETE FROM {table_ident} WHERE k >= $1 RETURNING k,v;"),
233 prefix: format!("DELETE FROM {table_ident} WHERE k LIKE $1 RETURNING k,v;"),
234 },
235 }
236 }
237
238 fn format_table_ident(schema_name: Option<&str>, table_name: &str) -> String {
240 match schema_name {
241 Some(s) if !s.is_empty() => format!("\"{}\".\"{}\"", s, table_name),
242 _ => format!("\"{}\"", table_name),
243 }
244 }
245}
246
247#[derive(Debug, Clone)]
249pub struct PgSqlTemplateSet {
250 table_ident: String,
251 create_table_statement: String,
252 range_template: RangeTemplate,
253 delete_template: RangeTemplate,
254}
255
256impl PgSqlTemplateSet {
257 fn generate_batch_get_query(&self, key_len: usize) -> String {
259 let in_clause = pg_generate_in_placeholders(1, key_len).join(", ");
260 format!(
261 "SELECT k, v FROM {} WHERE k in ({});",
262 self.table_ident, in_clause
263 )
264 }
265
266 fn generate_batch_delete_query(&self, key_len: usize) -> String {
268 let in_clause = pg_generate_in_placeholders(1, key_len).join(", ");
269 format!(
270 "DELETE FROM {} WHERE k in ({}) RETURNING k,v;",
271 self.table_ident, in_clause
272 )
273 }
274
275 fn generate_batch_upsert_query(&self, kv_len: usize) -> String {
277 let in_placeholders: Vec<String> = (1..=kv_len).map(|i| format!("${}", i)).collect();
278 let in_clause = in_placeholders.join(", ");
279 let mut param_index = kv_len + 1;
280 let mut values_placeholders = Vec::new();
281 for _ in 0..kv_len {
282 values_placeholders.push(format!("(${0}, ${1})", param_index, param_index + 1));
283 param_index += 2;
284 }
285 let values_clause = values_placeholders.join(", ");
286
287 format!(
288 r#"
289 WITH prev AS (
290 SELECT k,v FROM {table} WHERE k IN ({in_clause})
291 ), update AS (
292 INSERT INTO {table} (k, v) VALUES
293 {values_clause}
294 ON CONFLICT (
295 k
296 ) DO UPDATE SET
297 v = excluded.v
298 )
299
300 SELECT k, v FROM prev;
301 "#,
302 table = self.table_ident,
303 in_clause = in_clause,
304 values_clause = values_clause
305 )
306 }
307}
308
309#[async_trait::async_trait]
310impl Executor for PgClient {
311 type Transaction<'a>
312 = PgTxnClient<'a>
313 where
314 Self: 'a;
315
316 fn name() -> &'static str {
317 "Postgres"
318 }
319
320 async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>> {
321 let params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
322 let stmt = self
323 .0
324 .prepare_cached(query)
325 .await
326 .context(PostgresExecutionSnafu { sql: query })?;
327 let rows = self
328 .0
329 .query(&stmt, ¶ms)
330 .await
331 .context(PostgresExecutionSnafu { sql: query })?;
332 Ok(rows.into_iter().map(key_value_from_row).collect())
333 }
334
335 async fn txn_executor<'a>(&'a mut self) -> Result<Self::Transaction<'a>> {
336 let txn = self
337 .0
338 .build_transaction()
339 .isolation_level(IsolationLevel::Serializable)
340 .start()
341 .await
342 .context(PostgresTransactionSnafu {
343 operation: "begin".to_string(),
344 })?;
345 Ok(PgTxnClient(txn))
346 }
347}
348
349#[async_trait::async_trait]
350impl<'a> Transaction<'a> for PgTxnClient<'a> {
351 async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>> {
352 let params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
353 let stmt = self
354 .0
355 .prepare_cached(query)
356 .await
357 .context(PostgresExecutionSnafu { sql: query })?;
358 let rows = self
359 .0
360 .query(&stmt, ¶ms)
361 .await
362 .context(PostgresExecutionSnafu { sql: query })?;
363 Ok(rows.into_iter().map(key_value_from_row).collect())
364 }
365
366 async fn commit(self) -> Result<()> {
367 self.0.commit().await.context(PostgresTransactionSnafu {
368 operation: "commit",
369 })?;
370 Ok(())
371 }
372}
373
374pub struct PgExecutorFactory {
375 pool: Pool,
376}
377
378impl PgExecutorFactory {
379 async fn client(&self) -> Result<PgClient> {
380 match self.pool.get().await {
381 Ok(client) => Ok(PgClient(client)),
382 Err(e) => GetPostgresConnectionSnafu {
383 reason: e.to_string(),
384 }
385 .fail(),
386 }
387 }
388}
389
390#[async_trait::async_trait]
391impl ExecutorFactory<PgClient> for PgExecutorFactory {
392 async fn default_executor(&self) -> Result<PgClient> {
393 self.client().await
394 }
395
396 async fn txn_executor<'a>(
397 &self,
398 default_executor: &'a mut PgClient,
399 ) -> Result<PgTxnClient<'a>> {
400 default_executor.txn_executor().await
401 }
402}
403
404pub type PgStore = RdsStore<PgClient, PgExecutorFactory, PgSqlTemplateSet>;
407
408pub fn create_postgres_tls_connector(tls_config: &TlsOption) -> Result<MakeRustlsConnect> {
422 common_telemetry::info!(
423 "Creating PostgreSQL TLS connector with mode: {:?}",
424 tls_config.mode
425 );
426 ensure_rustls_crypto_provider_installed()?;
427
428 let config_builder = match tls_config.mode {
429 TlsMode::Disable => {
430 return PostgresTlsConfigSnafu {
431 reason: "Cannot create TLS connector for Disable mode".to_string(),
432 }
433 .fail();
434 }
435 TlsMode::Prefer | TlsMode::Require => {
436 let verifier = Arc::new(AcceptAnyVerifier);
438 ClientConfig::builder()
439 .dangerous()
440 .with_custom_certificate_verifier(verifier)
441 }
442 TlsMode::VerifyCa => {
443 let ca_store = load_ca(&tls_config.ca_cert_path)?;
445 let verifier = Arc::new(NoHostnameVerification { roots: ca_store });
446 ClientConfig::builder()
447 .dangerous()
448 .with_custom_certificate_verifier(verifier)
449 }
450 TlsMode::VerifyFull => {
451 let ca_store = load_ca(&tls_config.ca_cert_path)?;
452 ClientConfig::builder().with_root_certificates(ca_store)
453 }
454 };
455
456 let client_config = if !tls_config.cert_path.is_empty() && !tls_config.key_path.is_empty() {
458 common_telemetry::info!("Loading client certificate for mutual TLS");
460 let cert_chain = load_certs(&tls_config.cert_path)?;
461 let private_key = load_private_key(&tls_config.key_path)?;
462
463 config_builder
464 .with_client_auth_cert(cert_chain, private_key)
465 .map_err(|e| {
466 PostgresTlsConfigSnafu {
467 reason: format!("Failed to configure client authentication: {}", e),
468 }
469 .build()
470 })?
471 } else {
472 common_telemetry::info!("No client certificate provided, skip client authentication");
473 config_builder.with_no_client_auth()
474 };
475
476 common_telemetry::info!("Successfully created PostgreSQL TLS connector");
477 Ok(MakeRustlsConnect::new(client_config))
478}
479
480#[derive(Debug)]
482struct AcceptAnyVerifier;
483
484impl ServerCertVerifier for AcceptAnyVerifier {
485 fn verify_server_cert(
486 &self,
487 _end_entity: &CertificateDer<'_>,
488 _intermediates: &[CertificateDer<'_>],
489 _server_name: &ServerName<'_>,
490 _ocsp_response: &[u8],
491 _now: UnixTime,
492 ) -> std::result::Result<ServerCertVerified, TlsError> {
493 common_telemetry::debug!(
494 "Accepting server certificate without verification (Prefer/Require mode)"
495 );
496 Ok(ServerCertVerified::assertion())
497 }
498
499 fn verify_tls12_signature(
500 &self,
501 _message: &[u8],
502 _cert: &CertificateDer<'_>,
503 _dss: &DigitallySignedStruct,
504 ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
505 Ok(HandshakeSignatureValid::assertion())
507 }
508
509 fn verify_tls13_signature(
510 &self,
511 _message: &[u8],
512 _cert: &CertificateDer<'_>,
513 _dss: &DigitallySignedStruct,
514 ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
515 Ok(HandshakeSignatureValid::assertion())
517 }
518
519 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
520 rustls::crypto::aws_lc_rs::default_provider()
522 .signature_verification_algorithms
523 .supported_schemes()
524 }
525}
526
527#[derive(Debug)]
530struct NoHostnameVerification {
531 roots: Arc<rustls::RootCertStore>,
532}
533
534impl ServerCertVerifier for NoHostnameVerification {
535 fn verify_server_cert(
536 &self,
537 end_entity: &CertificateDer<'_>,
538 intermediates: &[CertificateDer<'_>],
539 _server_name: &ServerName<'_>,
540 _ocsp_response: &[u8],
541 now: UnixTime,
542 ) -> std::result::Result<ServerCertVerified, TlsError> {
543 let cert = ParsedCertificate::try_from(end_entity)?;
544 rustls::client::verify_server_cert_signed_by_trust_anchor(
545 &cert,
546 &self.roots,
547 intermediates,
548 now,
549 rustls::crypto::aws_lc_rs::default_provider()
550 .signature_verification_algorithms
551 .all,
552 )?;
553
554 Ok(ServerCertVerified::assertion())
555 }
556
557 fn verify_tls12_signature(
558 &self,
559 message: &[u8],
560 cert: &CertificateDer<'_>,
561 dss: &DigitallySignedStruct,
562 ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
563 rustls::crypto::verify_tls12_signature(
564 message,
565 cert,
566 dss,
567 &rustls::crypto::aws_lc_rs::default_provider().signature_verification_algorithms,
568 )
569 }
570
571 fn verify_tls13_signature(
572 &self,
573 message: &[u8],
574 cert: &CertificateDer<'_>,
575 dss: &DigitallySignedStruct,
576 ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
577 rustls::crypto::verify_tls13_signature(
578 message,
579 cert,
580 dss,
581 &rustls::crypto::aws_lc_rs::default_provider().signature_verification_algorithms,
582 )
583 }
584
585 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
586 rustls::crypto::aws_lc_rs::default_provider()
588 .signature_verification_algorithms
589 .supported_schemes()
590 }
591}
592
593fn load_certs(path: &str) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
594 let file = File::open(path).context(LoadTlsCertificateSnafu { path })?;
595 let mut reader = BufReader::new(file);
596 let certs = certs(&mut reader)
597 .collect::<std::result::Result<Vec<_>, _>>()
598 .map_err(|e| {
599 PostgresTlsConfigSnafu {
600 reason: format!("Failed to parse certificates from {}: {}", path, e),
601 }
602 .build()
603 })?;
604 Ok(certs)
605}
606
607fn load_private_key(path: &str) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
608 let file = File::open(path).context(LoadTlsCertificateSnafu { path })?;
609 let mut reader = BufReader::new(file);
610 let key = private_key(&mut reader)
611 .map_err(|e| {
612 PostgresTlsConfigSnafu {
613 reason: format!("Failed to parse private key from {}: {}", path, e),
614 }
615 .build()
616 })?
617 .ok_or_else(|| {
618 PostgresTlsConfigSnafu {
619 reason: format!("No private key found in {}", path),
620 }
621 .build()
622 })?;
623 Ok(key)
624}
625
626fn load_ca(path: &str) -> Result<Arc<rustls::RootCertStore>> {
627 let mut root_store = rustls::RootCertStore::empty();
628
629 match rustls_native_certs::load_native_certs() {
631 Ok(certs) => {
632 let num_certs = certs.len();
633 for cert in certs {
634 if let Err(e) = root_store.add(cert) {
635 return PostgresTlsConfigSnafu {
636 reason: format!("Failed to add root certificate: {}", e),
637 }
638 .fail();
639 }
640 }
641 common_telemetry::info!("Loaded {num_certs} system root certificates successfully");
642 }
643 Err(e) => {
644 return PostgresTlsConfigSnafu {
645 reason: format!("Failed to load system root certificates: {}", e),
646 }
647 .fail();
648 }
649 }
650
651 if !path.is_empty() {
653 let ca_certs = load_certs(path)?;
654 for cert in ca_certs {
655 if let Err(e) = root_store.add(cert) {
656 return PostgresTlsConfigSnafu {
657 reason: format!("Failed to add custom CA certificate: {}", e),
658 }
659 .fail();
660 }
661 }
662 common_telemetry::info!("Added custom CA certificate from {}", path);
663 }
664
665 Ok(Arc::new(root_store))
666}
667
668#[async_trait::async_trait]
669impl KvQueryExecutor<PgClient> for PgStore {
670 async fn range_with_query_executor(
671 &self,
672 query_executor: &mut ExecutorImpl<'_, PgClient>,
673 req: RangeRequest,
674 ) -> Result<RangeResponse> {
675 let template_type = range_template(&req.key, &req.range_end);
676 let template = self.sql_template_set.range_template.get(template_type);
677 let params = template_type.build_params(req.key, req.range_end);
678 let params_ref = params.iter().collect::<Vec<_>>();
679 let query =
681 RangeTemplate::with_limit(template, if req.limit == 0 { 0 } else { req.limit + 1 });
682 let limit = req.limit as usize;
683 debug!("query: {:?}, params: {:?}", query, params);
684 let mut kvs = crate::record_rds_sql_execute_elapsed!(
685 query_executor.query(&query, ¶ms_ref).await,
686 PG_STORE_NAME,
687 RDS_STORE_OP_RANGE_QUERY,
688 template_type.as_ref()
689 )?;
690
691 if req.keys_only {
692 kvs.iter_mut().for_each(|kv| kv.value = vec![]);
693 }
694 if limit == 0 || kvs.len() <= limit {
696 return Ok(RangeResponse { kvs, more: false });
697 }
698 let removed = kvs.pop();
700 debug_assert!(removed.is_some());
701 Ok(RangeResponse { kvs, more: true })
702 }
703
704 async fn batch_put_with_query_executor(
705 &self,
706 query_executor: &mut ExecutorImpl<'_, PgClient>,
707 req: BatchPutRequest,
708 ) -> Result<BatchPutResponse> {
709 let mut in_params = Vec::with_capacity(req.kvs.len() * 3);
710 let mut values_params = Vec::with_capacity(req.kvs.len() * 2);
711
712 for kv in &req.kvs {
713 let processed_key = &kv.key;
714 in_params.push(processed_key);
715
716 let processed_value = &kv.value;
717 values_params.push(processed_key);
718 values_params.push(processed_value);
719 }
720 in_params.extend(values_params);
721 let params = in_params.iter().map(|x| x as _).collect::<Vec<_>>();
722 let query = self
723 .sql_template_set
724 .generate_batch_upsert_query(req.kvs.len());
725
726 let kvs = crate::record_rds_sql_execute_elapsed!(
727 query_executor.query(&query, ¶ms).await,
728 PG_STORE_NAME,
729 RDS_STORE_OP_BATCH_PUT,
730 ""
731 )?;
732 if req.prev_kv {
733 Ok(BatchPutResponse { prev_kvs: kvs })
734 } else {
735 Ok(BatchPutResponse::default())
736 }
737 }
738
739 async fn batch_get_with_query_executor(
741 &self,
742 query_executor: &mut ExecutorImpl<'_, PgClient>,
743 req: BatchGetRequest,
744 ) -> Result<BatchGetResponse> {
745 if req.keys.is_empty() {
746 return Ok(BatchGetResponse { kvs: vec![] });
747 }
748 let query = self
749 .sql_template_set
750 .generate_batch_get_query(req.keys.len());
751 let params = req.keys.iter().map(|x| x as _).collect::<Vec<_>>();
752 let kvs = crate::record_rds_sql_execute_elapsed!(
753 query_executor.query(&query, ¶ms).await,
754 PG_STORE_NAME,
755 RDS_STORE_OP_BATCH_GET,
756 ""
757 )?;
758 Ok(BatchGetResponse { kvs })
759 }
760
761 async fn delete_range_with_query_executor(
762 &self,
763 query_executor: &mut ExecutorImpl<'_, PgClient>,
764 req: DeleteRangeRequest,
765 ) -> Result<DeleteRangeResponse> {
766 let template_type = range_template(&req.key, &req.range_end);
767 let template = self.sql_template_set.delete_template.get(template_type);
768 let params = template_type.build_params(req.key, req.range_end);
769 let params_ref = params.iter().map(|x| x as _).collect::<Vec<_>>();
770 let kvs = crate::record_rds_sql_execute_elapsed!(
771 query_executor.query(template, ¶ms_ref).await,
772 PG_STORE_NAME,
773 RDS_STORE_OP_RANGE_DELETE,
774 template_type.as_ref()
775 )?;
776 let mut resp = DeleteRangeResponse::new(kvs.len() as i64);
777 if req.prev_kv {
778 resp.with_prev_kvs(kvs);
779 }
780 Ok(resp)
781 }
782
783 async fn batch_delete_with_query_executor(
784 &self,
785 query_executor: &mut ExecutorImpl<'_, PgClient>,
786 req: BatchDeleteRequest,
787 ) -> Result<BatchDeleteResponse> {
788 if req.keys.is_empty() {
789 return Ok(BatchDeleteResponse::default());
790 }
791 let query = self
792 .sql_template_set
793 .generate_batch_delete_query(req.keys.len());
794 let params = req.keys.iter().map(|x| x as _).collect::<Vec<_>>();
795
796 let kvs = crate::record_rds_sql_execute_elapsed!(
797 query_executor.query(&query, ¶ms).await,
798 PG_STORE_NAME,
799 RDS_STORE_OP_BATCH_DELETE,
800 ""
801 )?;
802 if req.prev_kv {
803 Ok(BatchDeleteResponse { prev_kvs: kvs })
804 } else {
805 Ok(BatchDeleteResponse::default())
806 }
807 }
808}
809
810impl PgStore {
811 pub async fn with_url_and_tls(
820 url: &str,
821 table_name: &str,
822 max_txn_ops: usize,
823 tls_config: Option<TlsOption>,
824 ) -> Result<KvBackendRef> {
825 let mut cfg = Config::new();
826 cfg.url = Some(url.to_string());
827
828 let pool = match tls_config {
829 Some(tls_config) if tls_config.mode != TlsMode::Disable => {
830 match create_postgres_tls_connector(&tls_config) {
831 Ok(tls_connector) => cfg
832 .create_pool(Some(Runtime::Tokio1), tls_connector)
833 .context(CreatePostgresPoolSnafu)?,
834 Err(e) => {
835 if tls_config.mode == TlsMode::Prefer {
836 common_telemetry::info!(
838 "Failed to create TLS connector, falling back to insecure connection"
839 );
840 cfg.create_pool(Some(Runtime::Tokio1), NoTls)
841 .context(CreatePostgresPoolSnafu)?
842 } else {
843 return Err(e);
844 }
845 }
846 }
847 }
848 _ => cfg
849 .create_pool(Some(Runtime::Tokio1), NoTls)
850 .context(CreatePostgresPoolSnafu)?,
851 };
852
853 Self::with_pg_pool(pool, None, table_name, max_txn_ops, false).await
854 }
855
856 pub async fn with_url(url: &str, table_name: &str, max_txn_ops: usize) -> Result<KvBackendRef> {
858 Self::with_url_and_tls(url, table_name, max_txn_ops, None).await
859 }
860
861 pub async fn with_pg_pool(
863 pool: Pool,
864 schema_name: Option<&str>,
865 table_name: &str,
866 max_txn_ops: usize,
867 auto_create_schema: bool,
868 ) -> Result<KvBackendRef> {
869 let client = match pool.get().await {
871 Ok(client) => client,
872 Err(e) => {
873 common_telemetry::error!(e; "Failed to get Postgres connection.");
875 return GetPostgresConnectionSnafu {
876 reason: e.to_string(),
877 }
878 .fail();
879 }
880 };
881
882 if auto_create_schema
884 && let Some(schema) = schema_name
885 && !schema.is_empty()
886 {
887 let create_schema_sql = format!("CREATE SCHEMA IF NOT EXISTS \"{}\"", schema);
888 client
889 .execute(&create_schema_sql, &[])
890 .await
891 .with_context(|_| PostgresExecutionSnafu {
892 sql: create_schema_sql.clone(),
893 })?;
894 }
895
896 let template_factory = PgSqlTemplateFactory::new(schema_name, table_name);
897 let sql_template_set = template_factory.build();
898 client
899 .execute(&sql_template_set.create_table_statement, &[])
900 .await
901 .with_context(|_| PostgresExecutionSnafu {
902 sql: sql_template_set.create_table_statement.clone(),
903 })?;
904 Ok(Arc::new(Self {
905 max_txn_ops,
906 sql_template_set,
907 txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
908 executor_factory: PgExecutorFactory { pool },
909 _phantom: PhantomData,
910 }))
911 }
912}
913
914#[cfg(test)]
915mod tests {
916 use super::*;
917 use crate::kv_backend::test::{
918 prepare_kv_with_prefix, test_kv_batch_delete_with_prefix, test_kv_batch_get_with_prefix,
919 test_kv_compare_and_put_with_prefix, test_kv_delete_range_with_prefix,
920 test_kv_put_with_prefix, test_kv_range_2_with_prefix, test_kv_range_with_prefix,
921 test_simple_kv_range, test_txn_compare_equal, test_txn_compare_greater,
922 test_txn_compare_less, test_txn_compare_not_equal, test_txn_one_compare_op,
923 text_txn_multi_compare_op, unprepare_kv,
924 };
925 use crate::test_util::test_certs_dir;
926 use crate::{maybe_skip_postgres_integration_test, maybe_skip_postgres15_integration_test};
927
928 async fn build_pg_kv_backend(table_name: &str) -> Option<PgStore> {
929 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap_or_default();
930 if endpoints.is_empty() {
931 return None;
932 }
933
934 let mut cfg = Config::new();
935 cfg.url = Some(endpoints);
936 let pool = cfg
937 .create_pool(Some(Runtime::Tokio1), NoTls)
938 .context(CreatePostgresPoolSnafu)
939 .unwrap();
940 let client = pool.get().await.unwrap();
941 let template_factory = PgSqlTemplateFactory::new(None, table_name);
943 let sql_templates = template_factory.build();
944 client
946 .execute(&sql_templates.create_table_statement, &[])
947 .await
948 .with_context(|_| PostgresExecutionSnafu {
949 sql: sql_templates.create_table_statement.clone(),
950 })
951 .unwrap();
952 Some(PgStore {
953 max_txn_ops: 128,
954 sql_template_set: sql_templates,
955 txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
956 executor_factory: PgExecutorFactory { pool },
957 _phantom: PhantomData,
958 })
959 }
960
961 async fn build_pg15_pool() -> Option<Pool> {
962 let url = std::env::var("GT_POSTGRES15_ENDPOINTS").unwrap_or_default();
963 if url.is_empty() {
964 return None;
965 }
966 let mut cfg = Config::new();
967 cfg.url = Some(url);
968 let pool = cfg
969 .create_pool(Some(Runtime::Tokio1), NoTls)
970 .context(CreatePostgresPoolSnafu)
971 .ok()?;
972 Some(pool)
973 }
974
975 #[tokio::test]
976 async fn test_pg15_create_table_in_public_should_fail() {
977 maybe_skip_postgres15_integration_test!();
978 let Some(pool) = build_pg15_pool().await else {
979 return;
980 };
981 let res = PgStore::with_pg_pool(pool, None, "pg15_public_should_fail", 128, false).await;
982 assert!(
983 res.is_err(),
984 "creating table in public should fail for test_user"
985 );
986 }
987
988 #[tokio::test]
989 async fn test_pg15_create_table_in_test_schema_and_crud_should_succeed() {
990 maybe_skip_postgres15_integration_test!();
991 let Some(pool) = build_pg15_pool().await else {
992 return;
993 };
994 let schema_name = std::env::var("GT_POSTGRES15_SCHEMA").unwrap();
995 let client = pool.get().await.unwrap();
996 let factory = PgSqlTemplateFactory::new(Some(&schema_name), "pg15_ok");
997 let templates = factory.build();
998 client
999 .execute(&templates.create_table_statement, &[])
1000 .await
1001 .unwrap();
1002 let kv = PgStore {
1003 max_txn_ops: 128,
1004 sql_template_set: templates,
1005 txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
1006 executor_factory: PgExecutorFactory { pool },
1007 _phantom: PhantomData,
1008 };
1009 let prefix = b"pg15_crud/";
1010 prepare_kv_with_prefix(&kv, prefix.to_vec()).await;
1011 test_kv_put_with_prefix(&kv, prefix.to_vec()).await;
1012 test_kv_batch_get_with_prefix(&kv, prefix.to_vec()).await;
1013 unprepare_kv(&kv, prefix).await;
1014 }
1015
1016 #[tokio::test]
1017 async fn test_pg_with_tls() {
1018 common_telemetry::init_default_ut_logging();
1019 maybe_skip_postgres_integration_test!();
1020 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1021 let tls_connector = create_postgres_tls_connector(&TlsOption {
1022 mode: TlsMode::Require,
1023 cert_path: String::new(),
1024 key_path: String::new(),
1025 ca_cert_path: String::new(),
1026 watch: false,
1027 })
1028 .unwrap();
1029 let mut cfg = Config::new();
1030 cfg.url = Some(endpoints);
1031 let pool = cfg
1032 .create_pool(Some(Runtime::Tokio1), tls_connector)
1033 .unwrap();
1034 let client = pool.get().await.unwrap();
1035 client.execute("SELECT 1", &[]).await.unwrap();
1036 }
1037
1038 #[tokio::test]
1039 async fn test_pg_with_mtls() {
1040 common_telemetry::init_default_ut_logging();
1041 maybe_skip_postgres_integration_test!();
1042 let certs_dir = test_certs_dir();
1043 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1044 let tls_connector = create_postgres_tls_connector(&TlsOption {
1045 mode: TlsMode::Require,
1046 cert_path: certs_dir.join("client.crt").display().to_string(),
1047 key_path: certs_dir.join("client.key").display().to_string(),
1048 ca_cert_path: String::new(),
1049 watch: false,
1050 })
1051 .unwrap();
1052 let mut cfg = Config::new();
1053 cfg.url = Some(endpoints);
1054 let pool = cfg
1055 .create_pool(Some(Runtime::Tokio1), tls_connector)
1056 .unwrap();
1057 let client = pool.get().await.unwrap();
1058 client.execute("SELECT 1", &[]).await.unwrap();
1059 }
1060
1061 #[tokio::test]
1062 async fn test_pg_verify_ca() {
1063 common_telemetry::init_default_ut_logging();
1064 maybe_skip_postgres_integration_test!();
1065 let certs_dir = test_certs_dir();
1066 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1067 let tls_connector = create_postgres_tls_connector(&TlsOption {
1068 mode: TlsMode::VerifyCa,
1069 cert_path: certs_dir.join("client.crt").display().to_string(),
1070 key_path: certs_dir.join("client.key").display().to_string(),
1071 ca_cert_path: certs_dir.join("root.crt").display().to_string(),
1072 watch: false,
1073 })
1074 .unwrap();
1075 let mut cfg = Config::new();
1076 cfg.url = Some(endpoints);
1077 let pool = cfg
1078 .create_pool(Some(Runtime::Tokio1), tls_connector)
1079 .unwrap();
1080 let client = pool.get().await.unwrap();
1081 client.execute("SELECT 1", &[]).await.unwrap();
1082 }
1083
1084 #[tokio::test]
1085 async fn test_pg_verify_full() {
1086 common_telemetry::init_default_ut_logging();
1087 maybe_skip_postgres_integration_test!();
1088 let certs_dir = test_certs_dir();
1089 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1090 let tls_connector = create_postgres_tls_connector(&TlsOption {
1091 mode: TlsMode::VerifyFull,
1092 cert_path: certs_dir.join("client.crt").display().to_string(),
1093 key_path: certs_dir.join("client.key").display().to_string(),
1094 ca_cert_path: certs_dir.join("root.crt").display().to_string(),
1095 watch: false,
1096 })
1097 .unwrap();
1098 let mut cfg = Config::new();
1099 cfg.url = Some(endpoints);
1100 let pool = cfg
1101 .create_pool(Some(Runtime::Tokio1), tls_connector)
1102 .unwrap();
1103 let client = pool.get().await.unwrap();
1104 client.execute("SELECT 1", &[]).await.unwrap();
1105 }
1106
1107 #[tokio::test]
1108 async fn test_pg_put() {
1109 maybe_skip_postgres_integration_test!();
1110 let kv_backend = build_pg_kv_backend("put-test").await.unwrap();
1111 let prefix = b"put/";
1112 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1113 test_kv_put_with_prefix(&kv_backend, prefix.to_vec()).await;
1114 unprepare_kv(&kv_backend, prefix).await;
1115 }
1116
1117 #[tokio::test]
1118 async fn test_pg_range() {
1119 maybe_skip_postgres_integration_test!();
1120 let kv_backend = build_pg_kv_backend("range-test").await.unwrap();
1121 let prefix = b"range/";
1122 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1123 test_kv_range_with_prefix(&kv_backend, prefix.to_vec()).await;
1124 unprepare_kv(&kv_backend, prefix).await;
1125 }
1126
1127 #[tokio::test]
1128 async fn test_pg_range_2() {
1129 maybe_skip_postgres_integration_test!();
1130 let kv_backend = build_pg_kv_backend("range2-test").await.unwrap();
1131 let prefix = b"range2/";
1132 test_kv_range_2_with_prefix(&kv_backend, prefix.to_vec()).await;
1133 unprepare_kv(&kv_backend, prefix).await;
1134 }
1135
1136 #[tokio::test]
1137 async fn test_pg_all_range() {
1138 maybe_skip_postgres_integration_test!();
1139 let kv_backend = build_pg_kv_backend("simple_range-test").await.unwrap();
1140 let prefix = b"";
1141 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1142 test_simple_kv_range(&kv_backend).await;
1143 unprepare_kv(&kv_backend, prefix).await;
1144 }
1145
1146 #[tokio::test]
1147 async fn test_pg_batch_get() {
1148 maybe_skip_postgres_integration_test!();
1149 let kv_backend = build_pg_kv_backend("batch_get-test").await.unwrap();
1150 let prefix = b"batch_get/";
1151 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1152 test_kv_batch_get_with_prefix(&kv_backend, prefix.to_vec()).await;
1153 unprepare_kv(&kv_backend, prefix).await;
1154 }
1155
1156 #[tokio::test]
1157 async fn test_pg_batch_delete() {
1158 maybe_skip_postgres_integration_test!();
1159 let kv_backend = build_pg_kv_backend("batch_delete-test").await.unwrap();
1160 let prefix = b"batch_delete/";
1161 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1162 test_kv_delete_range_with_prefix(&kv_backend, prefix.to_vec()).await;
1163 unprepare_kv(&kv_backend, prefix).await;
1164 }
1165
1166 #[tokio::test]
1167 async fn test_pg_batch_delete_with_prefix() {
1168 maybe_skip_postgres_integration_test!();
1169 let kv_backend = build_pg_kv_backend("batch_delete_with_prefix-test")
1170 .await
1171 .unwrap();
1172 let prefix = b"batch_delete/";
1173 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1174 test_kv_batch_delete_with_prefix(&kv_backend, prefix.to_vec()).await;
1175 unprepare_kv(&kv_backend, prefix).await;
1176 }
1177
1178 #[tokio::test]
1179 async fn test_pg_delete_range() {
1180 maybe_skip_postgres_integration_test!();
1181 let kv_backend = build_pg_kv_backend("delete_range-test").await.unwrap();
1182 let prefix = b"delete_range/";
1183 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1184 test_kv_delete_range_with_prefix(&kv_backend, prefix.to_vec()).await;
1185 unprepare_kv(&kv_backend, prefix).await;
1186 }
1187
1188 #[tokio::test]
1189 async fn test_pg_compare_and_put() {
1190 maybe_skip_postgres_integration_test!();
1191 let kv_backend = build_pg_kv_backend("compare_and_put-test").await.unwrap();
1192 let prefix = b"compare_and_put/";
1193 let kv_backend = Arc::new(kv_backend);
1194 test_kv_compare_and_put_with_prefix(kv_backend.clone(), prefix.to_vec()).await;
1195 }
1196
1197 #[tokio::test]
1198 async fn test_pg_txn() {
1199 maybe_skip_postgres_integration_test!();
1200 let kv_backend = build_pg_kv_backend("txn-test").await.unwrap();
1201 test_txn_one_compare_op(&kv_backend).await;
1202 text_txn_multi_compare_op(&kv_backend).await;
1203 test_txn_compare_equal(&kv_backend).await;
1204 test_txn_compare_greater(&kv_backend).await;
1205 test_txn_compare_less(&kv_backend).await;
1206 test_txn_compare_not_equal(&kv_backend).await;
1207 }
1208
1209 #[test]
1210 fn test_pg_template_with_schema() {
1211 let factory = PgSqlTemplateFactory::new(Some("test_schema"), "greptime_metakv");
1212 let t = factory.build();
1213 assert!(
1214 t.create_table_statement
1215 .contains("\"test_schema\".\"greptime_metakv\"")
1216 );
1217 let upsert = t.generate_batch_upsert_query(1);
1218 assert!(upsert.contains("\"test_schema\".\"greptime_metakv\""));
1219 let get = t.generate_batch_get_query(1);
1220 assert!(get.contains("\"test_schema\".\"greptime_metakv\""));
1221 let del = t.generate_batch_delete_query(1);
1222 assert!(del.contains("\"test_schema\".\"greptime_metakv\""));
1223 }
1224
1225 #[test]
1226 fn test_format_table_ident() {
1227 let t = PgSqlTemplateFactory::format_table_ident(None, "test_table");
1228 assert_eq!(t, "\"test_table\"");
1229
1230 let t = PgSqlTemplateFactory::format_table_ident(Some("test_schema"), "test_table");
1231 assert_eq!(t, "\"test_schema\".\"test_table\"");
1232
1233 let t = PgSqlTemplateFactory::format_table_ident(Some(""), "test_table");
1234 assert_eq!(t, "\"test_table\"");
1235 }
1236
1237 #[tokio::test]
1238 async fn test_auto_create_schema_enabled() {
1239 common_telemetry::init_default_ut_logging();
1240 maybe_skip_postgres_integration_test!();
1241 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1242 let mut cfg = Config::new();
1243 cfg.url = Some(endpoints);
1244 let pool = cfg
1245 .create_pool(Some(Runtime::Tokio1), NoTls)
1246 .context(CreatePostgresPoolSnafu)
1247 .unwrap();
1248
1249 let schema_name = "test_auto_create_enabled";
1250 let table_name = "test_table";
1251
1252 let client = pool.get().await.unwrap();
1254 let _ = client
1255 .execute(
1256 &format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
1257 &[],
1258 )
1259 .await;
1260
1261 let _ = PgStore::with_pg_pool(pool.clone(), Some(schema_name), table_name, 128, true)
1263 .await
1264 .unwrap();
1265
1266 let row = client
1268 .query_one(
1269 "SELECT schema_name FROM information_schema.schemata WHERE schema_name = $1",
1270 &[&schema_name],
1271 )
1272 .await
1273 .unwrap();
1274 let created_schema: String = row.get(0);
1275 assert_eq!(created_schema, schema_name);
1276
1277 let row = client
1279 .query_one(
1280 "SELECT table_schema, table_name FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2",
1281 &[&schema_name, &table_name],
1282 )
1283 .await
1284 .unwrap();
1285 let created_table_schema: String = row.get(0);
1286 let created_table_name: String = row.get(1);
1287 assert_eq!(created_table_schema, schema_name);
1288 assert_eq!(created_table_name, table_name);
1289
1290 let _ = client
1292 .execute(
1293 &format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
1294 &[],
1295 )
1296 .await;
1297 }
1298
1299 #[tokio::test]
1300 async fn test_auto_create_schema_disabled() {
1301 common_telemetry::init_default_ut_logging();
1302 maybe_skip_postgres_integration_test!();
1303 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1304 let mut cfg = Config::new();
1305 cfg.url = Some(endpoints);
1306 let pool = cfg
1307 .create_pool(Some(Runtime::Tokio1), NoTls)
1308 .context(CreatePostgresPoolSnafu)
1309 .unwrap();
1310
1311 let schema_name = "test_auto_create_disabled";
1312 let table_name = "test_table";
1313
1314 let client = pool.get().await.unwrap();
1316 let _ = client
1317 .execute(
1318 &format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
1319 &[],
1320 )
1321 .await;
1322
1323 let result =
1325 PgStore::with_pg_pool(pool.clone(), Some(schema_name), table_name, 128, false).await;
1326
1327 assert!(
1329 result.is_err(),
1330 "Expected error when schema doesn't exist and auto_create_schema is disabled"
1331 );
1332 }
1333
1334 #[tokio::test]
1335 async fn test_auto_create_schema_already_exists() {
1336 common_telemetry::init_default_ut_logging();
1337 maybe_skip_postgres_integration_test!();
1338 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1339 let mut cfg = Config::new();
1340 cfg.url = Some(endpoints);
1341 let pool = cfg
1342 .create_pool(Some(Runtime::Tokio1), NoTls)
1343 .context(CreatePostgresPoolSnafu)
1344 .unwrap();
1345
1346 let schema_name = "test_auto_create_existing";
1347 let table_name = "test_table";
1348
1349 let client = pool.get().await.unwrap();
1351 let _ = client
1352 .execute(
1353 &format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
1354 &[],
1355 )
1356 .await;
1357 client
1358 .execute(&format!("CREATE SCHEMA \"{}\"", schema_name), &[])
1359 .await
1360 .unwrap();
1361
1362 let _ = PgStore::with_pg_pool(pool.clone(), Some(schema_name), table_name, 128, true)
1364 .await
1365 .unwrap();
1366
1367 let row = client
1369 .query_one(
1370 "SELECT schema_name FROM information_schema.schemata WHERE schema_name = $1",
1371 &[&schema_name],
1372 )
1373 .await
1374 .unwrap();
1375 let created_schema: String = row.get(0);
1376 assert_eq!(created_schema, schema_name);
1377
1378 let row = client
1380 .query_one(
1381 "SELECT table_schema, table_name FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2",
1382 &[&schema_name, &table_name],
1383 )
1384 .await
1385 .unwrap();
1386 let created_table_schema: String = row.get(0);
1387 let created_table_name: String = row.get(1);
1388 assert_eq!(created_table_schema, schema_name);
1389 assert_eq!(created_table_name, table_name);
1390
1391 let _ = client
1393 .execute(
1394 &format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
1395 &[],
1396 )
1397 .await;
1398 }
1399
1400 #[tokio::test]
1401 async fn test_auto_create_schema_no_schema_name() {
1402 common_telemetry::init_default_ut_logging();
1403 maybe_skip_postgres_integration_test!();
1404 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1405 let mut cfg = Config::new();
1406 cfg.url = Some(endpoints);
1407 let pool = cfg
1408 .create_pool(Some(Runtime::Tokio1), NoTls)
1409 .context(CreatePostgresPoolSnafu)
1410 .unwrap();
1411
1412 let table_name = "test_table_no_schema";
1413
1414 let _ = PgStore::with_pg_pool(pool.clone(), None, table_name, 128, true)
1417 .await
1418 .unwrap();
1419
1420 let client = pool.get().await.unwrap();
1422 let row = client
1423 .query_one(
1424 "SELECT table_schema, table_name FROM information_schema.tables WHERE table_name = $1",
1425 &[&table_name],
1426 )
1427 .await
1428 .unwrap();
1429 let created_table_schema: String = row.get(0);
1430 let created_table_name: String = row.get(1);
1431 assert_eq!(created_table_name, table_name);
1432 assert!(created_table_schema == "public" || !created_table_schema.is_empty());
1434
1435 let _ = client
1437 .execute(&format!("DROP TABLE IF EXISTS \"{}\"", table_name), &[])
1438 .await;
1439 }
1440
1441 #[tokio::test]
1442 async fn test_auto_create_schema_with_empty_schema_name() {
1443 common_telemetry::init_default_ut_logging();
1444 maybe_skip_postgres_integration_test!();
1445 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1446 let mut cfg = Config::new();
1447 cfg.url = Some(endpoints);
1448 let pool = cfg
1449 .create_pool(Some(Runtime::Tokio1), NoTls)
1450 .context(CreatePostgresPoolSnafu)
1451 .unwrap();
1452
1453 let table_name = "test_table_empty_schema";
1454
1455 let _ = PgStore::with_pg_pool(pool.clone(), Some(""), table_name, 128, true)
1458 .await
1459 .unwrap();
1460
1461 let client = pool.get().await.unwrap();
1463 let row = client
1464 .query_one(
1465 "SELECT table_schema, table_name FROM information_schema.tables WHERE table_name = $1",
1466 &[&table_name],
1467 )
1468 .await
1469 .unwrap();
1470 let created_table_schema: String = row.get(0);
1471 let created_table_name: String = row.get(1);
1472 assert_eq!(created_table_name, table_name);
1473 assert!(created_table_schema == "public" || !created_table_schema.is_empty());
1475
1476 let _ = client
1478 .execute(&format!("DROP TABLE IF EXISTS \"{}\"", table_name), &[])
1479 .await;
1480 }
1481}