Skip to main content

common_meta/kv_backend/rds/
postgres.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fs::File;
16use std::io::BufReader;
17use std::marker::PhantomData;
18use std::sync::Arc;
19
20use common_telemetry::debug;
21use deadpool_postgres::{Config, Pool, Runtime};
22// TLS-related imports (feature-gated)
23use rustls::ClientConfig;
24use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
25use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
26use rustls::server::ParsedCertificate;
27use rustls::{DigitallySignedStruct, Error as TlsError, SignatureScheme};
28use rustls_pemfile::{certs, private_key};
29use snafu::ResultExt;
30use strum::AsRefStr;
31use tokio_postgres::types::ToSql;
32use tokio_postgres::{IsolationLevel, NoTls, Row};
33use tokio_postgres_rustls::MakeRustlsConnect;
34
35use crate::error::{
36    CreatePostgresPoolSnafu, GetPostgresConnectionSnafu, LoadTlsCertificateSnafu,
37    PostgresExecutionSnafu, PostgresTlsConfigSnafu, PostgresTransactionSnafu, Result,
38};
39use crate::kv_backend::KvBackendRef;
40use crate::kv_backend::rds::{
41    Executor, ExecutorFactory, ExecutorImpl, KvQueryExecutor, RDS_STORE_OP_BATCH_DELETE,
42    RDS_STORE_OP_BATCH_GET, RDS_STORE_OP_BATCH_PUT, RDS_STORE_OP_RANGE_DELETE,
43    RDS_STORE_OP_RANGE_QUERY, RDS_STORE_TXN_RETRY_COUNT, RdsStore, Transaction,
44    ensure_rustls_crypto_provider_installed,
45};
46use crate::rpc::KeyValue;
47use crate::rpc::store::{
48    BatchDeleteRequest, BatchDeleteResponse, BatchGetRequest, BatchGetResponse, BatchPutRequest,
49    BatchPutResponse, DeleteRangeRequest, DeleteRangeResponse, RangeRequest, RangeResponse,
50};
51
52/// TLS mode configuration for PostgreSQL connections.
53/// This mirrors the TlsMode from servers::tls to avoid circular dependencies.
54#[derive(Debug, Clone, PartialEq, Eq, Default)]
55pub enum TlsMode {
56    Disable,
57    #[default]
58    Prefer,
59    Require,
60    VerifyCa,
61    VerifyFull,
62}
63
64/// TLS configuration for PostgreSQL connections.
65/// This mirrors the TlsOption from servers::tls to avoid circular dependencies.
66#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct TlsOption {
68    pub mode: TlsMode,
69    pub cert_path: String,
70    pub key_path: String,
71    pub ca_cert_path: String,
72    pub watch: bool,
73}
74
75impl Default for TlsOption {
76    fn default() -> Self {
77        TlsOption {
78            mode: TlsMode::Prefer,
79            cert_path: String::new(),
80            key_path: String::new(),
81            ca_cert_path: String::new(),
82            watch: false,
83        }
84    }
85}
86
87const PG_STORE_NAME: &str = "pg_store";
88
89pub struct PgClient(deadpool::managed::Object<deadpool_postgres::Manager>);
90pub struct PgTxnClient<'a>(deadpool_postgres::Transaction<'a>);
91
92/// Converts a row to a [`KeyValue`].
93fn key_value_from_row(r: Row) -> KeyValue {
94    KeyValue {
95        key: r.get(0),
96        value: r.get(1),
97    }
98}
99
100const EMPTY: &[u8] = &[0];
101
102/// Type of range template.
103#[derive(Debug, Clone, Copy, AsRefStr)]
104enum RangeTemplateType {
105    Point,
106    Range,
107    Full,
108    LeftBounded,
109    Prefix,
110}
111
112/// Builds params for the given range template type.
113impl RangeTemplateType {
114    /// Builds the parameters for the given range template type.
115    /// You can check out the conventions at [RangeRequest]
116    fn build_params(&self, mut key: Vec<u8>, range_end: Vec<u8>) -> Vec<Vec<u8>> {
117        match self {
118            RangeTemplateType::Point => vec![key],
119            RangeTemplateType::Range => vec![key, range_end],
120            RangeTemplateType::Full => vec![],
121            RangeTemplateType::LeftBounded => vec![key],
122            RangeTemplateType::Prefix => {
123                key.push(b'%');
124                vec![key]
125            }
126        }
127    }
128}
129
130/// Templates for range request.
131#[derive(Debug, Clone)]
132struct RangeTemplate {
133    point: String,
134    range: String,
135    full: String,
136    left_bounded: String,
137    prefix: String,
138}
139
140impl RangeTemplate {
141    /// Gets the template for the given type.
142    fn get(&self, typ: RangeTemplateType) -> &str {
143        match typ {
144            RangeTemplateType::Point => &self.point,
145            RangeTemplateType::Range => &self.range,
146            RangeTemplateType::Full => &self.full,
147            RangeTemplateType::LeftBounded => &self.left_bounded,
148            RangeTemplateType::Prefix => &self.prefix,
149        }
150    }
151
152    /// Adds limit to the template.
153    fn with_limit(template: &str, limit: i64) -> String {
154        if limit == 0 {
155            return format!("{};", template);
156        }
157        format!("{} LIMIT {};", template, limit)
158    }
159}
160
161fn is_prefix_range(start: &[u8], end: &[u8]) -> bool {
162    if start.len() != end.len() {
163        return false;
164    }
165    let l = start.len();
166    let same_prefix = start[0..l - 1] == end[0..l - 1];
167    if let (Some(rhs), Some(lhs)) = (start.last(), end.last()) {
168        return same_prefix && (*rhs + 1) == *lhs;
169    }
170    false
171}
172
173/// Determine the template type for range request.
174fn range_template(key: &[u8], range_end: &[u8]) -> RangeTemplateType {
175    match (key, range_end) {
176        (_, &[]) => RangeTemplateType::Point,
177        (EMPTY, EMPTY) => RangeTemplateType::Full,
178        (_, EMPTY) => RangeTemplateType::LeftBounded,
179        (start, end) => {
180            if is_prefix_range(start, end) {
181                RangeTemplateType::Prefix
182            } else {
183                RangeTemplateType::Range
184            }
185        }
186    }
187}
188
189/// Generate in placeholders for PostgreSQL.
190fn pg_generate_in_placeholders(from: usize, to: usize) -> Vec<String> {
191    (from..=to).map(|i| format!("${}", i)).collect()
192}
193
194/// Factory for building sql templates.
195struct PgSqlTemplateFactory<'a> {
196    schema_name: Option<&'a str>,
197    table_name: &'a str,
198}
199
200impl<'a> PgSqlTemplateFactory<'a> {
201    /// Creates a new factory with optional schema.
202    fn new(schema_name: Option<&'a str>, table_name: &'a str) -> Self {
203        Self {
204            schema_name,
205            table_name,
206        }
207    }
208
209    /// Builds the template set for the given table name.
210    fn build(&self) -> PgSqlTemplateSet {
211        let table_ident = Self::format_table_ident(self.schema_name, self.table_name);
212        // Some of queries don't end with `;`, because we need to add `LIMIT` clause.
213        PgSqlTemplateSet {
214            table_ident: table_ident.clone(),
215            // Do not attempt to create schema implicitly to avoid extra privileges requirement.
216            create_table_statement: format!(
217                "CREATE TABLE IF NOT EXISTS {table_ident}(k bytea PRIMARY KEY, v bytea)",
218            ),
219            range_template: RangeTemplate {
220                point: format!("SELECT k, v FROM {table_ident} WHERE k = $1"),
221                range: format!(
222                    "SELECT k, v FROM {table_ident} WHERE k >= $1 AND k < $2 ORDER BY k"
223                ),
224                full: format!("SELECT k, v FROM {table_ident} ORDER BY k"),
225                left_bounded: format!("SELECT k, v FROM {table_ident} WHERE k >= $1 ORDER BY k"),
226                prefix: format!("SELECT k, v FROM {table_ident} WHERE k LIKE $1 ORDER BY k"),
227            },
228            delete_template: RangeTemplate {
229                point: format!("DELETE FROM {table_ident} WHERE k = $1 RETURNING k,v;"),
230                range: format!("DELETE FROM {table_ident} WHERE k >= $1 AND k < $2 RETURNING k,v;"),
231                full: format!("DELETE FROM {table_ident} RETURNING k,v"),
232                left_bounded: format!("DELETE FROM {table_ident} WHERE k >= $1 RETURNING k,v;"),
233                prefix: format!("DELETE FROM {table_ident} WHERE k LIKE $1 RETURNING k,v;"),
234            },
235        }
236    }
237
238    /// Formats the table reference with schema if provided.
239    fn format_table_ident(schema_name: Option<&str>, table_name: &str) -> String {
240        match schema_name {
241            Some(s) if !s.is_empty() => format!("\"{}\".\"{}\"", s, table_name),
242            _ => format!("\"{}\"", table_name),
243        }
244    }
245}
246
247/// Templates for the given table name.
248#[derive(Debug, Clone)]
249pub struct PgSqlTemplateSet {
250    table_ident: String,
251    create_table_statement: String,
252    range_template: RangeTemplate,
253    delete_template: RangeTemplate,
254}
255
256impl PgSqlTemplateSet {
257    /// Generates the sql for batch get.
258    fn generate_batch_get_query(&self, key_len: usize) -> String {
259        let in_clause = pg_generate_in_placeholders(1, key_len).join(", ");
260        format!(
261            "SELECT k, v FROM {} WHERE k in ({});",
262            self.table_ident, in_clause
263        )
264    }
265
266    /// Generates the sql for batch delete.
267    fn generate_batch_delete_query(&self, key_len: usize) -> String {
268        let in_clause = pg_generate_in_placeholders(1, key_len).join(", ");
269        format!(
270            "DELETE FROM {} WHERE k in ({}) RETURNING k,v;",
271            self.table_ident, in_clause
272        )
273    }
274
275    /// Generates the sql for batch upsert.
276    fn generate_batch_upsert_query(&self, kv_len: usize) -> String {
277        let in_placeholders: Vec<String> = (1..=kv_len).map(|i| format!("${}", i)).collect();
278        let in_clause = in_placeholders.join(", ");
279        let mut param_index = kv_len + 1;
280        let mut values_placeholders = Vec::new();
281        for _ in 0..kv_len {
282            values_placeholders.push(format!("(${0}, ${1})", param_index, param_index + 1));
283            param_index += 2;
284        }
285        let values_clause = values_placeholders.join(", ");
286
287        format!(
288            r#"
289    WITH prev AS (
290        SELECT k,v FROM {table} WHERE k IN ({in_clause})
291    ), update AS (
292    INSERT INTO {table} (k, v) VALUES
293        {values_clause}
294    ON CONFLICT (
295        k
296    ) DO UPDATE SET
297        v = excluded.v
298    )
299
300    SELECT k, v FROM prev;
301    "#,
302            table = self.table_ident,
303            in_clause = in_clause,
304            values_clause = values_clause
305        )
306    }
307}
308
309#[async_trait::async_trait]
310impl Executor for PgClient {
311    type Transaction<'a>
312        = PgTxnClient<'a>
313    where
314        Self: 'a;
315
316    fn name() -> &'static str {
317        "Postgres"
318    }
319
320    async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>> {
321        let params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
322        let stmt = self
323            .0
324            .prepare_cached(query)
325            .await
326            .context(PostgresExecutionSnafu { sql: query })?;
327        let rows = self
328            .0
329            .query(&stmt, &params)
330            .await
331            .context(PostgresExecutionSnafu { sql: query })?;
332        Ok(rows.into_iter().map(key_value_from_row).collect())
333    }
334
335    async fn txn_executor<'a>(&'a mut self) -> Result<Self::Transaction<'a>> {
336        let txn = self
337            .0
338            .build_transaction()
339            .isolation_level(IsolationLevel::Serializable)
340            .start()
341            .await
342            .context(PostgresTransactionSnafu {
343                operation: "begin".to_string(),
344            })?;
345        Ok(PgTxnClient(txn))
346    }
347}
348
349#[async_trait::async_trait]
350impl<'a> Transaction<'a> for PgTxnClient<'a> {
351    async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>> {
352        let params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
353        let stmt = self
354            .0
355            .prepare_cached(query)
356            .await
357            .context(PostgresExecutionSnafu { sql: query })?;
358        let rows = self
359            .0
360            .query(&stmt, &params)
361            .await
362            .context(PostgresExecutionSnafu { sql: query })?;
363        Ok(rows.into_iter().map(key_value_from_row).collect())
364    }
365
366    async fn commit(self) -> Result<()> {
367        self.0.commit().await.context(PostgresTransactionSnafu {
368            operation: "commit",
369        })?;
370        Ok(())
371    }
372}
373
374pub struct PgExecutorFactory {
375    pool: Pool,
376}
377
378impl PgExecutorFactory {
379    async fn client(&self) -> Result<PgClient> {
380        match self.pool.get().await {
381            Ok(client) => Ok(PgClient(client)),
382            Err(e) => GetPostgresConnectionSnafu {
383                reason: e.to_string(),
384            }
385            .fail(),
386        }
387    }
388}
389
390#[async_trait::async_trait]
391impl ExecutorFactory<PgClient> for PgExecutorFactory {
392    async fn default_executor(&self) -> Result<PgClient> {
393        self.client().await
394    }
395
396    async fn txn_executor<'a>(
397        &self,
398        default_executor: &'a mut PgClient,
399    ) -> Result<PgTxnClient<'a>> {
400        default_executor.txn_executor().await
401    }
402}
403
404/// A PostgreSQL-backed key-value store for metasrv.
405/// It uses [deadpool_postgres::Pool] as the connection pool for [RdsStore].
406pub type PgStore = RdsStore<PgClient, PgExecutorFactory, PgSqlTemplateSet>;
407
408/// Creates a PostgreSQL TLS connector based on the provided configuration.
409///
410/// This function creates a rustls-based TLS connector for PostgreSQL connections,
411/// following PostgreSQL's TLS mode specifications exactly:
412///
413/// # TLS Modes (PostgreSQL Specification)
414///
415/// - `Disable`: No TLS connection attempted
416/// - `Prefer`: Try TLS first, fallback to plaintext if TLS fails (handled by connection logic)
417/// - `Require`: Only TLS connections, but NO certificate verification (accept any cert)
418/// - `VerifyCa`: TLS + verify certificate is signed by trusted CA (no hostname verification)
419/// - `VerifyFull`: TLS + verify CA + verify hostname matches certificate SAN
420///
421pub fn create_postgres_tls_connector(tls_config: &TlsOption) -> Result<MakeRustlsConnect> {
422    common_telemetry::info!(
423        "Creating PostgreSQL TLS connector with mode: {:?}",
424        tls_config.mode
425    );
426    ensure_rustls_crypto_provider_installed()?;
427
428    let config_builder = match tls_config.mode {
429        TlsMode::Disable => {
430            return PostgresTlsConfigSnafu {
431                reason: "Cannot create TLS connector for Disable mode".to_string(),
432            }
433            .fail();
434        }
435        TlsMode::Prefer | TlsMode::Require => {
436            // For Prefer/Require: Accept any certificate (no verification)
437            let verifier = Arc::new(AcceptAnyVerifier);
438            ClientConfig::builder()
439                .dangerous()
440                .with_custom_certificate_verifier(verifier)
441        }
442        TlsMode::VerifyCa => {
443            // For VerifyCa: Verify server cert against CA store, but skip hostname verification
444            let ca_store = load_ca(&tls_config.ca_cert_path)?;
445            let verifier = Arc::new(NoHostnameVerification { roots: ca_store });
446            ClientConfig::builder()
447                .dangerous()
448                .with_custom_certificate_verifier(verifier)
449        }
450        TlsMode::VerifyFull => {
451            let ca_store = load_ca(&tls_config.ca_cert_path)?;
452            ClientConfig::builder().with_root_certificates(ca_store)
453        }
454    };
455
456    // Create the TLS client configuration based on the mode and client cert requirements
457    let client_config = if !tls_config.cert_path.is_empty() && !tls_config.key_path.is_empty() {
458        // Client certificate authentication required
459        common_telemetry::info!("Loading client certificate for mutual TLS");
460        let cert_chain = load_certs(&tls_config.cert_path)?;
461        let private_key = load_private_key(&tls_config.key_path)?;
462
463        config_builder
464            .with_client_auth_cert(cert_chain, private_key)
465            .map_err(|e| {
466                PostgresTlsConfigSnafu {
467                    reason: format!("Failed to configure client authentication: {}", e),
468                }
469                .build()
470            })?
471    } else {
472        common_telemetry::info!("No client certificate provided, skip client authentication");
473        config_builder.with_no_client_auth()
474    };
475
476    common_telemetry::info!("Successfully created PostgreSQL TLS connector");
477    Ok(MakeRustlsConnect::new(client_config))
478}
479
480/// For Prefer/Require mode, we accept any server certificate without verification.
481#[derive(Debug)]
482struct AcceptAnyVerifier;
483
484impl ServerCertVerifier for AcceptAnyVerifier {
485    fn verify_server_cert(
486        &self,
487        _end_entity: &CertificateDer<'_>,
488        _intermediates: &[CertificateDer<'_>],
489        _server_name: &ServerName<'_>,
490        _ocsp_response: &[u8],
491        _now: UnixTime,
492    ) -> std::result::Result<ServerCertVerified, TlsError> {
493        common_telemetry::debug!(
494            "Accepting server certificate without verification (Prefer/Require mode)"
495        );
496        Ok(ServerCertVerified::assertion())
497    }
498
499    fn verify_tls12_signature(
500        &self,
501        _message: &[u8],
502        _cert: &CertificateDer<'_>,
503        _dss: &DigitallySignedStruct,
504    ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
505        // Accept any signature without verification
506        Ok(HandshakeSignatureValid::assertion())
507    }
508
509    fn verify_tls13_signature(
510        &self,
511        _message: &[u8],
512        _cert: &CertificateDer<'_>,
513        _dss: &DigitallySignedStruct,
514    ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
515        // Accept any signature without verification
516        Ok(HandshakeSignatureValid::assertion())
517    }
518
519    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
520        // Support all signature schemes
521        rustls::crypto::aws_lc_rs::default_provider()
522            .signature_verification_algorithms
523            .supported_schemes()
524    }
525}
526
527/// For VerifyCa mode, we verify the server certificate against our CA store
528/// and skip verify server's HostName.
529#[derive(Debug)]
530struct NoHostnameVerification {
531    roots: Arc<rustls::RootCertStore>,
532}
533
534impl ServerCertVerifier for NoHostnameVerification {
535    fn verify_server_cert(
536        &self,
537        end_entity: &CertificateDer<'_>,
538        intermediates: &[CertificateDer<'_>],
539        _server_name: &ServerName<'_>,
540        _ocsp_response: &[u8],
541        now: UnixTime,
542    ) -> std::result::Result<ServerCertVerified, TlsError> {
543        let cert = ParsedCertificate::try_from(end_entity)?;
544        rustls::client::verify_server_cert_signed_by_trust_anchor(
545            &cert,
546            &self.roots,
547            intermediates,
548            now,
549            rustls::crypto::aws_lc_rs::default_provider()
550                .signature_verification_algorithms
551                .all,
552        )?;
553
554        Ok(ServerCertVerified::assertion())
555    }
556
557    fn verify_tls12_signature(
558        &self,
559        message: &[u8],
560        cert: &CertificateDer<'_>,
561        dss: &DigitallySignedStruct,
562    ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
563        rustls::crypto::verify_tls12_signature(
564            message,
565            cert,
566            dss,
567            &rustls::crypto::aws_lc_rs::default_provider().signature_verification_algorithms,
568        )
569    }
570
571    fn verify_tls13_signature(
572        &self,
573        message: &[u8],
574        cert: &CertificateDer<'_>,
575        dss: &DigitallySignedStruct,
576    ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
577        rustls::crypto::verify_tls13_signature(
578            message,
579            cert,
580            dss,
581            &rustls::crypto::aws_lc_rs::default_provider().signature_verification_algorithms,
582        )
583    }
584
585    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
586        // Support all signature schemes
587        rustls::crypto::aws_lc_rs::default_provider()
588            .signature_verification_algorithms
589            .supported_schemes()
590    }
591}
592
593fn load_certs(path: &str) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
594    let file = File::open(path).context(LoadTlsCertificateSnafu { path })?;
595    let mut reader = BufReader::new(file);
596    let certs = certs(&mut reader)
597        .collect::<std::result::Result<Vec<_>, _>>()
598        .map_err(|e| {
599            PostgresTlsConfigSnafu {
600                reason: format!("Failed to parse certificates from {}: {}", path, e),
601            }
602            .build()
603        })?;
604    Ok(certs)
605}
606
607fn load_private_key(path: &str) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
608    let file = File::open(path).context(LoadTlsCertificateSnafu { path })?;
609    let mut reader = BufReader::new(file);
610    let key = private_key(&mut reader)
611        .map_err(|e| {
612            PostgresTlsConfigSnafu {
613                reason: format!("Failed to parse private key from {}: {}", path, e),
614            }
615            .build()
616        })?
617        .ok_or_else(|| {
618            PostgresTlsConfigSnafu {
619                reason: format!("No private key found in {}", path),
620            }
621            .build()
622        })?;
623    Ok(key)
624}
625
626fn load_ca(path: &str) -> Result<Arc<rustls::RootCertStore>> {
627    let mut root_store = rustls::RootCertStore::empty();
628
629    // Add system root certificates
630    match rustls_native_certs::load_native_certs() {
631        Ok(certs) => {
632            let num_certs = certs.len();
633            for cert in certs {
634                if let Err(e) = root_store.add(cert) {
635                    return PostgresTlsConfigSnafu {
636                        reason: format!("Failed to add root certificate: {}", e),
637                    }
638                    .fail();
639                }
640            }
641            common_telemetry::info!("Loaded {num_certs} system root certificates successfully");
642        }
643        Err(e) => {
644            return PostgresTlsConfigSnafu {
645                reason: format!("Failed to load system root certificates: {}", e),
646            }
647            .fail();
648        }
649    }
650
651    // Try add custom CA certificate if provided
652    if !path.is_empty() {
653        let ca_certs = load_certs(path)?;
654        for cert in ca_certs {
655            if let Err(e) = root_store.add(cert) {
656                return PostgresTlsConfigSnafu {
657                    reason: format!("Failed to add custom CA certificate: {}", e),
658                }
659                .fail();
660            }
661        }
662        common_telemetry::info!("Added custom CA certificate from {}", path);
663    }
664
665    Ok(Arc::new(root_store))
666}
667
668#[async_trait::async_trait]
669impl KvQueryExecutor<PgClient> for PgStore {
670    async fn range_with_query_executor(
671        &self,
672        query_executor: &mut ExecutorImpl<'_, PgClient>,
673        req: RangeRequest,
674    ) -> Result<RangeResponse> {
675        let template_type = range_template(&req.key, &req.range_end);
676        let template = self.sql_template_set.range_template.get(template_type);
677        let params = template_type.build_params(req.key, req.range_end);
678        let params_ref = params.iter().collect::<Vec<_>>();
679        // Always add 1 to limit to check if there is more data
680        let query =
681            RangeTemplate::with_limit(template, if req.limit == 0 { 0 } else { req.limit + 1 });
682        let limit = req.limit as usize;
683        debug!("query: {:?}, params: {:?}", query, params);
684        let mut kvs = crate::record_rds_sql_execute_elapsed!(
685            query_executor.query(&query, &params_ref).await,
686            PG_STORE_NAME,
687            RDS_STORE_OP_RANGE_QUERY,
688            template_type.as_ref()
689        )?;
690
691        if req.keys_only {
692            kvs.iter_mut().for_each(|kv| kv.value = vec![]);
693        }
694        // If limit is 0, we always return all data
695        if limit == 0 || kvs.len() <= limit {
696            return Ok(RangeResponse { kvs, more: false });
697        }
698        // If limit is greater than the number of rows, we remove the last row and set more to true
699        let removed = kvs.pop();
700        debug_assert!(removed.is_some());
701        Ok(RangeResponse { kvs, more: true })
702    }
703
704    async fn batch_put_with_query_executor(
705        &self,
706        query_executor: &mut ExecutorImpl<'_, PgClient>,
707        req: BatchPutRequest,
708    ) -> Result<BatchPutResponse> {
709        let mut in_params = Vec::with_capacity(req.kvs.len() * 3);
710        let mut values_params = Vec::with_capacity(req.kvs.len() * 2);
711
712        for kv in &req.kvs {
713            let processed_key = &kv.key;
714            in_params.push(processed_key);
715
716            let processed_value = &kv.value;
717            values_params.push(processed_key);
718            values_params.push(processed_value);
719        }
720        in_params.extend(values_params);
721        let params = in_params.iter().map(|x| x as _).collect::<Vec<_>>();
722        let query = self
723            .sql_template_set
724            .generate_batch_upsert_query(req.kvs.len());
725
726        let kvs = crate::record_rds_sql_execute_elapsed!(
727            query_executor.query(&query, &params).await,
728            PG_STORE_NAME,
729            RDS_STORE_OP_BATCH_PUT,
730            ""
731        )?;
732        if req.prev_kv {
733            Ok(BatchPutResponse { prev_kvs: kvs })
734        } else {
735            Ok(BatchPutResponse::default())
736        }
737    }
738
739    /// Batch get with certain client. It's needed for a client with transaction.
740    async fn batch_get_with_query_executor(
741        &self,
742        query_executor: &mut ExecutorImpl<'_, PgClient>,
743        req: BatchGetRequest,
744    ) -> Result<BatchGetResponse> {
745        if req.keys.is_empty() {
746            return Ok(BatchGetResponse { kvs: vec![] });
747        }
748        let query = self
749            .sql_template_set
750            .generate_batch_get_query(req.keys.len());
751        let params = req.keys.iter().map(|x| x as _).collect::<Vec<_>>();
752        let kvs = crate::record_rds_sql_execute_elapsed!(
753            query_executor.query(&query, &params).await,
754            PG_STORE_NAME,
755            RDS_STORE_OP_BATCH_GET,
756            ""
757        )?;
758        Ok(BatchGetResponse { kvs })
759    }
760
761    async fn delete_range_with_query_executor(
762        &self,
763        query_executor: &mut ExecutorImpl<'_, PgClient>,
764        req: DeleteRangeRequest,
765    ) -> Result<DeleteRangeResponse> {
766        let template_type = range_template(&req.key, &req.range_end);
767        let template = self.sql_template_set.delete_template.get(template_type);
768        let params = template_type.build_params(req.key, req.range_end);
769        let params_ref = params.iter().map(|x| x as _).collect::<Vec<_>>();
770        let kvs = crate::record_rds_sql_execute_elapsed!(
771            query_executor.query(template, &params_ref).await,
772            PG_STORE_NAME,
773            RDS_STORE_OP_RANGE_DELETE,
774            template_type.as_ref()
775        )?;
776        let mut resp = DeleteRangeResponse::new(kvs.len() as i64);
777        if req.prev_kv {
778            resp.with_prev_kvs(kvs);
779        }
780        Ok(resp)
781    }
782
783    async fn batch_delete_with_query_executor(
784        &self,
785        query_executor: &mut ExecutorImpl<'_, PgClient>,
786        req: BatchDeleteRequest,
787    ) -> Result<BatchDeleteResponse> {
788        if req.keys.is_empty() {
789            return Ok(BatchDeleteResponse::default());
790        }
791        let query = self
792            .sql_template_set
793            .generate_batch_delete_query(req.keys.len());
794        let params = req.keys.iter().map(|x| x as _).collect::<Vec<_>>();
795
796        let kvs = crate::record_rds_sql_execute_elapsed!(
797            query_executor.query(&query, &params).await,
798            PG_STORE_NAME,
799            RDS_STORE_OP_BATCH_DELETE,
800            ""
801        )?;
802        if req.prev_kv {
803            Ok(BatchDeleteResponse { prev_kvs: kvs })
804        } else {
805            Ok(BatchDeleteResponse::default())
806        }
807    }
808}
809
810impl PgStore {
811    /// Create [PgStore] impl of [KvBackendRef] from url with optional TLS support.
812    ///
813    /// # Arguments
814    ///
815    /// * `url` - PostgreSQL connection URL
816    /// * `table_name` - Name of the table to use for key-value storage
817    /// * `max_txn_ops` - Maximum number of operations per transaction
818    /// * `tls_config` - Optional TLS configuration. If None, uses plaintext connection.
819    pub async fn with_url_and_tls(
820        url: &str,
821        table_name: &str,
822        max_txn_ops: usize,
823        tls_config: Option<TlsOption>,
824    ) -> Result<KvBackendRef> {
825        let mut cfg = Config::new();
826        cfg.url = Some(url.to_string());
827
828        let pool = match tls_config {
829            Some(tls_config) if tls_config.mode != TlsMode::Disable => {
830                match create_postgres_tls_connector(&tls_config) {
831                    Ok(tls_connector) => cfg
832                        .create_pool(Some(Runtime::Tokio1), tls_connector)
833                        .context(CreatePostgresPoolSnafu)?,
834                    Err(e) => {
835                        if tls_config.mode == TlsMode::Prefer {
836                            // Fallback to insecure connection if TLS fails
837                            common_telemetry::info!(
838                                "Failed to create TLS connector, falling back to insecure connection"
839                            );
840                            cfg.create_pool(Some(Runtime::Tokio1), NoTls)
841                                .context(CreatePostgresPoolSnafu)?
842                        } else {
843                            return Err(e);
844                        }
845                    }
846                }
847            }
848            _ => cfg
849                .create_pool(Some(Runtime::Tokio1), NoTls)
850                .context(CreatePostgresPoolSnafu)?,
851        };
852
853        Self::with_pg_pool(pool, None, table_name, max_txn_ops, false).await
854    }
855
856    /// Create [PgStore] impl of [KvBackendRef] from url (backward compatibility).
857    pub async fn with_url(url: &str, table_name: &str, max_txn_ops: usize) -> Result<KvBackendRef> {
858        Self::with_url_and_tls(url, table_name, max_txn_ops, None).await
859    }
860
861    /// Create [PgStore] impl of [KvBackendRef] from [deadpool_postgres::Pool] with optional schema.
862    pub async fn with_pg_pool(
863        pool: Pool,
864        schema_name: Option<&str>,
865        table_name: &str,
866        max_txn_ops: usize,
867        auto_create_schema: bool,
868    ) -> Result<KvBackendRef> {
869        // Ensure the postgres metadata backend is ready to use.
870        let client = match pool.get().await {
871            Ok(client) => client,
872            Err(e) => {
873                // We need to log the debug for the error to help diagnose the issue.
874                common_telemetry::error!(e; "Failed to get Postgres connection.");
875                return GetPostgresConnectionSnafu {
876                    reason: e.to_string(),
877                }
878                .fail();
879            }
880        };
881
882        // Automatically create schema if enabled and schema_name is provided.
883        if auto_create_schema
884            && let Some(schema) = schema_name
885            && !schema.is_empty()
886        {
887            let create_schema_sql = format!("CREATE SCHEMA IF NOT EXISTS \"{}\"", schema);
888            client
889                .execute(&create_schema_sql, &[])
890                .await
891                .with_context(|_| PostgresExecutionSnafu {
892                    sql: create_schema_sql.clone(),
893                })?;
894        }
895
896        let template_factory = PgSqlTemplateFactory::new(schema_name, table_name);
897        let sql_template_set = template_factory.build();
898        client
899            .execute(&sql_template_set.create_table_statement, &[])
900            .await
901            .with_context(|_| PostgresExecutionSnafu {
902                sql: sql_template_set.create_table_statement.clone(),
903            })?;
904        Ok(Arc::new(Self {
905            max_txn_ops,
906            sql_template_set,
907            txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
908            executor_factory: PgExecutorFactory { pool },
909            _phantom: PhantomData,
910        }))
911    }
912}
913
914#[cfg(test)]
915mod tests {
916    use super::*;
917    use crate::kv_backend::test::{
918        prepare_kv_with_prefix, test_kv_batch_delete_with_prefix, test_kv_batch_get_with_prefix,
919        test_kv_compare_and_put_with_prefix, test_kv_delete_range_with_prefix,
920        test_kv_put_with_prefix, test_kv_range_2_with_prefix, test_kv_range_with_prefix,
921        test_simple_kv_range, test_txn_compare_equal, test_txn_compare_greater,
922        test_txn_compare_less, test_txn_compare_not_equal, test_txn_one_compare_op,
923        text_txn_multi_compare_op, unprepare_kv,
924    };
925    use crate::test_util::test_certs_dir;
926    use crate::{maybe_skip_postgres_integration_test, maybe_skip_postgres15_integration_test};
927
928    async fn build_pg_kv_backend(table_name: &str) -> Option<PgStore> {
929        let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap_or_default();
930        if endpoints.is_empty() {
931            return None;
932        }
933
934        let mut cfg = Config::new();
935        cfg.url = Some(endpoints);
936        let pool = cfg
937            .create_pool(Some(Runtime::Tokio1), NoTls)
938            .context(CreatePostgresPoolSnafu)
939            .unwrap();
940        let client = pool.get().await.unwrap();
941        // use the default schema (i.e., public)
942        let template_factory = PgSqlTemplateFactory::new(None, table_name);
943        let sql_templates = template_factory.build();
944        // Do not attempt to create schema implicitly.
945        client
946            .execute(&sql_templates.create_table_statement, &[])
947            .await
948            .with_context(|_| PostgresExecutionSnafu {
949                sql: sql_templates.create_table_statement.clone(),
950            })
951            .unwrap();
952        Some(PgStore {
953            max_txn_ops: 128,
954            sql_template_set: sql_templates,
955            txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
956            executor_factory: PgExecutorFactory { pool },
957            _phantom: PhantomData,
958        })
959    }
960
961    async fn build_pg15_pool() -> Option<Pool> {
962        let url = std::env::var("GT_POSTGRES15_ENDPOINTS").unwrap_or_default();
963        if url.is_empty() {
964            return None;
965        }
966        let mut cfg = Config::new();
967        cfg.url = Some(url);
968        let pool = cfg
969            .create_pool(Some(Runtime::Tokio1), NoTls)
970            .context(CreatePostgresPoolSnafu)
971            .ok()?;
972        Some(pool)
973    }
974
975    #[tokio::test]
976    async fn test_pg15_create_table_in_public_should_fail() {
977        maybe_skip_postgres15_integration_test!();
978        let Some(pool) = build_pg15_pool().await else {
979            return;
980        };
981        let res = PgStore::with_pg_pool(pool, None, "pg15_public_should_fail", 128, false).await;
982        assert!(
983            res.is_err(),
984            "creating table in public should fail for test_user"
985        );
986    }
987
988    #[tokio::test]
989    async fn test_pg15_create_table_in_test_schema_and_crud_should_succeed() {
990        maybe_skip_postgres15_integration_test!();
991        let Some(pool) = build_pg15_pool().await else {
992            return;
993        };
994        let schema_name = std::env::var("GT_POSTGRES15_SCHEMA").unwrap();
995        let client = pool.get().await.unwrap();
996        let factory = PgSqlTemplateFactory::new(Some(&schema_name), "pg15_ok");
997        let templates = factory.build();
998        client
999            .execute(&templates.create_table_statement, &[])
1000            .await
1001            .unwrap();
1002        let kv = PgStore {
1003            max_txn_ops: 128,
1004            sql_template_set: templates,
1005            txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
1006            executor_factory: PgExecutorFactory { pool },
1007            _phantom: PhantomData,
1008        };
1009        let prefix = b"pg15_crud/";
1010        prepare_kv_with_prefix(&kv, prefix.to_vec()).await;
1011        test_kv_put_with_prefix(&kv, prefix.to_vec()).await;
1012        test_kv_batch_get_with_prefix(&kv, prefix.to_vec()).await;
1013        unprepare_kv(&kv, prefix).await;
1014    }
1015
1016    #[tokio::test]
1017    async fn test_pg_with_tls() {
1018        common_telemetry::init_default_ut_logging();
1019        maybe_skip_postgres_integration_test!();
1020        let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1021        let tls_connector = create_postgres_tls_connector(&TlsOption {
1022            mode: TlsMode::Require,
1023            cert_path: String::new(),
1024            key_path: String::new(),
1025            ca_cert_path: String::new(),
1026            watch: false,
1027        })
1028        .unwrap();
1029        let mut cfg = Config::new();
1030        cfg.url = Some(endpoints);
1031        let pool = cfg
1032            .create_pool(Some(Runtime::Tokio1), tls_connector)
1033            .unwrap();
1034        let client = pool.get().await.unwrap();
1035        client.execute("SELECT 1", &[]).await.unwrap();
1036    }
1037
1038    #[tokio::test]
1039    async fn test_pg_with_mtls() {
1040        common_telemetry::init_default_ut_logging();
1041        maybe_skip_postgres_integration_test!();
1042        let certs_dir = test_certs_dir();
1043        let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1044        let tls_connector = create_postgres_tls_connector(&TlsOption {
1045            mode: TlsMode::Require,
1046            cert_path: certs_dir.join("client.crt").display().to_string(),
1047            key_path: certs_dir.join("client.key").display().to_string(),
1048            ca_cert_path: String::new(),
1049            watch: false,
1050        })
1051        .unwrap();
1052        let mut cfg = Config::new();
1053        cfg.url = Some(endpoints);
1054        let pool = cfg
1055            .create_pool(Some(Runtime::Tokio1), tls_connector)
1056            .unwrap();
1057        let client = pool.get().await.unwrap();
1058        client.execute("SELECT 1", &[]).await.unwrap();
1059    }
1060
1061    #[tokio::test]
1062    async fn test_pg_verify_ca() {
1063        common_telemetry::init_default_ut_logging();
1064        maybe_skip_postgres_integration_test!();
1065        let certs_dir = test_certs_dir();
1066        let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1067        let tls_connector = create_postgres_tls_connector(&TlsOption {
1068            mode: TlsMode::VerifyCa,
1069            cert_path: certs_dir.join("client.crt").display().to_string(),
1070            key_path: certs_dir.join("client.key").display().to_string(),
1071            ca_cert_path: certs_dir.join("root.crt").display().to_string(),
1072            watch: false,
1073        })
1074        .unwrap();
1075        let mut cfg = Config::new();
1076        cfg.url = Some(endpoints);
1077        let pool = cfg
1078            .create_pool(Some(Runtime::Tokio1), tls_connector)
1079            .unwrap();
1080        let client = pool.get().await.unwrap();
1081        client.execute("SELECT 1", &[]).await.unwrap();
1082    }
1083
1084    #[tokio::test]
1085    async fn test_pg_verify_full() {
1086        common_telemetry::init_default_ut_logging();
1087        maybe_skip_postgres_integration_test!();
1088        let certs_dir = test_certs_dir();
1089        let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1090        let tls_connector = create_postgres_tls_connector(&TlsOption {
1091            mode: TlsMode::VerifyFull,
1092            cert_path: certs_dir.join("client.crt").display().to_string(),
1093            key_path: certs_dir.join("client.key").display().to_string(),
1094            ca_cert_path: certs_dir.join("root.crt").display().to_string(),
1095            watch: false,
1096        })
1097        .unwrap();
1098        let mut cfg = Config::new();
1099        cfg.url = Some(endpoints);
1100        let pool = cfg
1101            .create_pool(Some(Runtime::Tokio1), tls_connector)
1102            .unwrap();
1103        let client = pool.get().await.unwrap();
1104        client.execute("SELECT 1", &[]).await.unwrap();
1105    }
1106
1107    #[tokio::test]
1108    async fn test_pg_put() {
1109        maybe_skip_postgres_integration_test!();
1110        let kv_backend = build_pg_kv_backend("put-test").await.unwrap();
1111        let prefix = b"put/";
1112        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1113        test_kv_put_with_prefix(&kv_backend, prefix.to_vec()).await;
1114        unprepare_kv(&kv_backend, prefix).await;
1115    }
1116
1117    #[tokio::test]
1118    async fn test_pg_range() {
1119        maybe_skip_postgres_integration_test!();
1120        let kv_backend = build_pg_kv_backend("range-test").await.unwrap();
1121        let prefix = b"range/";
1122        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1123        test_kv_range_with_prefix(&kv_backend, prefix.to_vec()).await;
1124        unprepare_kv(&kv_backend, prefix).await;
1125    }
1126
1127    #[tokio::test]
1128    async fn test_pg_range_2() {
1129        maybe_skip_postgres_integration_test!();
1130        let kv_backend = build_pg_kv_backend("range2-test").await.unwrap();
1131        let prefix = b"range2/";
1132        test_kv_range_2_with_prefix(&kv_backend, prefix.to_vec()).await;
1133        unprepare_kv(&kv_backend, prefix).await;
1134    }
1135
1136    #[tokio::test]
1137    async fn test_pg_all_range() {
1138        maybe_skip_postgres_integration_test!();
1139        let kv_backend = build_pg_kv_backend("simple_range-test").await.unwrap();
1140        let prefix = b"";
1141        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1142        test_simple_kv_range(&kv_backend).await;
1143        unprepare_kv(&kv_backend, prefix).await;
1144    }
1145
1146    #[tokio::test]
1147    async fn test_pg_batch_get() {
1148        maybe_skip_postgres_integration_test!();
1149        let kv_backend = build_pg_kv_backend("batch_get-test").await.unwrap();
1150        let prefix = b"batch_get/";
1151        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1152        test_kv_batch_get_with_prefix(&kv_backend, prefix.to_vec()).await;
1153        unprepare_kv(&kv_backend, prefix).await;
1154    }
1155
1156    #[tokio::test]
1157    async fn test_pg_batch_delete() {
1158        maybe_skip_postgres_integration_test!();
1159        let kv_backend = build_pg_kv_backend("batch_delete-test").await.unwrap();
1160        let prefix = b"batch_delete/";
1161        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1162        test_kv_delete_range_with_prefix(&kv_backend, prefix.to_vec()).await;
1163        unprepare_kv(&kv_backend, prefix).await;
1164    }
1165
1166    #[tokio::test]
1167    async fn test_pg_batch_delete_with_prefix() {
1168        maybe_skip_postgres_integration_test!();
1169        let kv_backend = build_pg_kv_backend("batch_delete_with_prefix-test")
1170            .await
1171            .unwrap();
1172        let prefix = b"batch_delete/";
1173        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1174        test_kv_batch_delete_with_prefix(&kv_backend, prefix.to_vec()).await;
1175        unprepare_kv(&kv_backend, prefix).await;
1176    }
1177
1178    #[tokio::test]
1179    async fn test_pg_delete_range() {
1180        maybe_skip_postgres_integration_test!();
1181        let kv_backend = build_pg_kv_backend("delete_range-test").await.unwrap();
1182        let prefix = b"delete_range/";
1183        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1184        test_kv_delete_range_with_prefix(&kv_backend, prefix.to_vec()).await;
1185        unprepare_kv(&kv_backend, prefix).await;
1186    }
1187
1188    #[tokio::test]
1189    async fn test_pg_compare_and_put() {
1190        maybe_skip_postgres_integration_test!();
1191        let kv_backend = build_pg_kv_backend("compare_and_put-test").await.unwrap();
1192        let prefix = b"compare_and_put/";
1193        let kv_backend = Arc::new(kv_backend);
1194        test_kv_compare_and_put_with_prefix(kv_backend.clone(), prefix.to_vec()).await;
1195    }
1196
1197    #[tokio::test]
1198    async fn test_pg_txn() {
1199        maybe_skip_postgres_integration_test!();
1200        let kv_backend = build_pg_kv_backend("txn-test").await.unwrap();
1201        test_txn_one_compare_op(&kv_backend).await;
1202        text_txn_multi_compare_op(&kv_backend).await;
1203        test_txn_compare_equal(&kv_backend).await;
1204        test_txn_compare_greater(&kv_backend).await;
1205        test_txn_compare_less(&kv_backend).await;
1206        test_txn_compare_not_equal(&kv_backend).await;
1207    }
1208
1209    #[test]
1210    fn test_pg_template_with_schema() {
1211        let factory = PgSqlTemplateFactory::new(Some("test_schema"), "greptime_metakv");
1212        let t = factory.build();
1213        assert!(
1214            t.create_table_statement
1215                .contains("\"test_schema\".\"greptime_metakv\"")
1216        );
1217        let upsert = t.generate_batch_upsert_query(1);
1218        assert!(upsert.contains("\"test_schema\".\"greptime_metakv\""));
1219        let get = t.generate_batch_get_query(1);
1220        assert!(get.contains("\"test_schema\".\"greptime_metakv\""));
1221        let del = t.generate_batch_delete_query(1);
1222        assert!(del.contains("\"test_schema\".\"greptime_metakv\""));
1223    }
1224
1225    #[test]
1226    fn test_format_table_ident() {
1227        let t = PgSqlTemplateFactory::format_table_ident(None, "test_table");
1228        assert_eq!(t, "\"test_table\"");
1229
1230        let t = PgSqlTemplateFactory::format_table_ident(Some("test_schema"), "test_table");
1231        assert_eq!(t, "\"test_schema\".\"test_table\"");
1232
1233        let t = PgSqlTemplateFactory::format_table_ident(Some(""), "test_table");
1234        assert_eq!(t, "\"test_table\"");
1235    }
1236
1237    #[tokio::test]
1238    async fn test_auto_create_schema_enabled() {
1239        common_telemetry::init_default_ut_logging();
1240        maybe_skip_postgres_integration_test!();
1241        let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1242        let mut cfg = Config::new();
1243        cfg.url = Some(endpoints);
1244        let pool = cfg
1245            .create_pool(Some(Runtime::Tokio1), NoTls)
1246            .context(CreatePostgresPoolSnafu)
1247            .unwrap();
1248
1249        let schema_name = "test_auto_create_enabled";
1250        let table_name = "test_table";
1251
1252        // Drop the schema if it exists to start clean
1253        let client = pool.get().await.unwrap();
1254        let _ = client
1255            .execute(
1256                &format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
1257                &[],
1258            )
1259            .await;
1260
1261        // Create store with auto_create_schema enabled
1262        let _ = PgStore::with_pg_pool(pool.clone(), Some(schema_name), table_name, 128, true)
1263            .await
1264            .unwrap();
1265
1266        // Verify schema was created
1267        let row = client
1268            .query_one(
1269                "SELECT schema_name FROM information_schema.schemata WHERE schema_name = $1",
1270                &[&schema_name],
1271            )
1272            .await
1273            .unwrap();
1274        let created_schema: String = row.get(0);
1275        assert_eq!(created_schema, schema_name);
1276
1277        // Verify table was created in the schema
1278        let row = client
1279            .query_one(
1280                "SELECT table_schema, table_name FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2",
1281                &[&schema_name, &table_name],
1282            )
1283            .await
1284            .unwrap();
1285        let created_table_schema: String = row.get(0);
1286        let created_table_name: String = row.get(1);
1287        assert_eq!(created_table_schema, schema_name);
1288        assert_eq!(created_table_name, table_name);
1289
1290        // Cleanup
1291        let _ = client
1292            .execute(
1293                &format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
1294                &[],
1295            )
1296            .await;
1297    }
1298
1299    #[tokio::test]
1300    async fn test_auto_create_schema_disabled() {
1301        common_telemetry::init_default_ut_logging();
1302        maybe_skip_postgres_integration_test!();
1303        let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1304        let mut cfg = Config::new();
1305        cfg.url = Some(endpoints);
1306        let pool = cfg
1307            .create_pool(Some(Runtime::Tokio1), NoTls)
1308            .context(CreatePostgresPoolSnafu)
1309            .unwrap();
1310
1311        let schema_name = "test_auto_create_disabled";
1312        let table_name = "test_table";
1313
1314        // Drop the schema if it exists to start clean
1315        let client = pool.get().await.unwrap();
1316        let _ = client
1317            .execute(
1318                &format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
1319                &[],
1320            )
1321            .await;
1322
1323        // Try to create store with auto_create_schema disabled (should fail)
1324        let result =
1325            PgStore::with_pg_pool(pool.clone(), Some(schema_name), table_name, 128, false).await;
1326
1327        // Verify it failed because schema doesn't exist
1328        assert!(
1329            result.is_err(),
1330            "Expected error when schema doesn't exist and auto_create_schema is disabled"
1331        );
1332    }
1333
1334    #[tokio::test]
1335    async fn test_auto_create_schema_already_exists() {
1336        common_telemetry::init_default_ut_logging();
1337        maybe_skip_postgres_integration_test!();
1338        let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1339        let mut cfg = Config::new();
1340        cfg.url = Some(endpoints);
1341        let pool = cfg
1342            .create_pool(Some(Runtime::Tokio1), NoTls)
1343            .context(CreatePostgresPoolSnafu)
1344            .unwrap();
1345
1346        let schema_name = "test_auto_create_existing";
1347        let table_name = "test_table";
1348
1349        // Manually create the schema first
1350        let client = pool.get().await.unwrap();
1351        let _ = client
1352            .execute(
1353                &format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
1354                &[],
1355            )
1356            .await;
1357        client
1358            .execute(&format!("CREATE SCHEMA \"{}\"", schema_name), &[])
1359            .await
1360            .unwrap();
1361
1362        // Create store with auto_create_schema enabled (should succeed idempotently)
1363        let _ = PgStore::with_pg_pool(pool.clone(), Some(schema_name), table_name, 128, true)
1364            .await
1365            .unwrap();
1366
1367        // Verify schema still exists
1368        let row = client
1369            .query_one(
1370                "SELECT schema_name FROM information_schema.schemata WHERE schema_name = $1",
1371                &[&schema_name],
1372            )
1373            .await
1374            .unwrap();
1375        let created_schema: String = row.get(0);
1376        assert_eq!(created_schema, schema_name);
1377
1378        // Verify table was created in the schema
1379        let row = client
1380            .query_one(
1381                "SELECT table_schema, table_name FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2",
1382                &[&schema_name, &table_name],
1383            )
1384            .await
1385            .unwrap();
1386        let created_table_schema: String = row.get(0);
1387        let created_table_name: String = row.get(1);
1388        assert_eq!(created_table_schema, schema_name);
1389        assert_eq!(created_table_name, table_name);
1390
1391        // Cleanup
1392        let _ = client
1393            .execute(
1394                &format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
1395                &[],
1396            )
1397            .await;
1398    }
1399
1400    #[tokio::test]
1401    async fn test_auto_create_schema_no_schema_name() {
1402        common_telemetry::init_default_ut_logging();
1403        maybe_skip_postgres_integration_test!();
1404        let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1405        let mut cfg = Config::new();
1406        cfg.url = Some(endpoints);
1407        let pool = cfg
1408            .create_pool(Some(Runtime::Tokio1), NoTls)
1409            .context(CreatePostgresPoolSnafu)
1410            .unwrap();
1411
1412        let table_name = "test_table_no_schema";
1413
1414        // Create store with auto_create_schema enabled but no schema name (should succeed)
1415        // This should create the table in the default schema (public)
1416        let _ = PgStore::with_pg_pool(pool.clone(), None, table_name, 128, true)
1417            .await
1418            .unwrap();
1419
1420        // Verify table was created in public schema
1421        let client = pool.get().await.unwrap();
1422        let row = client
1423            .query_one(
1424                "SELECT table_schema, table_name FROM information_schema.tables WHERE table_name = $1",
1425                &[&table_name],
1426            )
1427            .await
1428            .unwrap();
1429        let created_table_schema: String = row.get(0);
1430        let created_table_name: String = row.get(1);
1431        assert_eq!(created_table_name, table_name);
1432        // Verify it's in public schema (or whichever is the default)
1433        assert!(created_table_schema == "public" || !created_table_schema.is_empty());
1434
1435        // Cleanup
1436        let _ = client
1437            .execute(&format!("DROP TABLE IF EXISTS \"{}\"", table_name), &[])
1438            .await;
1439    }
1440
1441    #[tokio::test]
1442    async fn test_auto_create_schema_with_empty_schema_name() {
1443        common_telemetry::init_default_ut_logging();
1444        maybe_skip_postgres_integration_test!();
1445        let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
1446        let mut cfg = Config::new();
1447        cfg.url = Some(endpoints);
1448        let pool = cfg
1449            .create_pool(Some(Runtime::Tokio1), NoTls)
1450            .context(CreatePostgresPoolSnafu)
1451            .unwrap();
1452
1453        let table_name = "test_table_empty_schema";
1454
1455        // Create store with auto_create_schema enabled but empty schema name (should succeed)
1456        // This should create the table in the default schema (public)
1457        let _ = PgStore::with_pg_pool(pool.clone(), Some(""), table_name, 128, true)
1458            .await
1459            .unwrap();
1460
1461        // Verify table was created in public schema
1462        let client = pool.get().await.unwrap();
1463        let row = client
1464            .query_one(
1465                "SELECT table_schema, table_name FROM information_schema.tables WHERE table_name = $1",
1466                &[&table_name],
1467            )
1468            .await
1469            .unwrap();
1470        let created_table_schema: String = row.get(0);
1471        let created_table_name: String = row.get(1);
1472        assert_eq!(created_table_name, table_name);
1473        // Verify it's in public schema (or whichever is the default)
1474        assert!(created_table_schema == "public" || !created_table_schema.is_empty());
1475
1476        // Cleanup
1477        let _ = client
1478            .execute(&format!("DROP TABLE IF EXISTS \"{}\"", table_name), &[])
1479            .await;
1480    }
1481}