common_meta/kv_backend/rds/
postgres.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fs::File;
16use std::io::BufReader;
17use std::marker::PhantomData;
18use std::sync::Arc;
19
20use common_telemetry::debug;
21use deadpool_postgres::{Config, Pool, Runtime};
22// TLS-related imports (feature-gated)
23use rustls::ClientConfig;
24use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
25use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
26use rustls::server::ParsedCertificate;
27use rustls::{DigitallySignedStruct, Error as TlsError, SignatureScheme};
28use rustls_pemfile::{certs, private_key};
29use snafu::ResultExt;
30use strum::AsRefStr;
31use tokio_postgres::types::ToSql;
32use tokio_postgres::{IsolationLevel, NoTls, Row};
33use tokio_postgres_rustls::MakeRustlsConnect;
34
35use crate::error::{
36    CreatePostgresPoolSnafu, GetPostgresConnectionSnafu, LoadTlsCertificateSnafu,
37    PostgresExecutionSnafu, PostgresTlsConfigSnafu, PostgresTransactionSnafu, Result,
38};
39use crate::kv_backend::KvBackendRef;
40use crate::kv_backend::rds::{
41    Executor, ExecutorFactory, ExecutorImpl, KvQueryExecutor, RDS_STORE_OP_BATCH_DELETE,
42    RDS_STORE_OP_BATCH_GET, RDS_STORE_OP_BATCH_PUT, RDS_STORE_OP_RANGE_DELETE,
43    RDS_STORE_OP_RANGE_QUERY, RDS_STORE_TXN_RETRY_COUNT, RdsStore, Transaction,
44};
45use crate::rpc::KeyValue;
46use crate::rpc::store::{
47    BatchDeleteRequest, BatchDeleteResponse, BatchGetRequest, BatchGetResponse, BatchPutRequest,
48    BatchPutResponse, DeleteRangeRequest, DeleteRangeResponse, RangeRequest, RangeResponse,
49};
50
51/// TLS mode configuration for PostgreSQL connections.
52/// This mirrors the TlsMode from servers::tls to avoid circular dependencies.
53#[derive(Debug, Clone, PartialEq, Eq, Default)]
54pub enum TlsMode {
55    Disable,
56    #[default]
57    Prefer,
58    Require,
59    VerifyCa,
60    VerifyFull,
61}
62
63/// TLS configuration for PostgreSQL connections.
64/// This mirrors the TlsOption from servers::tls to avoid circular dependencies.
65#[derive(Debug, Clone, PartialEq, Eq)]
66pub struct TlsOption {
67    pub mode: TlsMode,
68    pub cert_path: String,
69    pub key_path: String,
70    pub ca_cert_path: String,
71    pub watch: bool,
72}
73
74impl Default for TlsOption {
75    fn default() -> Self {
76        TlsOption {
77            mode: TlsMode::Prefer,
78            cert_path: String::new(),
79            key_path: String::new(),
80            ca_cert_path: String::new(),
81            watch: false,
82        }
83    }
84}
85
86const PG_STORE_NAME: &str = "pg_store";
87
88pub struct PgClient(deadpool::managed::Object<deadpool_postgres::Manager>);
89pub struct PgTxnClient<'a>(deadpool_postgres::Transaction<'a>);
90
91/// Converts a row to a [`KeyValue`].
92fn key_value_from_row(r: Row) -> KeyValue {
93    KeyValue {
94        key: r.get(0),
95        value: r.get(1),
96    }
97}
98
99const EMPTY: &[u8] = &[0];
100
101/// Type of range template.
102#[derive(Debug, Clone, Copy, AsRefStr)]
103enum RangeTemplateType {
104    Point,
105    Range,
106    Full,
107    LeftBounded,
108    Prefix,
109}
110
111/// Builds params for the given range template type.
112impl RangeTemplateType {
113    /// Builds the parameters for the given range template type.
114    /// You can check out the conventions at [RangeRequest]
115    fn build_params(&self, mut key: Vec<u8>, range_end: Vec<u8>) -> Vec<Vec<u8>> {
116        match self {
117            RangeTemplateType::Point => vec![key],
118            RangeTemplateType::Range => vec![key, range_end],
119            RangeTemplateType::Full => vec![],
120            RangeTemplateType::LeftBounded => vec![key],
121            RangeTemplateType::Prefix => {
122                key.push(b'%');
123                vec![key]
124            }
125        }
126    }
127}
128
129/// Templates for range request.
130#[derive(Debug, Clone)]
131struct RangeTemplate {
132    point: String,
133    range: String,
134    full: String,
135    left_bounded: String,
136    prefix: String,
137}
138
139impl RangeTemplate {
140    /// Gets the template for the given type.
141    fn get(&self, typ: RangeTemplateType) -> &str {
142        match typ {
143            RangeTemplateType::Point => &self.point,
144            RangeTemplateType::Range => &self.range,
145            RangeTemplateType::Full => &self.full,
146            RangeTemplateType::LeftBounded => &self.left_bounded,
147            RangeTemplateType::Prefix => &self.prefix,
148        }
149    }
150
151    /// Adds limit to the template.
152    fn with_limit(template: &str, limit: i64) -> String {
153        if limit == 0 {
154            return format!("{};", template);
155        }
156        format!("{} LIMIT {};", template, limit)
157    }
158}
159
160fn is_prefix_range(start: &[u8], end: &[u8]) -> bool {
161    if start.len() != end.len() {
162        return false;
163    }
164    let l = start.len();
165    let same_prefix = start[0..l - 1] == end[0..l - 1];
166    if let (Some(rhs), Some(lhs)) = (start.last(), end.last()) {
167        return same_prefix && (*rhs + 1) == *lhs;
168    }
169    false
170}
171
172/// Determine the template type for range request.
173fn range_template(key: &[u8], range_end: &[u8]) -> RangeTemplateType {
174    match (key, range_end) {
175        (_, &[]) => RangeTemplateType::Point,
176        (EMPTY, EMPTY) => RangeTemplateType::Full,
177        (_, EMPTY) => RangeTemplateType::LeftBounded,
178        (start, end) => {
179            if is_prefix_range(start, end) {
180                RangeTemplateType::Prefix
181            } else {
182                RangeTemplateType::Range
183            }
184        }
185    }
186}
187
188/// Generate in placeholders for PostgreSQL.
189fn pg_generate_in_placeholders(from: usize, to: usize) -> Vec<String> {
190    (from..=to).map(|i| format!("${}", i)).collect()
191}
192
193/// Factory for building sql templates.
194struct PgSqlTemplateFactory<'a> {
195    schema_name: Option<&'a str>,
196    table_name: &'a str,
197}
198
199impl<'a> PgSqlTemplateFactory<'a> {
200    /// Creates a new factory with optional schema.
201    fn new(schema_name: Option<&'a str>, table_name: &'a str) -> Self {
202        Self {
203            schema_name,
204            table_name,
205        }
206    }
207
208    /// Builds the template set for the given table name.
209    fn build(&self) -> PgSqlTemplateSet {
210        let table_ident = Self::format_table_ident(self.schema_name, self.table_name);
211        // Some of queries don't end with `;`, because we need to add `LIMIT` clause.
212        PgSqlTemplateSet {
213            table_ident: table_ident.clone(),
214            // Do not attempt to create schema implicitly to avoid extra privileges requirement.
215            create_table_statement: format!(
216                "CREATE TABLE IF NOT EXISTS {table_ident}(k bytea PRIMARY KEY, v bytea)",
217            ),
218            range_template: RangeTemplate {
219                point: format!("SELECT k, v FROM {table_ident} WHERE k = $1"),
220                range: format!(
221                    "SELECT k, v FROM {table_ident} WHERE k >= $1 AND k < $2 ORDER BY k"
222                ),
223                full: format!("SELECT k, v FROM {table_ident} ORDER BY k"),
224                left_bounded: format!("SELECT k, v FROM {table_ident} WHERE k >= $1 ORDER BY k"),
225                prefix: format!("SELECT k, v FROM {table_ident} WHERE k LIKE $1 ORDER BY k"),
226            },
227            delete_template: RangeTemplate {
228                point: format!("DELETE FROM {table_ident} WHERE k = $1 RETURNING k,v;"),
229                range: format!("DELETE FROM {table_ident} WHERE k >= $1 AND k < $2 RETURNING k,v;"),
230                full: format!("DELETE FROM {table_ident} RETURNING k,v"),
231                left_bounded: format!("DELETE FROM {table_ident} WHERE k >= $1 RETURNING k,v;"),
232                prefix: format!("DELETE FROM {table_ident} WHERE k LIKE $1 RETURNING k,v;"),
233            },
234        }
235    }
236
237    /// Formats the table reference with schema if provided.
238    fn format_table_ident(schema_name: Option<&str>, table_name: &str) -> String {
239        match schema_name {
240            Some(s) if !s.is_empty() => format!("\"{}\".\"{}\"", s, table_name),
241            _ => format!("\"{}\"", table_name),
242        }
243    }
244}
245
246/// Templates for the given table name.
247#[derive(Debug, Clone)]
248pub struct PgSqlTemplateSet {
249    table_ident: String,
250    create_table_statement: String,
251    range_template: RangeTemplate,
252    delete_template: RangeTemplate,
253}
254
255impl PgSqlTemplateSet {
256    /// Generates the sql for batch get.
257    fn generate_batch_get_query(&self, key_len: usize) -> String {
258        let in_clause = pg_generate_in_placeholders(1, key_len).join(", ");
259        format!(
260            "SELECT k, v FROM {} WHERE k in ({});",
261            self.table_ident, in_clause
262        )
263    }
264
265    /// Generates the sql for batch delete.
266    fn generate_batch_delete_query(&self, key_len: usize) -> String {
267        let in_clause = pg_generate_in_placeholders(1, key_len).join(", ");
268        format!(
269            "DELETE FROM {} WHERE k in ({}) RETURNING k,v;",
270            self.table_ident, in_clause
271        )
272    }
273
274    /// Generates the sql for batch upsert.
275    fn generate_batch_upsert_query(&self, kv_len: usize) -> String {
276        let in_placeholders: Vec<String> = (1..=kv_len).map(|i| format!("${}", i)).collect();
277        let in_clause = in_placeholders.join(", ");
278        let mut param_index = kv_len + 1;
279        let mut values_placeholders = Vec::new();
280        for _ in 0..kv_len {
281            values_placeholders.push(format!("(${0}, ${1})", param_index, param_index + 1));
282            param_index += 2;
283        }
284        let values_clause = values_placeholders.join(", ");
285
286        format!(
287            r#"
288    WITH prev AS (
289        SELECT k,v FROM {table} WHERE k IN ({in_clause})
290    ), update AS (
291    INSERT INTO {table} (k, v) VALUES
292        {values_clause}
293    ON CONFLICT (
294        k
295    ) DO UPDATE SET
296        v = excluded.v
297    )
298
299    SELECT k, v FROM prev;
300    "#,
301            table = self.table_ident,
302            in_clause = in_clause,
303            values_clause = values_clause
304        )
305    }
306}
307
308#[async_trait::async_trait]
309impl Executor for PgClient {
310    type Transaction<'a>
311        = PgTxnClient<'a>
312    where
313        Self: 'a;
314
315    fn name() -> &'static str {
316        "Postgres"
317    }
318
319    async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>> {
320        let params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
321        let stmt = self
322            .0
323            .prepare_cached(query)
324            .await
325            .context(PostgresExecutionSnafu { sql: query })?;
326        let rows = self
327            .0
328            .query(&stmt, &params)
329            .await
330            .context(PostgresExecutionSnafu { sql: query })?;
331        Ok(rows.into_iter().map(key_value_from_row).collect())
332    }
333
334    async fn txn_executor<'a>(&'a mut self) -> Result<Self::Transaction<'a>> {
335        let txn = self
336            .0
337            .build_transaction()
338            .isolation_level(IsolationLevel::Serializable)
339            .start()
340            .await
341            .context(PostgresTransactionSnafu {
342                operation: "begin".to_string(),
343            })?;
344        Ok(PgTxnClient(txn))
345    }
346}
347
348#[async_trait::async_trait]
349impl<'a> Transaction<'a> for PgTxnClient<'a> {
350    async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>> {
351        let params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
352        let stmt = self
353            .0
354            .prepare_cached(query)
355            .await
356            .context(PostgresExecutionSnafu { sql: query })?;
357        let rows = self
358            .0
359            .query(&stmt, &params)
360            .await
361            .context(PostgresExecutionSnafu { sql: query })?;
362        Ok(rows.into_iter().map(key_value_from_row).collect())
363    }
364
365    async fn commit(self) -> Result<()> {
366        self.0.commit().await.context(PostgresTransactionSnafu {
367            operation: "commit",
368        })?;
369        Ok(())
370    }
371}
372
373pub struct PgExecutorFactory {
374    pool: Pool,
375}
376
377impl PgExecutorFactory {
378    async fn client(&self) -> Result<PgClient> {
379        match self.pool.get().await {
380            Ok(client) => Ok(PgClient(client)),
381            Err(e) => GetPostgresConnectionSnafu {
382                reason: e.to_string(),
383            }
384            .fail(),
385        }
386    }
387}
388
389#[async_trait::async_trait]
390impl ExecutorFactory<PgClient> for PgExecutorFactory {
391    async fn default_executor(&self) -> Result<PgClient> {
392        self.client().await
393    }
394
395    async fn txn_executor<'a>(
396        &self,
397        default_executor: &'a mut PgClient,
398    ) -> Result<PgTxnClient<'a>> {
399        default_executor.txn_executor().await
400    }
401}
402
403/// A PostgreSQL-backed key-value store for metasrv.
404/// It uses [deadpool_postgres::Pool] as the connection pool for [RdsStore].
405pub type PgStore = RdsStore<PgClient, PgExecutorFactory, PgSqlTemplateSet>;
406
407/// Creates a PostgreSQL TLS connector based on the provided configuration.
408///
409/// This function creates a rustls-based TLS connector for PostgreSQL connections,
410/// following PostgreSQL's TLS mode specifications exactly:
411///
412/// # TLS Modes (PostgreSQL Specification)
413///
414/// - `Disable`: No TLS connection attempted
415/// - `Prefer`: Try TLS first, fallback to plaintext if TLS fails (handled by connection logic)
416/// - `Require`: Only TLS connections, but NO certificate verification (accept any cert)
417/// - `VerifyCa`: TLS + verify certificate is signed by trusted CA (no hostname verification)
418/// - `VerifyFull`: TLS + verify CA + verify hostname matches certificate SAN
419///
420pub fn create_postgres_tls_connector(tls_config: &TlsOption) -> Result<MakeRustlsConnect> {
421    common_telemetry::info!(
422        "Creating PostgreSQL TLS connector with mode: {:?}",
423        tls_config.mode
424    );
425
426    let config_builder = match tls_config.mode {
427        TlsMode::Disable => {
428            return PostgresTlsConfigSnafu {
429                reason: "Cannot create TLS connector for Disable mode".to_string(),
430            }
431            .fail();
432        }
433        TlsMode::Prefer | TlsMode::Require => {
434            // For Prefer/Require: Accept any certificate (no verification)
435            let verifier = Arc::new(AcceptAnyVerifier);
436            ClientConfig::builder()
437                .dangerous()
438                .with_custom_certificate_verifier(verifier)
439        }
440        TlsMode::VerifyCa => {
441            // For VerifyCa: Verify server cert against CA store, but skip hostname verification
442            let ca_store = load_ca(&tls_config.ca_cert_path)?;
443            let verifier = Arc::new(NoHostnameVerification { roots: ca_store });
444            ClientConfig::builder()
445                .dangerous()
446                .with_custom_certificate_verifier(verifier)
447        }
448        TlsMode::VerifyFull => {
449            let ca_store = load_ca(&tls_config.ca_cert_path)?;
450            ClientConfig::builder().with_root_certificates(ca_store)
451        }
452    };
453
454    // Create the TLS client configuration based on the mode and client cert requirements
455    let client_config = if !tls_config.cert_path.is_empty() && !tls_config.key_path.is_empty() {
456        // Client certificate authentication required
457        common_telemetry::info!("Loading client certificate for mutual TLS");
458        let cert_chain = load_certs(&tls_config.cert_path)?;
459        let private_key = load_private_key(&tls_config.key_path)?;
460
461        config_builder
462            .with_client_auth_cert(cert_chain, private_key)
463            .map_err(|e| {
464                PostgresTlsConfigSnafu {
465                    reason: format!("Failed to configure client authentication: {}", e),
466                }
467                .build()
468            })?
469    } else {
470        common_telemetry::info!("No client certificate provided, skip client authentication");
471        config_builder.with_no_client_auth()
472    };
473
474    common_telemetry::info!("Successfully created PostgreSQL TLS connector");
475    Ok(MakeRustlsConnect::new(client_config))
476}
477
478/// For Prefer/Require mode, we accept any server certificate without verification.
479#[derive(Debug)]
480struct AcceptAnyVerifier;
481
482impl ServerCertVerifier for AcceptAnyVerifier {
483    fn verify_server_cert(
484        &self,
485        _end_entity: &CertificateDer<'_>,
486        _intermediates: &[CertificateDer<'_>],
487        _server_name: &ServerName<'_>,
488        _ocsp_response: &[u8],
489        _now: UnixTime,
490    ) -> std::result::Result<ServerCertVerified, TlsError> {
491        common_telemetry::debug!(
492            "Accepting server certificate without verification (Prefer/Require mode)"
493        );
494        Ok(ServerCertVerified::assertion())
495    }
496
497    fn verify_tls12_signature(
498        &self,
499        _message: &[u8],
500        _cert: &CertificateDer<'_>,
501        _dss: &DigitallySignedStruct,
502    ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
503        // Accept any signature without verification
504        Ok(HandshakeSignatureValid::assertion())
505    }
506
507    fn verify_tls13_signature(
508        &self,
509        _message: &[u8],
510        _cert: &CertificateDer<'_>,
511        _dss: &DigitallySignedStruct,
512    ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
513        // Accept any signature without verification
514        Ok(HandshakeSignatureValid::assertion())
515    }
516
517    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
518        // Support all signature schemes
519        rustls::crypto::ring::default_provider()
520            .signature_verification_algorithms
521            .supported_schemes()
522    }
523}
524
525/// For VerifyCa mode, we verify the server certificate against our CA store
526/// and skip verify server's HostName.
527#[derive(Debug)]
528struct NoHostnameVerification {
529    roots: Arc<rustls::RootCertStore>,
530}
531
532impl ServerCertVerifier for NoHostnameVerification {
533    fn verify_server_cert(
534        &self,
535        end_entity: &CertificateDer<'_>,
536        intermediates: &[CertificateDer<'_>],
537        _server_name: &ServerName<'_>,
538        _ocsp_response: &[u8],
539        now: UnixTime,
540    ) -> std::result::Result<ServerCertVerified, TlsError> {
541        let cert = ParsedCertificate::try_from(end_entity)?;
542        rustls::client::verify_server_cert_signed_by_trust_anchor(
543            &cert,
544            &self.roots,
545            intermediates,
546            now,
547            rustls::crypto::ring::default_provider()
548                .signature_verification_algorithms
549                .all,
550        )?;
551
552        Ok(ServerCertVerified::assertion())
553    }
554
555    fn verify_tls12_signature(
556        &self,
557        message: &[u8],
558        cert: &CertificateDer<'_>,
559        dss: &DigitallySignedStruct,
560    ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
561        rustls::crypto::verify_tls12_signature(
562            message,
563            cert,
564            dss,
565            &rustls::crypto::ring::default_provider().signature_verification_algorithms,
566        )
567    }
568
569    fn verify_tls13_signature(
570        &self,
571        message: &[u8],
572        cert: &CertificateDer<'_>,
573        dss: &DigitallySignedStruct,
574    ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
575        rustls::crypto::verify_tls13_signature(
576            message,
577            cert,
578            dss,
579            &rustls::crypto::ring::default_provider().signature_verification_algorithms,
580        )
581    }
582
583    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
584        // Support all signature schemes
585        rustls::crypto::ring::default_provider()
586            .signature_verification_algorithms
587            .supported_schemes()
588    }
589}
590
591fn load_certs(path: &str) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
592    let file = File::open(path).context(LoadTlsCertificateSnafu { path })?;
593    let mut reader = BufReader::new(file);
594    let certs = certs(&mut reader)
595        .collect::<std::result::Result<Vec<_>, _>>()
596        .map_err(|e| {
597            PostgresTlsConfigSnafu {
598                reason: format!("Failed to parse certificates from {}: {}", path, e),
599            }
600            .build()
601        })?;
602    Ok(certs)
603}
604
605fn load_private_key(path: &str) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
606    let file = File::open(path).context(LoadTlsCertificateSnafu { path })?;
607    let mut reader = BufReader::new(file);
608    let key = private_key(&mut reader)
609        .map_err(|e| {
610            PostgresTlsConfigSnafu {
611                reason: format!("Failed to parse private key from {}: {}", path, e),
612            }
613            .build()
614        })?
615        .ok_or_else(|| {
616            PostgresTlsConfigSnafu {
617                reason: format!("No private key found in {}", path),
618            }
619            .build()
620        })?;
621    Ok(key)
622}
623
624fn load_ca(path: &str) -> Result<Arc<rustls::RootCertStore>> {
625    let mut root_store = rustls::RootCertStore::empty();
626
627    // Add system root certificates
628    match rustls_native_certs::load_native_certs() {
629        Ok(certs) => {
630            let num_certs = certs.len();
631            for cert in certs {
632                if let Err(e) = root_store.add(cert) {
633                    return PostgresTlsConfigSnafu {
634                        reason: format!("Failed to add root certificate: {}", e),
635                    }
636                    .fail();
637                }
638            }
639            common_telemetry::info!("Loaded {num_certs} system root certificates successfully");
640        }
641        Err(e) => {
642            return PostgresTlsConfigSnafu {
643                reason: format!("Failed to load system root certificates: {}", e),
644            }
645            .fail();
646        }
647    }
648
649    // Try add custom CA certificate if provided
650    if !path.is_empty() {
651        let ca_certs = load_certs(path)?;
652        for cert in ca_certs {
653            if let Err(e) = root_store.add(cert) {
654                return PostgresTlsConfigSnafu {
655                    reason: format!("Failed to add custom CA certificate: {}", e),
656                }
657                .fail();
658            }
659        }
660        common_telemetry::info!("Added custom CA certificate from {}", path);
661    }
662
663    Ok(Arc::new(root_store))
664}
665
666#[async_trait::async_trait]
667impl KvQueryExecutor<PgClient> for PgStore {
668    async fn range_with_query_executor(
669        &self,
670        query_executor: &mut ExecutorImpl<'_, PgClient>,
671        req: RangeRequest,
672    ) -> Result<RangeResponse> {
673        let template_type = range_template(&req.key, &req.range_end);
674        let template = self.sql_template_set.range_template.get(template_type);
675        let params = template_type.build_params(req.key, req.range_end);
676        let params_ref = params.iter().collect::<Vec<_>>();
677        // Always add 1 to limit to check if there is more data
678        let query =
679            RangeTemplate::with_limit(template, if req.limit == 0 { 0 } else { req.limit + 1 });
680        let limit = req.limit as usize;
681        debug!("query: {:?}, params: {:?}", query, params);
682        let mut kvs = crate::record_rds_sql_execute_elapsed!(
683            query_executor.query(&query, &params_ref).await,
684            PG_STORE_NAME,
685            RDS_STORE_OP_RANGE_QUERY,
686            template_type.as_ref()
687        )?;
688
689        if req.keys_only {
690            kvs.iter_mut().for_each(|kv| kv.value = vec![]);
691        }
692        // If limit is 0, we always return all data
693        if limit == 0 || kvs.len() <= limit {
694            return Ok(RangeResponse { kvs, more: false });
695        }
696        // If limit is greater than the number of rows, we remove the last row and set more to true
697        let removed = kvs.pop();
698        debug_assert!(removed.is_some());
699        Ok(RangeResponse { kvs, more: true })
700    }
701
702    async fn batch_put_with_query_executor(
703        &self,
704        query_executor: &mut ExecutorImpl<'_, PgClient>,
705        req: BatchPutRequest,
706    ) -> Result<BatchPutResponse> {
707        let mut in_params = Vec::with_capacity(req.kvs.len() * 3);
708        let mut values_params = Vec::with_capacity(req.kvs.len() * 2);
709
710        for kv in &req.kvs {
711            let processed_key = &kv.key;
712            in_params.push(processed_key);
713
714            let processed_value = &kv.value;
715            values_params.push(processed_key);
716            values_params.push(processed_value);
717        }
718        in_params.extend(values_params);
719        let params = in_params.iter().map(|x| x as _).collect::<Vec<_>>();
720        let query = self
721            .sql_template_set
722            .generate_batch_upsert_query(req.kvs.len());
723
724        let kvs = crate::record_rds_sql_execute_elapsed!(
725            query_executor.query(&query, &params).await,
726            PG_STORE_NAME,
727            RDS_STORE_OP_BATCH_PUT,
728            ""
729        )?;
730        if req.prev_kv {
731            Ok(BatchPutResponse { prev_kvs: kvs })
732        } else {
733            Ok(BatchPutResponse::default())
734        }
735    }
736
737    /// Batch get with certain client. It's needed for a client with transaction.
738    async fn batch_get_with_query_executor(
739        &self,
740        query_executor: &mut ExecutorImpl<'_, PgClient>,
741        req: BatchGetRequest,
742    ) -> Result<BatchGetResponse> {
743        if req.keys.is_empty() {
744            return Ok(BatchGetResponse { kvs: vec![] });
745        }
746        let query = self
747            .sql_template_set
748            .generate_batch_get_query(req.keys.len());
749        let params = req.keys.iter().map(|x| x as _).collect::<Vec<_>>();
750        let kvs = crate::record_rds_sql_execute_elapsed!(
751            query_executor.query(&query, &params).await,
752            PG_STORE_NAME,
753            RDS_STORE_OP_BATCH_GET,
754            ""
755        )?;
756        Ok(BatchGetResponse { kvs })
757    }
758
759    async fn delete_range_with_query_executor(
760        &self,
761        query_executor: &mut ExecutorImpl<'_, PgClient>,
762        req: DeleteRangeRequest,
763    ) -> Result<DeleteRangeResponse> {
764        let template_type = range_template(&req.key, &req.range_end);
765        let template = self.sql_template_set.delete_template.get(template_type);
766        let params = template_type.build_params(req.key, req.range_end);
767        let params_ref = params.iter().map(|x| x as _).collect::<Vec<_>>();
768        let kvs = crate::record_rds_sql_execute_elapsed!(
769            query_executor.query(template, &params_ref).await,
770            PG_STORE_NAME,
771            RDS_STORE_OP_RANGE_DELETE,
772            template_type.as_ref()
773        )?;
774        let mut resp = DeleteRangeResponse::new(kvs.len() as i64);
775        if req.prev_kv {
776            resp.with_prev_kvs(kvs);
777        }
778        Ok(resp)
779    }
780
781    async fn batch_delete_with_query_executor(
782        &self,
783        query_executor: &mut ExecutorImpl<'_, PgClient>,
784        req: BatchDeleteRequest,
785    ) -> Result<BatchDeleteResponse> {
786        if req.keys.is_empty() {
787            return Ok(BatchDeleteResponse::default());
788        }
789        let query = self
790            .sql_template_set
791            .generate_batch_delete_query(req.keys.len());
792        let params = req.keys.iter().map(|x| x as _).collect::<Vec<_>>();
793
794        let kvs = crate::record_rds_sql_execute_elapsed!(
795            query_executor.query(&query, &params).await,
796            PG_STORE_NAME,
797            RDS_STORE_OP_BATCH_DELETE,
798            ""
799        )?;
800        if req.prev_kv {
801            Ok(BatchDeleteResponse { prev_kvs: kvs })
802        } else {
803            Ok(BatchDeleteResponse::default())
804        }
805    }
806}
807
808impl PgStore {
809    /// Create [PgStore] impl of [KvBackendRef] from url with optional TLS support.
810    ///
811    /// # Arguments
812    ///
813    /// * `url` - PostgreSQL connection URL
814    /// * `table_name` - Name of the table to use for key-value storage
815    /// * `max_txn_ops` - Maximum number of operations per transaction
816    /// * `tls_config` - Optional TLS configuration. If None, uses plaintext connection.
817    pub async fn with_url_and_tls(
818        url: &str,
819        table_name: &str,
820        max_txn_ops: usize,
821        tls_config: Option<TlsOption>,
822    ) -> Result<KvBackendRef> {
823        let mut cfg = Config::new();
824        cfg.url = Some(url.to_string());
825
826        let pool = match tls_config {
827            Some(tls_config) if tls_config.mode != TlsMode::Disable => {
828                match create_postgres_tls_connector(&tls_config) {
829                    Ok(tls_connector) => cfg
830                        .create_pool(Some(Runtime::Tokio1), tls_connector)
831                        .context(CreatePostgresPoolSnafu)?,
832                    Err(e) => {
833                        if tls_config.mode == TlsMode::Prefer {
834                            // Fallback to insecure connection if TLS fails
835                            common_telemetry::info!(
836                                "Failed to create TLS connector, falling back to insecure connection"
837                            );
838                            cfg.create_pool(Some(Runtime::Tokio1), NoTls)
839                                .context(CreatePostgresPoolSnafu)?
840                        } else {
841                            return Err(e);
842                        }
843                    }
844                }
845            }
846            _ => cfg
847                .create_pool(Some(Runtime::Tokio1), NoTls)
848                .context(CreatePostgresPoolSnafu)?,
849        };
850
851        Self::with_pg_pool(pool, None, table_name, max_txn_ops).await
852    }
853
854    /// Create [PgStore] impl of [KvBackendRef] from url (backward compatibility).
855    pub async fn with_url(url: &str, table_name: &str, max_txn_ops: usize) -> Result<KvBackendRef> {
856        Self::with_url_and_tls(url, table_name, max_txn_ops, None).await
857    }
858
859    /// Create [PgStore] impl of [KvBackendRef] from [deadpool_postgres::Pool] with optional schema.
860    pub async fn with_pg_pool(
861        pool: Pool,
862        schema_name: Option<&str>,
863        table_name: &str,
864        max_txn_ops: usize,
865    ) -> Result<KvBackendRef> {
866        // Ensure the postgres metadata backend is ready to use.
867        let client = match pool.get().await {
868            Ok(client) => client,
869            Err(e) => {
870                return GetPostgresConnectionSnafu {
871                    reason: e.to_string(),
872                }
873                .fail();
874            }
875        };
876        let template_factory = PgSqlTemplateFactory::new(schema_name, table_name);
877        let sql_template_set = template_factory.build();
878        // Do not attempt to create schema implicitly.
879        client
880            .execute(&sql_template_set.create_table_statement, &[])
881            .await
882            .with_context(|_| PostgresExecutionSnafu {
883                sql: sql_template_set.create_table_statement.to_string(),
884            })?;
885        Ok(Arc::new(Self {
886            max_txn_ops,
887            sql_template_set,
888            txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
889            executor_factory: PgExecutorFactory { pool },
890            _phantom: PhantomData,
891        }))
892    }
893}
894
895#[cfg(test)]
896mod tests {
897    use super::*;
898    use crate::kv_backend::test::{
899        prepare_kv_with_prefix, test_kv_batch_delete_with_prefix, test_kv_batch_get_with_prefix,
900        test_kv_compare_and_put_with_prefix, test_kv_delete_range_with_prefix,
901        test_kv_put_with_prefix, test_kv_range_2_with_prefix, test_kv_range_with_prefix,
902        test_simple_kv_range, test_txn_compare_equal, test_txn_compare_greater,
903        test_txn_compare_less, test_txn_compare_not_equal, test_txn_one_compare_op,
904        text_txn_multi_compare_op, unprepare_kv,
905    };
906    use crate::{maybe_skip_postgres_integration_test, maybe_skip_postgres15_integration_test};
907
908    async fn build_pg_kv_backend(table_name: &str) -> Option<PgStore> {
909        let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap_or_default();
910        if endpoints.is_empty() {
911            return None;
912        }
913
914        let mut cfg = Config::new();
915        cfg.url = Some(endpoints);
916        let pool = cfg
917            .create_pool(Some(Runtime::Tokio1), NoTls)
918            .context(CreatePostgresPoolSnafu)
919            .unwrap();
920        let client = pool.get().await.unwrap();
921        // use the default schema (i.e., public)
922        let template_factory = PgSqlTemplateFactory::new(None, table_name);
923        let sql_templates = template_factory.build();
924        // Do not attempt to create schema implicitly.
925        client
926            .execute(&sql_templates.create_table_statement, &[])
927            .await
928            .context(PostgresExecutionSnafu {
929                sql: sql_templates.create_table_statement.to_string(),
930            })
931            .unwrap();
932        Some(PgStore {
933            max_txn_ops: 128,
934            sql_template_set: sql_templates,
935            txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
936            executor_factory: PgExecutorFactory { pool },
937            _phantom: PhantomData,
938        })
939    }
940
941    async fn build_pg15_pool() -> Option<Pool> {
942        let url = std::env::var("GT_POSTGRES15_ENDPOINTS").unwrap_or_default();
943        if url.is_empty() {
944            return None;
945        }
946        let mut cfg = Config::new();
947        cfg.url = Some(url);
948        let pool = cfg
949            .create_pool(Some(Runtime::Tokio1), NoTls)
950            .context(CreatePostgresPoolSnafu)
951            .ok()?;
952        Some(pool)
953    }
954
955    #[tokio::test]
956    async fn test_pg15_create_table_in_public_should_fail() {
957        maybe_skip_postgres15_integration_test!();
958        let Some(pool) = build_pg15_pool().await else {
959            return;
960        };
961        let res = PgStore::with_pg_pool(pool, None, "pg15_public_should_fail", 128).await;
962        assert!(
963            res.is_err(),
964            "creating table in public should fail for test_user"
965        );
966    }
967
968    #[tokio::test]
969    async fn test_pg15_create_table_in_test_schema_and_crud_should_succeed() {
970        maybe_skip_postgres15_integration_test!();
971        let Some(pool) = build_pg15_pool().await else {
972            return;
973        };
974        let schema_name = std::env::var("GT_POSTGRES15_SCHEMA").unwrap();
975        let client = pool.get().await.unwrap();
976        let factory = PgSqlTemplateFactory::new(Some(&schema_name), "pg15_ok");
977        let templates = factory.build();
978        client
979            .execute(&templates.create_table_statement, &[])
980            .await
981            .unwrap();
982        let kv = PgStore {
983            max_txn_ops: 128,
984            sql_template_set: templates,
985            txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
986            executor_factory: PgExecutorFactory { pool },
987            _phantom: PhantomData,
988        };
989        let prefix = b"pg15_crud/";
990        prepare_kv_with_prefix(&kv, prefix.to_vec()).await;
991        test_kv_put_with_prefix(&kv, prefix.to_vec()).await;
992        test_kv_batch_get_with_prefix(&kv, prefix.to_vec()).await;
993        unprepare_kv(&kv, prefix).await;
994    }
995
996    #[tokio::test]
997    async fn test_pg_put() {
998        maybe_skip_postgres_integration_test!();
999        let kv_backend = build_pg_kv_backend("put_test").await.unwrap();
1000        let prefix = b"put/";
1001        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1002        test_kv_put_with_prefix(&kv_backend, prefix.to_vec()).await;
1003        unprepare_kv(&kv_backend, prefix).await;
1004    }
1005
1006    #[tokio::test]
1007    async fn test_pg_range() {
1008        maybe_skip_postgres_integration_test!();
1009        let kv_backend = build_pg_kv_backend("range_test").await.unwrap();
1010        let prefix = b"range/";
1011        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1012        test_kv_range_with_prefix(&kv_backend, prefix.to_vec()).await;
1013        unprepare_kv(&kv_backend, prefix).await;
1014    }
1015
1016    #[tokio::test]
1017    async fn test_pg_range_2() {
1018        maybe_skip_postgres_integration_test!();
1019        let kv_backend = build_pg_kv_backend("range2_test").await.unwrap();
1020        let prefix = b"range2/";
1021        test_kv_range_2_with_prefix(&kv_backend, prefix.to_vec()).await;
1022        unprepare_kv(&kv_backend, prefix).await;
1023    }
1024
1025    #[tokio::test]
1026    async fn test_pg_all_range() {
1027        maybe_skip_postgres_integration_test!();
1028        let kv_backend = build_pg_kv_backend("simple_range_test").await.unwrap();
1029        let prefix = b"";
1030        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1031        test_simple_kv_range(&kv_backend).await;
1032        unprepare_kv(&kv_backend, prefix).await;
1033    }
1034
1035    #[tokio::test]
1036    async fn test_pg_batch_get() {
1037        maybe_skip_postgres_integration_test!();
1038        let kv_backend = build_pg_kv_backend("batch_get_test").await.unwrap();
1039        let prefix = b"batch_get/";
1040        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1041        test_kv_batch_get_with_prefix(&kv_backend, prefix.to_vec()).await;
1042        unprepare_kv(&kv_backend, prefix).await;
1043    }
1044
1045    #[tokio::test]
1046    async fn test_pg_batch_delete() {
1047        maybe_skip_postgres_integration_test!();
1048        let kv_backend = build_pg_kv_backend("batch_delete_test").await.unwrap();
1049        let prefix = b"batch_delete/";
1050        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1051        test_kv_delete_range_with_prefix(&kv_backend, prefix.to_vec()).await;
1052        unprepare_kv(&kv_backend, prefix).await;
1053    }
1054
1055    #[tokio::test]
1056    async fn test_pg_batch_delete_with_prefix() {
1057        maybe_skip_postgres_integration_test!();
1058        let kv_backend = build_pg_kv_backend("batch_delete_with_prefix_test")
1059            .await
1060            .unwrap();
1061        let prefix = b"batch_delete/";
1062        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1063        test_kv_batch_delete_with_prefix(&kv_backend, prefix.to_vec()).await;
1064        unprepare_kv(&kv_backend, prefix).await;
1065    }
1066
1067    #[tokio::test]
1068    async fn test_pg_delete_range() {
1069        maybe_skip_postgres_integration_test!();
1070        let kv_backend = build_pg_kv_backend("delete_range_test").await.unwrap();
1071        let prefix = b"delete_range/";
1072        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1073        test_kv_delete_range_with_prefix(&kv_backend, prefix.to_vec()).await;
1074        unprepare_kv(&kv_backend, prefix).await;
1075    }
1076
1077    #[tokio::test]
1078    async fn test_pg_compare_and_put() {
1079        maybe_skip_postgres_integration_test!();
1080        let kv_backend = build_pg_kv_backend("compare_and_put_test").await.unwrap();
1081        let prefix = b"compare_and_put/";
1082        let kv_backend = Arc::new(kv_backend);
1083        test_kv_compare_and_put_with_prefix(kv_backend.clone(), prefix.to_vec()).await;
1084    }
1085
1086    #[tokio::test]
1087    async fn test_pg_txn() {
1088        maybe_skip_postgres_integration_test!();
1089        let kv_backend = build_pg_kv_backend("txn_test").await.unwrap();
1090        test_txn_one_compare_op(&kv_backend).await;
1091        text_txn_multi_compare_op(&kv_backend).await;
1092        test_txn_compare_equal(&kv_backend).await;
1093        test_txn_compare_greater(&kv_backend).await;
1094        test_txn_compare_less(&kv_backend).await;
1095        test_txn_compare_not_equal(&kv_backend).await;
1096    }
1097
1098    #[test]
1099    fn test_pg_template_with_schema() {
1100        let factory = PgSqlTemplateFactory::new(Some("test_schema"), "greptime_metakv");
1101        let t = factory.build();
1102        assert!(
1103            t.create_table_statement
1104                .contains("\"test_schema\".\"greptime_metakv\"")
1105        );
1106        let upsert = t.generate_batch_upsert_query(1);
1107        assert!(upsert.contains("\"test_schema\".\"greptime_metakv\""));
1108        let get = t.generate_batch_get_query(1);
1109        assert!(get.contains("\"test_schema\".\"greptime_metakv\""));
1110        let del = t.generate_batch_delete_query(1);
1111        assert!(del.contains("\"test_schema\".\"greptime_metakv\""));
1112    }
1113
1114    #[test]
1115    fn test_format_table_ident() {
1116        let t = PgSqlTemplateFactory::format_table_ident(None, "test_table");
1117        assert_eq!(t, "\"test_table\"");
1118
1119        let t = PgSqlTemplateFactory::format_table_ident(Some("test_schema"), "test_table");
1120        assert_eq!(t, "\"test_schema\".\"test_table\"");
1121
1122        let t = PgSqlTemplateFactory::format_table_ident(Some(""), "test_table");
1123        assert_eq!(t, "\"test_table\"");
1124    }
1125}