1use std::fs::File;
16use std::io::BufReader;
17use std::marker::PhantomData;
18use std::sync::Arc;
19
20use common_telemetry::debug;
21use deadpool_postgres::{Config, Pool, Runtime};
22use rustls::ClientConfig;
24use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
25use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
26use rustls::server::ParsedCertificate;
27use rustls::{DigitallySignedStruct, Error as TlsError, SignatureScheme};
28use rustls_pemfile::{certs, private_key};
29use snafu::ResultExt;
30use strum::AsRefStr;
31use tokio_postgres::types::ToSql;
32use tokio_postgres::{IsolationLevel, NoTls, Row};
33use tokio_postgres_rustls::MakeRustlsConnect;
34
35use crate::error::{
36 CreatePostgresPoolSnafu, GetPostgresConnectionSnafu, LoadTlsCertificateSnafu,
37 PostgresExecutionSnafu, PostgresTlsConfigSnafu, PostgresTransactionSnafu, Result,
38};
39use crate::kv_backend::KvBackendRef;
40use crate::kv_backend::rds::{
41 Executor, ExecutorFactory, ExecutorImpl, KvQueryExecutor, RDS_STORE_OP_BATCH_DELETE,
42 RDS_STORE_OP_BATCH_GET, RDS_STORE_OP_BATCH_PUT, RDS_STORE_OP_RANGE_DELETE,
43 RDS_STORE_OP_RANGE_QUERY, RDS_STORE_TXN_RETRY_COUNT, RdsStore, Transaction,
44};
45use crate::rpc::KeyValue;
46use crate::rpc::store::{
47 BatchDeleteRequest, BatchDeleteResponse, BatchGetRequest, BatchGetResponse, BatchPutRequest,
48 BatchPutResponse, DeleteRangeRequest, DeleteRangeResponse, RangeRequest, RangeResponse,
49};
50
51#[derive(Debug, Clone, PartialEq, Eq, Default)]
54pub enum TlsMode {
55 Disable,
56 #[default]
57 Prefer,
58 Require,
59 VerifyCa,
60 VerifyFull,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
66pub struct TlsOption {
67 pub mode: TlsMode,
68 pub cert_path: String,
69 pub key_path: String,
70 pub ca_cert_path: String,
71 pub watch: bool,
72}
73
74impl Default for TlsOption {
75 fn default() -> Self {
76 TlsOption {
77 mode: TlsMode::Prefer,
78 cert_path: String::new(),
79 key_path: String::new(),
80 ca_cert_path: String::new(),
81 watch: false,
82 }
83 }
84}
85
86const PG_STORE_NAME: &str = "pg_store";
87
88pub struct PgClient(deadpool::managed::Object<deadpool_postgres::Manager>);
89pub struct PgTxnClient<'a>(deadpool_postgres::Transaction<'a>);
90
91fn key_value_from_row(r: Row) -> KeyValue {
93 KeyValue {
94 key: r.get(0),
95 value: r.get(1),
96 }
97}
98
99const EMPTY: &[u8] = &[0];
100
101#[derive(Debug, Clone, Copy, AsRefStr)]
103enum RangeTemplateType {
104 Point,
105 Range,
106 Full,
107 LeftBounded,
108 Prefix,
109}
110
111impl RangeTemplateType {
113 fn build_params(&self, mut key: Vec<u8>, range_end: Vec<u8>) -> Vec<Vec<u8>> {
116 match self {
117 RangeTemplateType::Point => vec![key],
118 RangeTemplateType::Range => vec![key, range_end],
119 RangeTemplateType::Full => vec![],
120 RangeTemplateType::LeftBounded => vec![key],
121 RangeTemplateType::Prefix => {
122 key.push(b'%');
123 vec![key]
124 }
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
131struct RangeTemplate {
132 point: String,
133 range: String,
134 full: String,
135 left_bounded: String,
136 prefix: String,
137}
138
139impl RangeTemplate {
140 fn get(&self, typ: RangeTemplateType) -> &str {
142 match typ {
143 RangeTemplateType::Point => &self.point,
144 RangeTemplateType::Range => &self.range,
145 RangeTemplateType::Full => &self.full,
146 RangeTemplateType::LeftBounded => &self.left_bounded,
147 RangeTemplateType::Prefix => &self.prefix,
148 }
149 }
150
151 fn with_limit(template: &str, limit: i64) -> String {
153 if limit == 0 {
154 return format!("{};", template);
155 }
156 format!("{} LIMIT {};", template, limit)
157 }
158}
159
160fn is_prefix_range(start: &[u8], end: &[u8]) -> bool {
161 if start.len() != end.len() {
162 return false;
163 }
164 let l = start.len();
165 let same_prefix = start[0..l - 1] == end[0..l - 1];
166 if let (Some(rhs), Some(lhs)) = (start.last(), end.last()) {
167 return same_prefix && (*rhs + 1) == *lhs;
168 }
169 false
170}
171
172fn range_template(key: &[u8], range_end: &[u8]) -> RangeTemplateType {
174 match (key, range_end) {
175 (_, &[]) => RangeTemplateType::Point,
176 (EMPTY, EMPTY) => RangeTemplateType::Full,
177 (_, EMPTY) => RangeTemplateType::LeftBounded,
178 (start, end) => {
179 if is_prefix_range(start, end) {
180 RangeTemplateType::Prefix
181 } else {
182 RangeTemplateType::Range
183 }
184 }
185 }
186}
187
188fn pg_generate_in_placeholders(from: usize, to: usize) -> Vec<String> {
190 (from..=to).map(|i| format!("${}", i)).collect()
191}
192
193struct PgSqlTemplateFactory<'a> {
195 schema_name: Option<&'a str>,
196 table_name: &'a str,
197}
198
199impl<'a> PgSqlTemplateFactory<'a> {
200 fn new(schema_name: Option<&'a str>, table_name: &'a str) -> Self {
202 Self {
203 schema_name,
204 table_name,
205 }
206 }
207
208 fn build(&self) -> PgSqlTemplateSet {
210 let table_ident = Self::format_table_ident(self.schema_name, self.table_name);
211 PgSqlTemplateSet {
213 table_ident: table_ident.clone(),
214 create_table_statement: format!(
216 "CREATE TABLE IF NOT EXISTS {table_ident}(k bytea PRIMARY KEY, v bytea)",
217 ),
218 range_template: RangeTemplate {
219 point: format!("SELECT k, v FROM {table_ident} WHERE k = $1"),
220 range: format!(
221 "SELECT k, v FROM {table_ident} WHERE k >= $1 AND k < $2 ORDER BY k"
222 ),
223 full: format!("SELECT k, v FROM {table_ident} ORDER BY k"),
224 left_bounded: format!("SELECT k, v FROM {table_ident} WHERE k >= $1 ORDER BY k"),
225 prefix: format!("SELECT k, v FROM {table_ident} WHERE k LIKE $1 ORDER BY k"),
226 },
227 delete_template: RangeTemplate {
228 point: format!("DELETE FROM {table_ident} WHERE k = $1 RETURNING k,v;"),
229 range: format!("DELETE FROM {table_ident} WHERE k >= $1 AND k < $2 RETURNING k,v;"),
230 full: format!("DELETE FROM {table_ident} RETURNING k,v"),
231 left_bounded: format!("DELETE FROM {table_ident} WHERE k >= $1 RETURNING k,v;"),
232 prefix: format!("DELETE FROM {table_ident} WHERE k LIKE $1 RETURNING k,v;"),
233 },
234 }
235 }
236
237 fn format_table_ident(schema_name: Option<&str>, table_name: &str) -> String {
239 match schema_name {
240 Some(s) if !s.is_empty() => format!("\"{}\".\"{}\"", s, table_name),
241 _ => format!("\"{}\"", table_name),
242 }
243 }
244}
245
246#[derive(Debug, Clone)]
248pub struct PgSqlTemplateSet {
249 table_ident: String,
250 create_table_statement: String,
251 range_template: RangeTemplate,
252 delete_template: RangeTemplate,
253}
254
255impl PgSqlTemplateSet {
256 fn generate_batch_get_query(&self, key_len: usize) -> String {
258 let in_clause = pg_generate_in_placeholders(1, key_len).join(", ");
259 format!(
260 "SELECT k, v FROM {} WHERE k in ({});",
261 self.table_ident, in_clause
262 )
263 }
264
265 fn generate_batch_delete_query(&self, key_len: usize) -> String {
267 let in_clause = pg_generate_in_placeholders(1, key_len).join(", ");
268 format!(
269 "DELETE FROM {} WHERE k in ({}) RETURNING k,v;",
270 self.table_ident, in_clause
271 )
272 }
273
274 fn generate_batch_upsert_query(&self, kv_len: usize) -> String {
276 let in_placeholders: Vec<String> = (1..=kv_len).map(|i| format!("${}", i)).collect();
277 let in_clause = in_placeholders.join(", ");
278 let mut param_index = kv_len + 1;
279 let mut values_placeholders = Vec::new();
280 for _ in 0..kv_len {
281 values_placeholders.push(format!("(${0}, ${1})", param_index, param_index + 1));
282 param_index += 2;
283 }
284 let values_clause = values_placeholders.join(", ");
285
286 format!(
287 r#"
288 WITH prev AS (
289 SELECT k,v FROM {table} WHERE k IN ({in_clause})
290 ), update AS (
291 INSERT INTO {table} (k, v) VALUES
292 {values_clause}
293 ON CONFLICT (
294 k
295 ) DO UPDATE SET
296 v = excluded.v
297 )
298
299 SELECT k, v FROM prev;
300 "#,
301 table = self.table_ident,
302 in_clause = in_clause,
303 values_clause = values_clause
304 )
305 }
306}
307
308#[async_trait::async_trait]
309impl Executor for PgClient {
310 type Transaction<'a>
311 = PgTxnClient<'a>
312 where
313 Self: 'a;
314
315 fn name() -> &'static str {
316 "Postgres"
317 }
318
319 async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>> {
320 let params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
321 let stmt = self
322 .0
323 .prepare_cached(query)
324 .await
325 .context(PostgresExecutionSnafu { sql: query })?;
326 let rows = self
327 .0
328 .query(&stmt, ¶ms)
329 .await
330 .context(PostgresExecutionSnafu { sql: query })?;
331 Ok(rows.into_iter().map(key_value_from_row).collect())
332 }
333
334 async fn txn_executor<'a>(&'a mut self) -> Result<Self::Transaction<'a>> {
335 let txn = self
336 .0
337 .build_transaction()
338 .isolation_level(IsolationLevel::Serializable)
339 .start()
340 .await
341 .context(PostgresTransactionSnafu {
342 operation: "begin".to_string(),
343 })?;
344 Ok(PgTxnClient(txn))
345 }
346}
347
348#[async_trait::async_trait]
349impl<'a> Transaction<'a> for PgTxnClient<'a> {
350 async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>> {
351 let params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
352 let stmt = self
353 .0
354 .prepare_cached(query)
355 .await
356 .context(PostgresExecutionSnafu { sql: query })?;
357 let rows = self
358 .0
359 .query(&stmt, ¶ms)
360 .await
361 .context(PostgresExecutionSnafu { sql: query })?;
362 Ok(rows.into_iter().map(key_value_from_row).collect())
363 }
364
365 async fn commit(self) -> Result<()> {
366 self.0.commit().await.context(PostgresTransactionSnafu {
367 operation: "commit",
368 })?;
369 Ok(())
370 }
371}
372
373pub struct PgExecutorFactory {
374 pool: Pool,
375}
376
377impl PgExecutorFactory {
378 async fn client(&self) -> Result<PgClient> {
379 match self.pool.get().await {
380 Ok(client) => Ok(PgClient(client)),
381 Err(e) => GetPostgresConnectionSnafu {
382 reason: e.to_string(),
383 }
384 .fail(),
385 }
386 }
387}
388
389#[async_trait::async_trait]
390impl ExecutorFactory<PgClient> for PgExecutorFactory {
391 async fn default_executor(&self) -> Result<PgClient> {
392 self.client().await
393 }
394
395 async fn txn_executor<'a>(
396 &self,
397 default_executor: &'a mut PgClient,
398 ) -> Result<PgTxnClient<'a>> {
399 default_executor.txn_executor().await
400 }
401}
402
403pub type PgStore = RdsStore<PgClient, PgExecutorFactory, PgSqlTemplateSet>;
406
407pub fn create_postgres_tls_connector(tls_config: &TlsOption) -> Result<MakeRustlsConnect> {
421 common_telemetry::info!(
422 "Creating PostgreSQL TLS connector with mode: {:?}",
423 tls_config.mode
424 );
425
426 let config_builder = match tls_config.mode {
427 TlsMode::Disable => {
428 return PostgresTlsConfigSnafu {
429 reason: "Cannot create TLS connector for Disable mode".to_string(),
430 }
431 .fail();
432 }
433 TlsMode::Prefer | TlsMode::Require => {
434 let verifier = Arc::new(AcceptAnyVerifier);
436 ClientConfig::builder()
437 .dangerous()
438 .with_custom_certificate_verifier(verifier)
439 }
440 TlsMode::VerifyCa => {
441 let ca_store = load_ca(&tls_config.ca_cert_path)?;
443 let verifier = Arc::new(NoHostnameVerification { roots: ca_store });
444 ClientConfig::builder()
445 .dangerous()
446 .with_custom_certificate_verifier(verifier)
447 }
448 TlsMode::VerifyFull => {
449 let ca_store = load_ca(&tls_config.ca_cert_path)?;
450 ClientConfig::builder().with_root_certificates(ca_store)
451 }
452 };
453
454 let client_config = if !tls_config.cert_path.is_empty() && !tls_config.key_path.is_empty() {
456 common_telemetry::info!("Loading client certificate for mutual TLS");
458 let cert_chain = load_certs(&tls_config.cert_path)?;
459 let private_key = load_private_key(&tls_config.key_path)?;
460
461 config_builder
462 .with_client_auth_cert(cert_chain, private_key)
463 .map_err(|e| {
464 PostgresTlsConfigSnafu {
465 reason: format!("Failed to configure client authentication: {}", e),
466 }
467 .build()
468 })?
469 } else {
470 common_telemetry::info!("No client certificate provided, skip client authentication");
471 config_builder.with_no_client_auth()
472 };
473
474 common_telemetry::info!("Successfully created PostgreSQL TLS connector");
475 Ok(MakeRustlsConnect::new(client_config))
476}
477
478#[derive(Debug)]
480struct AcceptAnyVerifier;
481
482impl ServerCertVerifier for AcceptAnyVerifier {
483 fn verify_server_cert(
484 &self,
485 _end_entity: &CertificateDer<'_>,
486 _intermediates: &[CertificateDer<'_>],
487 _server_name: &ServerName<'_>,
488 _ocsp_response: &[u8],
489 _now: UnixTime,
490 ) -> std::result::Result<ServerCertVerified, TlsError> {
491 common_telemetry::debug!(
492 "Accepting server certificate without verification (Prefer/Require mode)"
493 );
494 Ok(ServerCertVerified::assertion())
495 }
496
497 fn verify_tls12_signature(
498 &self,
499 _message: &[u8],
500 _cert: &CertificateDer<'_>,
501 _dss: &DigitallySignedStruct,
502 ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
503 Ok(HandshakeSignatureValid::assertion())
505 }
506
507 fn verify_tls13_signature(
508 &self,
509 _message: &[u8],
510 _cert: &CertificateDer<'_>,
511 _dss: &DigitallySignedStruct,
512 ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
513 Ok(HandshakeSignatureValid::assertion())
515 }
516
517 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
518 rustls::crypto::ring::default_provider()
520 .signature_verification_algorithms
521 .supported_schemes()
522 }
523}
524
525#[derive(Debug)]
528struct NoHostnameVerification {
529 roots: Arc<rustls::RootCertStore>,
530}
531
532impl ServerCertVerifier for NoHostnameVerification {
533 fn verify_server_cert(
534 &self,
535 end_entity: &CertificateDer<'_>,
536 intermediates: &[CertificateDer<'_>],
537 _server_name: &ServerName<'_>,
538 _ocsp_response: &[u8],
539 now: UnixTime,
540 ) -> std::result::Result<ServerCertVerified, TlsError> {
541 let cert = ParsedCertificate::try_from(end_entity)?;
542 rustls::client::verify_server_cert_signed_by_trust_anchor(
543 &cert,
544 &self.roots,
545 intermediates,
546 now,
547 rustls::crypto::ring::default_provider()
548 .signature_verification_algorithms
549 .all,
550 )?;
551
552 Ok(ServerCertVerified::assertion())
553 }
554
555 fn verify_tls12_signature(
556 &self,
557 message: &[u8],
558 cert: &CertificateDer<'_>,
559 dss: &DigitallySignedStruct,
560 ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
561 rustls::crypto::verify_tls12_signature(
562 message,
563 cert,
564 dss,
565 &rustls::crypto::ring::default_provider().signature_verification_algorithms,
566 )
567 }
568
569 fn verify_tls13_signature(
570 &self,
571 message: &[u8],
572 cert: &CertificateDer<'_>,
573 dss: &DigitallySignedStruct,
574 ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
575 rustls::crypto::verify_tls13_signature(
576 message,
577 cert,
578 dss,
579 &rustls::crypto::ring::default_provider().signature_verification_algorithms,
580 )
581 }
582
583 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
584 rustls::crypto::ring::default_provider()
586 .signature_verification_algorithms
587 .supported_schemes()
588 }
589}
590
591fn load_certs(path: &str) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
592 let file = File::open(path).context(LoadTlsCertificateSnafu { path })?;
593 let mut reader = BufReader::new(file);
594 let certs = certs(&mut reader)
595 .collect::<std::result::Result<Vec<_>, _>>()
596 .map_err(|e| {
597 PostgresTlsConfigSnafu {
598 reason: format!("Failed to parse certificates from {}: {}", path, e),
599 }
600 .build()
601 })?;
602 Ok(certs)
603}
604
605fn load_private_key(path: &str) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
606 let file = File::open(path).context(LoadTlsCertificateSnafu { path })?;
607 let mut reader = BufReader::new(file);
608 let key = private_key(&mut reader)
609 .map_err(|e| {
610 PostgresTlsConfigSnafu {
611 reason: format!("Failed to parse private key from {}: {}", path, e),
612 }
613 .build()
614 })?
615 .ok_or_else(|| {
616 PostgresTlsConfigSnafu {
617 reason: format!("No private key found in {}", path),
618 }
619 .build()
620 })?;
621 Ok(key)
622}
623
624fn load_ca(path: &str) -> Result<Arc<rustls::RootCertStore>> {
625 let mut root_store = rustls::RootCertStore::empty();
626
627 match rustls_native_certs::load_native_certs() {
629 Ok(certs) => {
630 let num_certs = certs.len();
631 for cert in certs {
632 if let Err(e) = root_store.add(cert) {
633 return PostgresTlsConfigSnafu {
634 reason: format!("Failed to add root certificate: {}", e),
635 }
636 .fail();
637 }
638 }
639 common_telemetry::info!("Loaded {num_certs} system root certificates successfully");
640 }
641 Err(e) => {
642 return PostgresTlsConfigSnafu {
643 reason: format!("Failed to load system root certificates: {}", e),
644 }
645 .fail();
646 }
647 }
648
649 if !path.is_empty() {
651 let ca_certs = load_certs(path)?;
652 for cert in ca_certs {
653 if let Err(e) = root_store.add(cert) {
654 return PostgresTlsConfigSnafu {
655 reason: format!("Failed to add custom CA certificate: {}", e),
656 }
657 .fail();
658 }
659 }
660 common_telemetry::info!("Added custom CA certificate from {}", path);
661 }
662
663 Ok(Arc::new(root_store))
664}
665
666#[async_trait::async_trait]
667impl KvQueryExecutor<PgClient> for PgStore {
668 async fn range_with_query_executor(
669 &self,
670 query_executor: &mut ExecutorImpl<'_, PgClient>,
671 req: RangeRequest,
672 ) -> Result<RangeResponse> {
673 let template_type = range_template(&req.key, &req.range_end);
674 let template = self.sql_template_set.range_template.get(template_type);
675 let params = template_type.build_params(req.key, req.range_end);
676 let params_ref = params.iter().collect::<Vec<_>>();
677 let query =
679 RangeTemplate::with_limit(template, if req.limit == 0 { 0 } else { req.limit + 1 });
680 let limit = req.limit as usize;
681 debug!("query: {:?}, params: {:?}", query, params);
682 let mut kvs = crate::record_rds_sql_execute_elapsed!(
683 query_executor.query(&query, ¶ms_ref).await,
684 PG_STORE_NAME,
685 RDS_STORE_OP_RANGE_QUERY,
686 template_type.as_ref()
687 )?;
688
689 if req.keys_only {
690 kvs.iter_mut().for_each(|kv| kv.value = vec![]);
691 }
692 if limit == 0 || kvs.len() <= limit {
694 return Ok(RangeResponse { kvs, more: false });
695 }
696 let removed = kvs.pop();
698 debug_assert!(removed.is_some());
699 Ok(RangeResponse { kvs, more: true })
700 }
701
702 async fn batch_put_with_query_executor(
703 &self,
704 query_executor: &mut ExecutorImpl<'_, PgClient>,
705 req: BatchPutRequest,
706 ) -> Result<BatchPutResponse> {
707 let mut in_params = Vec::with_capacity(req.kvs.len() * 3);
708 let mut values_params = Vec::with_capacity(req.kvs.len() * 2);
709
710 for kv in &req.kvs {
711 let processed_key = &kv.key;
712 in_params.push(processed_key);
713
714 let processed_value = &kv.value;
715 values_params.push(processed_key);
716 values_params.push(processed_value);
717 }
718 in_params.extend(values_params);
719 let params = in_params.iter().map(|x| x as _).collect::<Vec<_>>();
720 let query = self
721 .sql_template_set
722 .generate_batch_upsert_query(req.kvs.len());
723
724 let kvs = crate::record_rds_sql_execute_elapsed!(
725 query_executor.query(&query, ¶ms).await,
726 PG_STORE_NAME,
727 RDS_STORE_OP_BATCH_PUT,
728 ""
729 )?;
730 if req.prev_kv {
731 Ok(BatchPutResponse { prev_kvs: kvs })
732 } else {
733 Ok(BatchPutResponse::default())
734 }
735 }
736
737 async fn batch_get_with_query_executor(
739 &self,
740 query_executor: &mut ExecutorImpl<'_, PgClient>,
741 req: BatchGetRequest,
742 ) -> Result<BatchGetResponse> {
743 if req.keys.is_empty() {
744 return Ok(BatchGetResponse { kvs: vec![] });
745 }
746 let query = self
747 .sql_template_set
748 .generate_batch_get_query(req.keys.len());
749 let params = req.keys.iter().map(|x| x as _).collect::<Vec<_>>();
750 let kvs = crate::record_rds_sql_execute_elapsed!(
751 query_executor.query(&query, ¶ms).await,
752 PG_STORE_NAME,
753 RDS_STORE_OP_BATCH_GET,
754 ""
755 )?;
756 Ok(BatchGetResponse { kvs })
757 }
758
759 async fn delete_range_with_query_executor(
760 &self,
761 query_executor: &mut ExecutorImpl<'_, PgClient>,
762 req: DeleteRangeRequest,
763 ) -> Result<DeleteRangeResponse> {
764 let template_type = range_template(&req.key, &req.range_end);
765 let template = self.sql_template_set.delete_template.get(template_type);
766 let params = template_type.build_params(req.key, req.range_end);
767 let params_ref = params.iter().map(|x| x as _).collect::<Vec<_>>();
768 let kvs = crate::record_rds_sql_execute_elapsed!(
769 query_executor.query(template, ¶ms_ref).await,
770 PG_STORE_NAME,
771 RDS_STORE_OP_RANGE_DELETE,
772 template_type.as_ref()
773 )?;
774 let mut resp = DeleteRangeResponse::new(kvs.len() as i64);
775 if req.prev_kv {
776 resp.with_prev_kvs(kvs);
777 }
778 Ok(resp)
779 }
780
781 async fn batch_delete_with_query_executor(
782 &self,
783 query_executor: &mut ExecutorImpl<'_, PgClient>,
784 req: BatchDeleteRequest,
785 ) -> Result<BatchDeleteResponse> {
786 if req.keys.is_empty() {
787 return Ok(BatchDeleteResponse::default());
788 }
789 let query = self
790 .sql_template_set
791 .generate_batch_delete_query(req.keys.len());
792 let params = req.keys.iter().map(|x| x as _).collect::<Vec<_>>();
793
794 let kvs = crate::record_rds_sql_execute_elapsed!(
795 query_executor.query(&query, ¶ms).await,
796 PG_STORE_NAME,
797 RDS_STORE_OP_BATCH_DELETE,
798 ""
799 )?;
800 if req.prev_kv {
801 Ok(BatchDeleteResponse { prev_kvs: kvs })
802 } else {
803 Ok(BatchDeleteResponse::default())
804 }
805 }
806}
807
808impl PgStore {
809 pub async fn with_url_and_tls(
818 url: &str,
819 table_name: &str,
820 max_txn_ops: usize,
821 tls_config: Option<TlsOption>,
822 ) -> Result<KvBackendRef> {
823 let mut cfg = Config::new();
824 cfg.url = Some(url.to_string());
825
826 let pool = match tls_config {
827 Some(tls_config) if tls_config.mode != TlsMode::Disable => {
828 match create_postgres_tls_connector(&tls_config) {
829 Ok(tls_connector) => cfg
830 .create_pool(Some(Runtime::Tokio1), tls_connector)
831 .context(CreatePostgresPoolSnafu)?,
832 Err(e) => {
833 if tls_config.mode == TlsMode::Prefer {
834 common_telemetry::info!(
836 "Failed to create TLS connector, falling back to insecure connection"
837 );
838 cfg.create_pool(Some(Runtime::Tokio1), NoTls)
839 .context(CreatePostgresPoolSnafu)?
840 } else {
841 return Err(e);
842 }
843 }
844 }
845 }
846 _ => cfg
847 .create_pool(Some(Runtime::Tokio1), NoTls)
848 .context(CreatePostgresPoolSnafu)?,
849 };
850
851 Self::with_pg_pool(pool, None, table_name, max_txn_ops, false).await
852 }
853
854 pub async fn with_url(url: &str, table_name: &str, max_txn_ops: usize) -> Result<KvBackendRef> {
856 Self::with_url_and_tls(url, table_name, max_txn_ops, None).await
857 }
858
859 pub async fn with_pg_pool(
861 pool: Pool,
862 schema_name: Option<&str>,
863 table_name: &str,
864 max_txn_ops: usize,
865 auto_create_schema: bool,
866 ) -> Result<KvBackendRef> {
867 let client = match pool.get().await {
869 Ok(client) => client,
870 Err(e) => {
871 common_telemetry::error!(e; "Failed to get Postgres connection.");
873 return GetPostgresConnectionSnafu {
874 reason: e.to_string(),
875 }
876 .fail();
877 }
878 };
879
880 if auto_create_schema
882 && let Some(schema) = schema_name
883 && !schema.is_empty()
884 {
885 let create_schema_sql = format!("CREATE SCHEMA IF NOT EXISTS \"{}\"", schema);
886 client
887 .execute(&create_schema_sql, &[])
888 .await
889 .with_context(|_| PostgresExecutionSnafu {
890 sql: create_schema_sql.clone(),
891 })?;
892 }
893
894 let template_factory = PgSqlTemplateFactory::new(schema_name, table_name);
895 let sql_template_set = template_factory.build();
896 client
897 .execute(&sql_template_set.create_table_statement, &[])
898 .await
899 .with_context(|_| PostgresExecutionSnafu {
900 sql: sql_template_set.create_table_statement.clone(),
901 })?;
902 Ok(Arc::new(Self {
903 max_txn_ops,
904 sql_template_set,
905 txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
906 executor_factory: PgExecutorFactory { pool },
907 _phantom: PhantomData,
908 }))
909 }
910}
911
912#[cfg(test)]
913mod tests {
914 use super::*;
915 use crate::kv_backend::test::{
916 prepare_kv_with_prefix, test_kv_batch_delete_with_prefix, test_kv_batch_get_with_prefix,
917 test_kv_compare_and_put_with_prefix, test_kv_delete_range_with_prefix,
918 test_kv_put_with_prefix, test_kv_range_2_with_prefix, test_kv_range_with_prefix,
919 test_simple_kv_range, test_txn_compare_equal, test_txn_compare_greater,
920 test_txn_compare_less, test_txn_compare_not_equal, test_txn_one_compare_op,
921 text_txn_multi_compare_op, unprepare_kv,
922 };
923 use crate::test_util::test_certs_dir;
924 use crate::{maybe_skip_postgres_integration_test, maybe_skip_postgres15_integration_test};
925
926 async fn build_pg_kv_backend(table_name: &str) -> Option<PgStore> {
927 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap_or_default();
928 if endpoints.is_empty() {
929 return None;
930 }
931
932 let mut cfg = Config::new();
933 cfg.url = Some(endpoints);
934 let pool = cfg
935 .create_pool(Some(Runtime::Tokio1), NoTls)
936 .context(CreatePostgresPoolSnafu)
937 .unwrap();
938 let client = pool.get().await.unwrap();
939 let template_factory = PgSqlTemplateFactory::new(None, table_name);
941 let sql_templates = template_factory.build();
942 client
944 .execute(&sql_templates.create_table_statement, &[])
945 .await
946 .with_context(|_| PostgresExecutionSnafu {
947 sql: sql_templates.create_table_statement.clone(),
948 })
949 .unwrap();
950 Some(PgStore {
951 max_txn_ops: 128,
952 sql_template_set: sql_templates,
953 txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
954 executor_factory: PgExecutorFactory { pool },
955 _phantom: PhantomData,
956 })
957 }
958
959 async fn build_pg15_pool() -> Option<Pool> {
960 let url = std::env::var("GT_POSTGRES15_ENDPOINTS").unwrap_or_default();
961 if url.is_empty() {
962 return None;
963 }
964 let mut cfg = Config::new();
965 cfg.url = Some(url);
966 let pool = cfg
967 .create_pool(Some(Runtime::Tokio1), NoTls)
968 .context(CreatePostgresPoolSnafu)
969 .ok()?;
970 Some(pool)
971 }
972
973 #[tokio::test]
974 async fn test_pg15_create_table_in_public_should_fail() {
975 maybe_skip_postgres15_integration_test!();
976 let Some(pool) = build_pg15_pool().await else {
977 return;
978 };
979 let res = PgStore::with_pg_pool(pool, None, "pg15_public_should_fail", 128, false).await;
980 assert!(
981 res.is_err(),
982 "creating table in public should fail for test_user"
983 );
984 }
985
986 #[tokio::test]
987 async fn test_pg15_create_table_in_test_schema_and_crud_should_succeed() {
988 maybe_skip_postgres15_integration_test!();
989 let Some(pool) = build_pg15_pool().await else {
990 return;
991 };
992 let schema_name = std::env::var("GT_POSTGRES15_SCHEMA").unwrap();
993 let client = pool.get().await.unwrap();
994 let factory = PgSqlTemplateFactory::new(Some(&schema_name), "pg15_ok");
995 let templates = factory.build();
996 client
997 .execute(&templates.create_table_statement, &[])
998 .await
999 .unwrap();
1000 let kv = PgStore {
1001 max_txn_ops: 128,
1002 sql_template_set: templates,
1003 txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
1004 executor_factory: PgExecutorFactory { pool },
1005 _phantom: PhantomData,
1006 };
1007 let prefix = b"pg15_crud/";
1008 prepare_kv_with_prefix(&kv, prefix.to_vec()).await;
1009 test_kv_put_with_prefix(&kv, prefix.to_vec()).await;
1010 test_kv_batch_get_with_prefix(&kv, prefix.to_vec()).await;
1011 unprepare_kv(&kv, prefix).await;
1012 }
1013
1014 #[tokio::test]
1015 async fn test_pg_with_tls() {
1016 common_telemetry::init_default_ut_logging();
1017 maybe_skip_postgres_integration_test!();
1018 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1019 let tls_connector = create_postgres_tls_connector(&TlsOption {
1020 mode: TlsMode::Require,
1021 cert_path: String::new(),
1022 key_path: String::new(),
1023 ca_cert_path: String::new(),
1024 watch: false,
1025 })
1026 .unwrap();
1027 let mut cfg = Config::new();
1028 cfg.url = Some(endpoints);
1029 let pool = cfg
1030 .create_pool(Some(Runtime::Tokio1), tls_connector)
1031 .unwrap();
1032 let client = pool.get().await.unwrap();
1033 client.execute("SELECT 1", &[]).await.unwrap();
1034 }
1035
1036 #[tokio::test]
1037 async fn test_pg_with_mtls() {
1038 common_telemetry::init_default_ut_logging();
1039 maybe_skip_postgres_integration_test!();
1040 let certs_dir = test_certs_dir();
1041 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1042 let tls_connector = create_postgres_tls_connector(&TlsOption {
1043 mode: TlsMode::Require,
1044 cert_path: certs_dir.join("client.crt").display().to_string(),
1045 key_path: certs_dir.join("client.key").display().to_string(),
1046 ca_cert_path: String::new(),
1047 watch: false,
1048 })
1049 .unwrap();
1050 let mut cfg = Config::new();
1051 cfg.url = Some(endpoints);
1052 let pool = cfg
1053 .create_pool(Some(Runtime::Tokio1), tls_connector)
1054 .unwrap();
1055 let client = pool.get().await.unwrap();
1056 client.execute("SELECT 1", &[]).await.unwrap();
1057 }
1058
1059 #[tokio::test]
1060 async fn test_pg_verify_ca() {
1061 common_telemetry::init_default_ut_logging();
1062 maybe_skip_postgres_integration_test!();
1063 let certs_dir = test_certs_dir();
1064 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1065 let tls_connector = create_postgres_tls_connector(&TlsOption {
1066 mode: TlsMode::VerifyCa,
1067 cert_path: certs_dir.join("client.crt").display().to_string(),
1068 key_path: certs_dir.join("client.key").display().to_string(),
1069 ca_cert_path: certs_dir.join("root.crt").display().to_string(),
1070 watch: false,
1071 })
1072 .unwrap();
1073 let mut cfg = Config::new();
1074 cfg.url = Some(endpoints);
1075 let pool = cfg
1076 .create_pool(Some(Runtime::Tokio1), tls_connector)
1077 .unwrap();
1078 let client = pool.get().await.unwrap();
1079 client.execute("SELECT 1", &[]).await.unwrap();
1080 }
1081
1082 #[tokio::test]
1083 async fn test_pg_verify_full() {
1084 common_telemetry::init_default_ut_logging();
1085 maybe_skip_postgres_integration_test!();
1086 let certs_dir = test_certs_dir();
1087 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1088 let tls_connector = create_postgres_tls_connector(&TlsOption {
1089 mode: TlsMode::VerifyFull,
1090 cert_path: certs_dir.join("client.crt").display().to_string(),
1091 key_path: certs_dir.join("client.key").display().to_string(),
1092 ca_cert_path: certs_dir.join("root.crt").display().to_string(),
1093 watch: false,
1094 })
1095 .unwrap();
1096 let mut cfg = Config::new();
1097 cfg.url = Some(endpoints);
1098 let pool = cfg
1099 .create_pool(Some(Runtime::Tokio1), tls_connector)
1100 .unwrap();
1101 let client = pool.get().await.unwrap();
1102 client.execute("SELECT 1", &[]).await.unwrap();
1103 }
1104
1105 #[tokio::test]
1106 async fn test_pg_put() {
1107 maybe_skip_postgres_integration_test!();
1108 let kv_backend = build_pg_kv_backend("put-test").await.unwrap();
1109 let prefix = b"put/";
1110 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1111 test_kv_put_with_prefix(&kv_backend, prefix.to_vec()).await;
1112 unprepare_kv(&kv_backend, prefix).await;
1113 }
1114
1115 #[tokio::test]
1116 async fn test_pg_range() {
1117 maybe_skip_postgres_integration_test!();
1118 let kv_backend = build_pg_kv_backend("range-test").await.unwrap();
1119 let prefix = b"range/";
1120 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1121 test_kv_range_with_prefix(&kv_backend, prefix.to_vec()).await;
1122 unprepare_kv(&kv_backend, prefix).await;
1123 }
1124
1125 #[tokio::test]
1126 async fn test_pg_range_2() {
1127 maybe_skip_postgres_integration_test!();
1128 let kv_backend = build_pg_kv_backend("range2-test").await.unwrap();
1129 let prefix = b"range2/";
1130 test_kv_range_2_with_prefix(&kv_backend, prefix.to_vec()).await;
1131 unprepare_kv(&kv_backend, prefix).await;
1132 }
1133
1134 #[tokio::test]
1135 async fn test_pg_all_range() {
1136 maybe_skip_postgres_integration_test!();
1137 let kv_backend = build_pg_kv_backend("simple_range-test").await.unwrap();
1138 let prefix = b"";
1139 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1140 test_simple_kv_range(&kv_backend).await;
1141 unprepare_kv(&kv_backend, prefix).await;
1142 }
1143
1144 #[tokio::test]
1145 async fn test_pg_batch_get() {
1146 maybe_skip_postgres_integration_test!();
1147 let kv_backend = build_pg_kv_backend("batch_get-test").await.unwrap();
1148 let prefix = b"batch_get/";
1149 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1150 test_kv_batch_get_with_prefix(&kv_backend, prefix.to_vec()).await;
1151 unprepare_kv(&kv_backend, prefix).await;
1152 }
1153
1154 #[tokio::test]
1155 async fn test_pg_batch_delete() {
1156 maybe_skip_postgres_integration_test!();
1157 let kv_backend = build_pg_kv_backend("batch_delete-test").await.unwrap();
1158 let prefix = b"batch_delete/";
1159 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1160 test_kv_delete_range_with_prefix(&kv_backend, prefix.to_vec()).await;
1161 unprepare_kv(&kv_backend, prefix).await;
1162 }
1163
1164 #[tokio::test]
1165 async fn test_pg_batch_delete_with_prefix() {
1166 maybe_skip_postgres_integration_test!();
1167 let kv_backend = build_pg_kv_backend("batch_delete_with_prefix-test")
1168 .await
1169 .unwrap();
1170 let prefix = b"batch_delete/";
1171 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1172 test_kv_batch_delete_with_prefix(&kv_backend, prefix.to_vec()).await;
1173 unprepare_kv(&kv_backend, prefix).await;
1174 }
1175
1176 #[tokio::test]
1177 async fn test_pg_delete_range() {
1178 maybe_skip_postgres_integration_test!();
1179 let kv_backend = build_pg_kv_backend("delete_range-test").await.unwrap();
1180 let prefix = b"delete_range/";
1181 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1182 test_kv_delete_range_with_prefix(&kv_backend, prefix.to_vec()).await;
1183 unprepare_kv(&kv_backend, prefix).await;
1184 }
1185
1186 #[tokio::test]
1187 async fn test_pg_compare_and_put() {
1188 maybe_skip_postgres_integration_test!();
1189 let kv_backend = build_pg_kv_backend("compare_and_put-test").await.unwrap();
1190 let prefix = b"compare_and_put/";
1191 let kv_backend = Arc::new(kv_backend);
1192 test_kv_compare_and_put_with_prefix(kv_backend.clone(), prefix.to_vec()).await;
1193 }
1194
1195 #[tokio::test]
1196 async fn test_pg_txn() {
1197 maybe_skip_postgres_integration_test!();
1198 let kv_backend = build_pg_kv_backend("txn-test").await.unwrap();
1199 test_txn_one_compare_op(&kv_backend).await;
1200 text_txn_multi_compare_op(&kv_backend).await;
1201 test_txn_compare_equal(&kv_backend).await;
1202 test_txn_compare_greater(&kv_backend).await;
1203 test_txn_compare_less(&kv_backend).await;
1204 test_txn_compare_not_equal(&kv_backend).await;
1205 }
1206
1207 #[test]
1208 fn test_pg_template_with_schema() {
1209 let factory = PgSqlTemplateFactory::new(Some("test_schema"), "greptime_metakv");
1210 let t = factory.build();
1211 assert!(
1212 t.create_table_statement
1213 .contains("\"test_schema\".\"greptime_metakv\"")
1214 );
1215 let upsert = t.generate_batch_upsert_query(1);
1216 assert!(upsert.contains("\"test_schema\".\"greptime_metakv\""));
1217 let get = t.generate_batch_get_query(1);
1218 assert!(get.contains("\"test_schema\".\"greptime_metakv\""));
1219 let del = t.generate_batch_delete_query(1);
1220 assert!(del.contains("\"test_schema\".\"greptime_metakv\""));
1221 }
1222
1223 #[test]
1224 fn test_format_table_ident() {
1225 let t = PgSqlTemplateFactory::format_table_ident(None, "test_table");
1226 assert_eq!(t, "\"test_table\"");
1227
1228 let t = PgSqlTemplateFactory::format_table_ident(Some("test_schema"), "test_table");
1229 assert_eq!(t, "\"test_schema\".\"test_table\"");
1230
1231 let t = PgSqlTemplateFactory::format_table_ident(Some(""), "test_table");
1232 assert_eq!(t, "\"test_table\"");
1233 }
1234
1235 #[tokio::test]
1236 async fn test_auto_create_schema_enabled() {
1237 common_telemetry::init_default_ut_logging();
1238 maybe_skip_postgres_integration_test!();
1239 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1240 let mut cfg = Config::new();
1241 cfg.url = Some(endpoints);
1242 let pool = cfg
1243 .create_pool(Some(Runtime::Tokio1), NoTls)
1244 .context(CreatePostgresPoolSnafu)
1245 .unwrap();
1246
1247 let schema_name = "test_auto_create_enabled";
1248 let table_name = "test_table";
1249
1250 let client = pool.get().await.unwrap();
1252 let _ = client
1253 .execute(
1254 &format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
1255 &[],
1256 )
1257 .await;
1258
1259 let _ = PgStore::with_pg_pool(pool.clone(), Some(schema_name), table_name, 128, true)
1261 .await
1262 .unwrap();
1263
1264 let row = client
1266 .query_one(
1267 "SELECT schema_name FROM information_schema.schemata WHERE schema_name = $1",
1268 &[&schema_name],
1269 )
1270 .await
1271 .unwrap();
1272 let created_schema: String = row.get(0);
1273 assert_eq!(created_schema, schema_name);
1274
1275 let row = client
1277 .query_one(
1278 "SELECT table_schema, table_name FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2",
1279 &[&schema_name, &table_name],
1280 )
1281 .await
1282 .unwrap();
1283 let created_table_schema: String = row.get(0);
1284 let created_table_name: String = row.get(1);
1285 assert_eq!(created_table_schema, schema_name);
1286 assert_eq!(created_table_name, table_name);
1287
1288 let _ = client
1290 .execute(
1291 &format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
1292 &[],
1293 )
1294 .await;
1295 }
1296
1297 #[tokio::test]
1298 async fn test_auto_create_schema_disabled() {
1299 common_telemetry::init_default_ut_logging();
1300 maybe_skip_postgres_integration_test!();
1301 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1302 let mut cfg = Config::new();
1303 cfg.url = Some(endpoints);
1304 let pool = cfg
1305 .create_pool(Some(Runtime::Tokio1), NoTls)
1306 .context(CreatePostgresPoolSnafu)
1307 .unwrap();
1308
1309 let schema_name = "test_auto_create_disabled";
1310 let table_name = "test_table";
1311
1312 let client = pool.get().await.unwrap();
1314 let _ = client
1315 .execute(
1316 &format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
1317 &[],
1318 )
1319 .await;
1320
1321 let result =
1323 PgStore::with_pg_pool(pool.clone(), Some(schema_name), table_name, 128, false).await;
1324
1325 assert!(
1327 result.is_err(),
1328 "Expected error when schema doesn't exist and auto_create_schema is disabled"
1329 );
1330 }
1331
1332 #[tokio::test]
1333 async fn test_auto_create_schema_already_exists() {
1334 common_telemetry::init_default_ut_logging();
1335 maybe_skip_postgres_integration_test!();
1336 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1337 let mut cfg = Config::new();
1338 cfg.url = Some(endpoints);
1339 let pool = cfg
1340 .create_pool(Some(Runtime::Tokio1), NoTls)
1341 .context(CreatePostgresPoolSnafu)
1342 .unwrap();
1343
1344 let schema_name = "test_auto_create_existing";
1345 let table_name = "test_table";
1346
1347 let client = pool.get().await.unwrap();
1349 let _ = client
1350 .execute(
1351 &format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
1352 &[],
1353 )
1354 .await;
1355 client
1356 .execute(&format!("CREATE SCHEMA \"{}\"", schema_name), &[])
1357 .await
1358 .unwrap();
1359
1360 let _ = PgStore::with_pg_pool(pool.clone(), Some(schema_name), table_name, 128, true)
1362 .await
1363 .unwrap();
1364
1365 let row = client
1367 .query_one(
1368 "SELECT schema_name FROM information_schema.schemata WHERE schema_name = $1",
1369 &[&schema_name],
1370 )
1371 .await
1372 .unwrap();
1373 let created_schema: String = row.get(0);
1374 assert_eq!(created_schema, schema_name);
1375
1376 let row = client
1378 .query_one(
1379 "SELECT table_schema, table_name FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2",
1380 &[&schema_name, &table_name],
1381 )
1382 .await
1383 .unwrap();
1384 let created_table_schema: String = row.get(0);
1385 let created_table_name: String = row.get(1);
1386 assert_eq!(created_table_schema, schema_name);
1387 assert_eq!(created_table_name, table_name);
1388
1389 let _ = client
1391 .execute(
1392 &format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
1393 &[],
1394 )
1395 .await;
1396 }
1397
1398 #[tokio::test]
1399 async fn test_auto_create_schema_no_schema_name() {
1400 common_telemetry::init_default_ut_logging();
1401 maybe_skip_postgres_integration_test!();
1402 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1403 let mut cfg = Config::new();
1404 cfg.url = Some(endpoints);
1405 let pool = cfg
1406 .create_pool(Some(Runtime::Tokio1), NoTls)
1407 .context(CreatePostgresPoolSnafu)
1408 .unwrap();
1409
1410 let table_name = "test_table_no_schema";
1411
1412 let _ = PgStore::with_pg_pool(pool.clone(), None, table_name, 128, true)
1415 .await
1416 .unwrap();
1417
1418 let client = pool.get().await.unwrap();
1420 let row = client
1421 .query_one(
1422 "SELECT table_schema, table_name FROM information_schema.tables WHERE table_name = $1",
1423 &[&table_name],
1424 )
1425 .await
1426 .unwrap();
1427 let created_table_schema: String = row.get(0);
1428 let created_table_name: String = row.get(1);
1429 assert_eq!(created_table_name, table_name);
1430 assert!(created_table_schema == "public" || !created_table_schema.is_empty());
1432
1433 let _ = client
1435 .execute(&format!("DROP TABLE IF EXISTS \"{}\"", table_name), &[])
1436 .await;
1437 }
1438
1439 #[tokio::test]
1440 async fn test_auto_create_schema_with_empty_schema_name() {
1441 common_telemetry::init_default_ut_logging();
1442 maybe_skip_postgres_integration_test!();
1443 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1444 let mut cfg = Config::new();
1445 cfg.url = Some(endpoints);
1446 let pool = cfg
1447 .create_pool(Some(Runtime::Tokio1), NoTls)
1448 .context(CreatePostgresPoolSnafu)
1449 .unwrap();
1450
1451 let table_name = "test_table_empty_schema";
1452
1453 let _ = PgStore::with_pg_pool(pool.clone(), Some(""), table_name, 128, true)
1456 .await
1457 .unwrap();
1458
1459 let client = pool.get().await.unwrap();
1461 let row = client
1462 .query_one(
1463 "SELECT table_schema, table_name FROM information_schema.tables WHERE table_name = $1",
1464 &[&table_name],
1465 )
1466 .await
1467 .unwrap();
1468 let created_table_schema: String = row.get(0);
1469 let created_table_name: String = row.get(1);
1470 assert_eq!(created_table_name, table_name);
1471 assert!(created_table_schema == "public" || !created_table_schema.is_empty());
1473
1474 let _ = client
1476 .execute(&format!("DROP TABLE IF EXISTS \"{}\"", table_name), &[])
1477 .await;
1478 }
1479}