common_meta/kv_backend/rds/
postgres.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fs::File;
16use std::io::BufReader;
17use std::marker::PhantomData;
18use std::sync::Arc;
19
20use common_telemetry::debug;
21use deadpool_postgres::{Config, Pool, Runtime};
22use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
23use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
24use rustls::server::ParsedCertificate;
25// TLS-related imports (feature-gated)
26use rustls::ClientConfig;
27use rustls::{DigitallySignedStruct, Error as TlsError, SignatureScheme};
28use rustls_pemfile::{certs, private_key};
29use snafu::ResultExt;
30use strum::AsRefStr;
31use tokio_postgres::types::ToSql;
32use tokio_postgres::{IsolationLevel, NoTls, Row};
33use tokio_postgres_rustls::MakeRustlsConnect;
34
35use crate::error::{
36    CreatePostgresPoolSnafu, GetPostgresConnectionSnafu, LoadTlsCertificateSnafu,
37    PostgresExecutionSnafu, PostgresTlsConfigSnafu, PostgresTransactionSnafu, Result,
38};
39use crate::kv_backend::rds::{
40    Executor, ExecutorFactory, ExecutorImpl, KvQueryExecutor, RdsStore, Transaction,
41    RDS_STORE_OP_BATCH_DELETE, RDS_STORE_OP_BATCH_GET, RDS_STORE_OP_BATCH_PUT,
42    RDS_STORE_OP_RANGE_DELETE, RDS_STORE_OP_RANGE_QUERY, RDS_STORE_TXN_RETRY_COUNT,
43};
44use crate::kv_backend::KvBackendRef;
45use crate::rpc::store::{
46    BatchDeleteRequest, BatchDeleteResponse, BatchGetRequest, BatchGetResponse, BatchPutRequest,
47    BatchPutResponse, DeleteRangeRequest, DeleteRangeResponse, RangeRequest, RangeResponse,
48};
49use crate::rpc::KeyValue;
50
51/// TLS mode configuration for PostgreSQL connections.
52/// This mirrors the TlsMode from servers::tls to avoid circular dependencies.
53#[derive(Debug, Clone, PartialEq, Eq, Default)]
54pub enum TlsMode {
55    Disable,
56    #[default]
57    Prefer,
58    Require,
59    VerifyCa,
60    VerifyFull,
61}
62
63/// TLS configuration for PostgreSQL connections.
64/// This mirrors the TlsOption from servers::tls to avoid circular dependencies.
65#[derive(Debug, Clone, PartialEq, Eq)]
66pub struct TlsOption {
67    pub mode: TlsMode,
68    pub cert_path: String,
69    pub key_path: String,
70    pub ca_cert_path: String,
71    pub watch: bool,
72}
73
74impl Default for TlsOption {
75    fn default() -> Self {
76        TlsOption {
77            mode: TlsMode::Prefer,
78            cert_path: String::new(),
79            key_path: String::new(),
80            ca_cert_path: String::new(),
81            watch: false,
82        }
83    }
84}
85
86const PG_STORE_NAME: &str = "pg_store";
87
88pub struct PgClient(deadpool::managed::Object<deadpool_postgres::Manager>);
89pub struct PgTxnClient<'a>(deadpool_postgres::Transaction<'a>);
90
91/// Converts a row to a [`KeyValue`].
92fn key_value_from_row(r: Row) -> KeyValue {
93    KeyValue {
94        key: r.get(0),
95        value: r.get(1),
96    }
97}
98
99const EMPTY: &[u8] = &[0];
100
101/// Type of range template.
102#[derive(Debug, Clone, Copy, AsRefStr)]
103enum RangeTemplateType {
104    Point,
105    Range,
106    Full,
107    LeftBounded,
108    Prefix,
109}
110
111/// Builds params for the given range template type.
112impl RangeTemplateType {
113    /// Builds the parameters for the given range template type.
114    /// You can check out the conventions at [RangeRequest]
115    fn build_params(&self, mut key: Vec<u8>, range_end: Vec<u8>) -> Vec<Vec<u8>> {
116        match self {
117            RangeTemplateType::Point => vec![key],
118            RangeTemplateType::Range => vec![key, range_end],
119            RangeTemplateType::Full => vec![],
120            RangeTemplateType::LeftBounded => vec![key],
121            RangeTemplateType::Prefix => {
122                key.push(b'%');
123                vec![key]
124            }
125        }
126    }
127}
128
129/// Templates for range request.
130#[derive(Debug, Clone)]
131struct RangeTemplate {
132    point: String,
133    range: String,
134    full: String,
135    left_bounded: String,
136    prefix: String,
137}
138
139impl RangeTemplate {
140    /// Gets the template for the given type.
141    fn get(&self, typ: RangeTemplateType) -> &str {
142        match typ {
143            RangeTemplateType::Point => &self.point,
144            RangeTemplateType::Range => &self.range,
145            RangeTemplateType::Full => &self.full,
146            RangeTemplateType::LeftBounded => &self.left_bounded,
147            RangeTemplateType::Prefix => &self.prefix,
148        }
149    }
150
151    /// Adds limit to the template.
152    fn with_limit(template: &str, limit: i64) -> String {
153        if limit == 0 {
154            return format!("{};", template);
155        }
156        format!("{} LIMIT {};", template, limit)
157    }
158}
159
160fn is_prefix_range(start: &[u8], end: &[u8]) -> bool {
161    if start.len() != end.len() {
162        return false;
163    }
164    let l = start.len();
165    let same_prefix = start[0..l - 1] == end[0..l - 1];
166    if let (Some(rhs), Some(lhs)) = (start.last(), end.last()) {
167        return same_prefix && (*rhs + 1) == *lhs;
168    }
169    false
170}
171
172/// Determine the template type for range request.
173fn range_template(key: &[u8], range_end: &[u8]) -> RangeTemplateType {
174    match (key, range_end) {
175        (_, &[]) => RangeTemplateType::Point,
176        (EMPTY, EMPTY) => RangeTemplateType::Full,
177        (_, EMPTY) => RangeTemplateType::LeftBounded,
178        (start, end) => {
179            if is_prefix_range(start, end) {
180                RangeTemplateType::Prefix
181            } else {
182                RangeTemplateType::Range
183            }
184        }
185    }
186}
187
188/// Generate in placeholders for PostgreSQL.
189fn pg_generate_in_placeholders(from: usize, to: usize) -> Vec<String> {
190    (from..=to).map(|i| format!("${}", i)).collect()
191}
192
193/// Factory for building sql templates.
194struct PgSqlTemplateFactory<'a> {
195    table_name: &'a str,
196}
197
198impl<'a> PgSqlTemplateFactory<'a> {
199    /// Creates a new [`SqlTemplateFactory`] with the given table name.
200    fn new(table_name: &'a str) -> Self {
201        Self { table_name }
202    }
203
204    /// Builds the template set for the given table name.
205    fn build(&self) -> PgSqlTemplateSet {
206        let table_name = self.table_name;
207        // Some of queries don't end with `;`, because we need to add `LIMIT` clause.
208        PgSqlTemplateSet {
209            table_name: table_name.to_string(),
210            create_table_statement: format!(
211                "CREATE TABLE IF NOT EXISTS \"{table_name}\"(k bytea PRIMARY KEY, v bytea)",
212            ),
213            range_template: RangeTemplate {
214                point: format!("SELECT k, v FROM \"{table_name}\" WHERE k = $1"),
215                range: format!(
216                    "SELECT k, v FROM \"{table_name}\" WHERE k >= $1 AND k < $2 ORDER BY k"
217                ),
218                full: format!("SELECT k, v FROM \"{table_name}\" ORDER BY k"),
219                left_bounded: format!("SELECT k, v FROM \"{table_name}\" WHERE k >= $1 ORDER BY k"),
220                prefix: format!("SELECT k, v FROM \"{table_name}\" WHERE k LIKE $1 ORDER BY k"),
221            },
222            delete_template: RangeTemplate {
223                point: format!("DELETE FROM \"{table_name}\" WHERE k = $1 RETURNING k,v;"),
224                range: format!(
225                    "DELETE FROM \"{table_name}\" WHERE k >= $1 AND k < $2 RETURNING k,v;"
226                ),
227                full: format!("DELETE FROM \"{table_name}\" RETURNING k,v"),
228                left_bounded: format!("DELETE FROM \"{table_name}\" WHERE k >= $1 RETURNING k,v;"),
229                prefix: format!("DELETE FROM \"{table_name}\" WHERE k LIKE $1 RETURNING k,v;"),
230            },
231        }
232    }
233}
234
235/// Templates for the given table name.
236#[derive(Debug, Clone)]
237pub struct PgSqlTemplateSet {
238    table_name: String,
239    create_table_statement: String,
240    range_template: RangeTemplate,
241    delete_template: RangeTemplate,
242}
243
244impl PgSqlTemplateSet {
245    /// Generates the sql for batch get.
246    fn generate_batch_get_query(&self, key_len: usize) -> String {
247        let table_name = &self.table_name;
248        let in_clause = pg_generate_in_placeholders(1, key_len).join(", ");
249        format!(
250            "SELECT k, v FROM \"{table_name}\" WHERE k in ({});",
251            in_clause
252        )
253    }
254
255    /// Generates the sql for batch delete.
256    fn generate_batch_delete_query(&self, key_len: usize) -> String {
257        let table_name = &self.table_name;
258        let in_clause = pg_generate_in_placeholders(1, key_len).join(", ");
259        format!(
260            "DELETE FROM \"{table_name}\" WHERE k in ({}) RETURNING k,v;",
261            in_clause
262        )
263    }
264
265    /// Generates the sql for batch upsert.
266    fn generate_batch_upsert_query(&self, kv_len: usize) -> String {
267        let table_name = &self.table_name;
268        let in_placeholders: Vec<String> = (1..=kv_len).map(|i| format!("${}", i)).collect();
269        let in_clause = in_placeholders.join(", ");
270        let mut param_index = kv_len + 1;
271        let mut values_placeholders = Vec::new();
272        for _ in 0..kv_len {
273            values_placeholders.push(format!("(${0}, ${1})", param_index, param_index + 1));
274            param_index += 2;
275        }
276        let values_clause = values_placeholders.join(", ");
277
278        format!(
279            r#"
280    WITH prev AS (
281        SELECT k,v FROM "{table_name}" WHERE k IN ({in_clause})
282    ), update AS (
283    INSERT INTO "{table_name}" (k, v) VALUES
284        {values_clause}
285    ON CONFLICT (
286        k
287    ) DO UPDATE SET
288        v = excluded.v
289    )
290
291    SELECT k, v FROM prev;
292    "#
293        )
294    }
295}
296
297#[async_trait::async_trait]
298impl Executor for PgClient {
299    type Transaction<'a>
300        = PgTxnClient<'a>
301    where
302        Self: 'a;
303
304    fn name() -> &'static str {
305        "Postgres"
306    }
307
308    async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>> {
309        let params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
310        let stmt = self
311            .0
312            .prepare_cached(query)
313            .await
314            .context(PostgresExecutionSnafu { sql: query })?;
315        let rows = self
316            .0
317            .query(&stmt, &params)
318            .await
319            .context(PostgresExecutionSnafu { sql: query })?;
320        Ok(rows.into_iter().map(key_value_from_row).collect())
321    }
322
323    async fn txn_executor<'a>(&'a mut self) -> Result<Self::Transaction<'a>> {
324        let txn = self
325            .0
326            .build_transaction()
327            .isolation_level(IsolationLevel::Serializable)
328            .start()
329            .await
330            .context(PostgresTransactionSnafu {
331                operation: "begin".to_string(),
332            })?;
333        Ok(PgTxnClient(txn))
334    }
335}
336
337#[async_trait::async_trait]
338impl<'a> Transaction<'a> for PgTxnClient<'a> {
339    async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>> {
340        let params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
341        let stmt = self
342            .0
343            .prepare_cached(query)
344            .await
345            .context(PostgresExecutionSnafu { sql: query })?;
346        let rows = self
347            .0
348            .query(&stmt, &params)
349            .await
350            .context(PostgresExecutionSnafu { sql: query })?;
351        Ok(rows.into_iter().map(key_value_from_row).collect())
352    }
353
354    async fn commit(self) -> Result<()> {
355        self.0.commit().await.context(PostgresTransactionSnafu {
356            operation: "commit",
357        })?;
358        Ok(())
359    }
360}
361
362pub struct PgExecutorFactory {
363    pool: Pool,
364}
365
366impl PgExecutorFactory {
367    async fn client(&self) -> Result<PgClient> {
368        match self.pool.get().await {
369            Ok(client) => Ok(PgClient(client)),
370            Err(e) => GetPostgresConnectionSnafu {
371                reason: e.to_string(),
372            }
373            .fail(),
374        }
375    }
376}
377
378#[async_trait::async_trait]
379impl ExecutorFactory<PgClient> for PgExecutorFactory {
380    async fn default_executor(&self) -> Result<PgClient> {
381        self.client().await
382    }
383
384    async fn txn_executor<'a>(
385        &self,
386        default_executor: &'a mut PgClient,
387    ) -> Result<PgTxnClient<'a>> {
388        default_executor.txn_executor().await
389    }
390}
391
392/// A PostgreSQL-backed key-value store for metasrv.
393/// It uses [deadpool_postgres::Pool] as the connection pool for [RdsStore].
394pub type PgStore = RdsStore<PgClient, PgExecutorFactory, PgSqlTemplateSet>;
395
396/// Creates a PostgreSQL TLS connector based on the provided configuration.
397///
398/// This function creates a rustls-based TLS connector for PostgreSQL connections,
399/// following PostgreSQL's TLS mode specifications exactly:
400///
401/// # TLS Modes (PostgreSQL Specification)
402///
403/// - `Disable`: No TLS connection attempted
404/// - `Prefer`: Try TLS first, fallback to plaintext if TLS fails (handled by connection logic)
405/// - `Require`: Only TLS connections, but NO certificate verification (accept any cert)
406/// - `VerifyCa`: TLS + verify certificate is signed by trusted CA (no hostname verification)
407/// - `VerifyFull`: TLS + verify CA + verify hostname matches certificate SAN
408///
409pub fn create_postgres_tls_connector(tls_config: &TlsOption) -> Result<MakeRustlsConnect> {
410    common_telemetry::info!(
411        "Creating PostgreSQL TLS connector with mode: {:?}",
412        tls_config.mode
413    );
414
415    let config_builder = match tls_config.mode {
416        TlsMode::Disable => {
417            return PostgresTlsConfigSnafu {
418                reason: "Cannot create TLS connector for Disable mode".to_string(),
419            }
420            .fail();
421        }
422        TlsMode::Prefer | TlsMode::Require => {
423            // For Prefer/Require: Accept any certificate (no verification)
424            let verifier = Arc::new(AcceptAnyVerifier);
425            ClientConfig::builder()
426                .dangerous()
427                .with_custom_certificate_verifier(verifier)
428        }
429        TlsMode::VerifyCa => {
430            // For VerifyCa: Verify server cert against CA store, but skip hostname verification
431            let ca_store = load_ca(&tls_config.ca_cert_path)?;
432            let verifier = Arc::new(NoHostnameVerification { roots: ca_store });
433            ClientConfig::builder()
434                .dangerous()
435                .with_custom_certificate_verifier(verifier)
436        }
437        TlsMode::VerifyFull => {
438            let ca_store = load_ca(&tls_config.ca_cert_path)?;
439            ClientConfig::builder().with_root_certificates(ca_store)
440        }
441    };
442
443    // Create the TLS client configuration based on the mode and client cert requirements
444    let client_config = if !tls_config.cert_path.is_empty() && !tls_config.key_path.is_empty() {
445        // Client certificate authentication required
446        common_telemetry::info!("Loading client certificate for mutual TLS");
447        let cert_chain = load_certs(&tls_config.cert_path)?;
448        let private_key = load_private_key(&tls_config.key_path)?;
449
450        config_builder
451            .with_client_auth_cert(cert_chain, private_key)
452            .map_err(|e| {
453                PostgresTlsConfigSnafu {
454                    reason: format!("Failed to configure client authentication: {}", e),
455                }
456                .build()
457            })?
458    } else {
459        common_telemetry::info!("No client certificate provided, skip client authentication");
460        config_builder.with_no_client_auth()
461    };
462
463    common_telemetry::info!("Successfully created PostgreSQL TLS connector");
464    Ok(MakeRustlsConnect::new(client_config))
465}
466
467/// For Prefer/Require mode, we accept any server certificate without verification.
468#[derive(Debug)]
469struct AcceptAnyVerifier;
470
471impl ServerCertVerifier for AcceptAnyVerifier {
472    fn verify_server_cert(
473        &self,
474        _end_entity: &CertificateDer<'_>,
475        _intermediates: &[CertificateDer<'_>],
476        _server_name: &ServerName<'_>,
477        _ocsp_response: &[u8],
478        _now: UnixTime,
479    ) -> std::result::Result<ServerCertVerified, TlsError> {
480        common_telemetry::debug!(
481            "Accepting server certificate without verification (Prefer/Require mode)"
482        );
483        Ok(ServerCertVerified::assertion())
484    }
485
486    fn verify_tls12_signature(
487        &self,
488        _message: &[u8],
489        _cert: &CertificateDer<'_>,
490        _dss: &DigitallySignedStruct,
491    ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
492        // Accept any signature without verification
493        Ok(HandshakeSignatureValid::assertion())
494    }
495
496    fn verify_tls13_signature(
497        &self,
498        _message: &[u8],
499        _cert: &CertificateDer<'_>,
500        _dss: &DigitallySignedStruct,
501    ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
502        // Accept any signature without verification
503        Ok(HandshakeSignatureValid::assertion())
504    }
505
506    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
507        // Support all signature schemes
508        rustls::crypto::ring::default_provider()
509            .signature_verification_algorithms
510            .supported_schemes()
511    }
512}
513
514/// For VerifyCa mode, we verify the server certificate against our CA store
515/// and skip verify server's HostName.
516#[derive(Debug)]
517struct NoHostnameVerification {
518    roots: Arc<rustls::RootCertStore>,
519}
520
521impl ServerCertVerifier for NoHostnameVerification {
522    fn verify_server_cert(
523        &self,
524        end_entity: &CertificateDer<'_>,
525        intermediates: &[CertificateDer<'_>],
526        _server_name: &ServerName<'_>,
527        _ocsp_response: &[u8],
528        now: UnixTime,
529    ) -> std::result::Result<ServerCertVerified, TlsError> {
530        let cert = ParsedCertificate::try_from(end_entity)?;
531        rustls::client::verify_server_cert_signed_by_trust_anchor(
532            &cert,
533            &self.roots,
534            intermediates,
535            now,
536            rustls::crypto::ring::default_provider()
537                .signature_verification_algorithms
538                .all,
539        )?;
540
541        Ok(ServerCertVerified::assertion())
542    }
543
544    fn verify_tls12_signature(
545        &self,
546        message: &[u8],
547        cert: &CertificateDer<'_>,
548        dss: &DigitallySignedStruct,
549    ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
550        rustls::crypto::verify_tls12_signature(
551            message,
552            cert,
553            dss,
554            &rustls::crypto::ring::default_provider().signature_verification_algorithms,
555        )
556    }
557
558    fn verify_tls13_signature(
559        &self,
560        message: &[u8],
561        cert: &CertificateDer<'_>,
562        dss: &DigitallySignedStruct,
563    ) -> std::result::Result<HandshakeSignatureValid, TlsError> {
564        rustls::crypto::verify_tls13_signature(
565            message,
566            cert,
567            dss,
568            &rustls::crypto::ring::default_provider().signature_verification_algorithms,
569        )
570    }
571
572    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
573        // Support all signature schemes
574        rustls::crypto::ring::default_provider()
575            .signature_verification_algorithms
576            .supported_schemes()
577    }
578}
579
580fn load_certs(path: &str) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>> {
581    let file = File::open(path).context(LoadTlsCertificateSnafu { path })?;
582    let mut reader = BufReader::new(file);
583    let certs = certs(&mut reader)
584        .collect::<std::result::Result<Vec<_>, _>>()
585        .map_err(|e| {
586            PostgresTlsConfigSnafu {
587                reason: format!("Failed to parse certificates from {}: {}", path, e),
588            }
589            .build()
590        })?;
591    Ok(certs)
592}
593
594fn load_private_key(path: &str) -> Result<rustls::pki_types::PrivateKeyDer<'static>> {
595    let file = File::open(path).context(LoadTlsCertificateSnafu { path })?;
596    let mut reader = BufReader::new(file);
597    let key = private_key(&mut reader)
598        .map_err(|e| {
599            PostgresTlsConfigSnafu {
600                reason: format!("Failed to parse private key from {}: {}", path, e),
601            }
602            .build()
603        })?
604        .ok_or_else(|| {
605            PostgresTlsConfigSnafu {
606                reason: format!("No private key found in {}", path),
607            }
608            .build()
609        })?;
610    Ok(key)
611}
612
613fn load_ca(path: &str) -> Result<Arc<rustls::RootCertStore>> {
614    let mut root_store = rustls::RootCertStore::empty();
615
616    // Add system root certificates
617    match rustls_native_certs::load_native_certs() {
618        Ok(certs) => {
619            let num_certs = certs.len();
620            for cert in certs {
621                if let Err(e) = root_store.add(cert) {
622                    return PostgresTlsConfigSnafu {
623                        reason: format!("Failed to add root certificate: {}", e),
624                    }
625                    .fail();
626                }
627            }
628            common_telemetry::info!("Loaded {num_certs} system root certificates successfully");
629        }
630        Err(e) => {
631            return PostgresTlsConfigSnafu {
632                reason: format!("Failed to load system root certificates: {}", e),
633            }
634            .fail();
635        }
636    }
637
638    // Try add custom CA certificate if provided
639    if !path.is_empty() {
640        let ca_certs = load_certs(path)?;
641        for cert in ca_certs {
642            if let Err(e) = root_store.add(cert) {
643                return PostgresTlsConfigSnafu {
644                    reason: format!("Failed to add custom CA certificate: {}", e),
645                }
646                .fail();
647            }
648        }
649        common_telemetry::info!("Added custom CA certificate from {}", path);
650    }
651
652    Ok(Arc::new(root_store))
653}
654
655#[async_trait::async_trait]
656impl KvQueryExecutor<PgClient> for PgStore {
657    async fn range_with_query_executor(
658        &self,
659        query_executor: &mut ExecutorImpl<'_, PgClient>,
660        req: RangeRequest,
661    ) -> Result<RangeResponse> {
662        let template_type = range_template(&req.key, &req.range_end);
663        let template = self.sql_template_set.range_template.get(template_type);
664        let params = template_type.build_params(req.key, req.range_end);
665        let params_ref = params.iter().collect::<Vec<_>>();
666        // Always add 1 to limit to check if there is more data
667        let query =
668            RangeTemplate::with_limit(template, if req.limit == 0 { 0 } else { req.limit + 1 });
669        let limit = req.limit as usize;
670        debug!("query: {:?}, params: {:?}", query, params);
671        let mut kvs = crate::record_rds_sql_execute_elapsed!(
672            query_executor.query(&query, &params_ref).await,
673            PG_STORE_NAME,
674            RDS_STORE_OP_RANGE_QUERY,
675            template_type.as_ref()
676        )?;
677
678        if req.keys_only {
679            kvs.iter_mut().for_each(|kv| kv.value = vec![]);
680        }
681        // If limit is 0, we always return all data
682        if limit == 0 || kvs.len() <= limit {
683            return Ok(RangeResponse { kvs, more: false });
684        }
685        // If limit is greater than the number of rows, we remove the last row and set more to true
686        let removed = kvs.pop();
687        debug_assert!(removed.is_some());
688        Ok(RangeResponse { kvs, more: true })
689    }
690
691    async fn batch_put_with_query_executor(
692        &self,
693        query_executor: &mut ExecutorImpl<'_, PgClient>,
694        req: BatchPutRequest,
695    ) -> Result<BatchPutResponse> {
696        let mut in_params = Vec::with_capacity(req.kvs.len() * 3);
697        let mut values_params = Vec::with_capacity(req.kvs.len() * 2);
698
699        for kv in &req.kvs {
700            let processed_key = &kv.key;
701            in_params.push(processed_key);
702
703            let processed_value = &kv.value;
704            values_params.push(processed_key);
705            values_params.push(processed_value);
706        }
707        in_params.extend(values_params);
708        let params = in_params.iter().map(|x| x as _).collect::<Vec<_>>();
709        let query = self
710            .sql_template_set
711            .generate_batch_upsert_query(req.kvs.len());
712
713        let kvs = crate::record_rds_sql_execute_elapsed!(
714            query_executor.query(&query, &params).await,
715            PG_STORE_NAME,
716            RDS_STORE_OP_BATCH_PUT,
717            ""
718        )?;
719        if req.prev_kv {
720            Ok(BatchPutResponse { prev_kvs: kvs })
721        } else {
722            Ok(BatchPutResponse::default())
723        }
724    }
725
726    /// Batch get with certain client. It's needed for a client with transaction.
727    async fn batch_get_with_query_executor(
728        &self,
729        query_executor: &mut ExecutorImpl<'_, PgClient>,
730        req: BatchGetRequest,
731    ) -> Result<BatchGetResponse> {
732        if req.keys.is_empty() {
733            return Ok(BatchGetResponse { kvs: vec![] });
734        }
735        let query = self
736            .sql_template_set
737            .generate_batch_get_query(req.keys.len());
738        let params = req.keys.iter().map(|x| x as _).collect::<Vec<_>>();
739        let kvs = crate::record_rds_sql_execute_elapsed!(
740            query_executor.query(&query, &params).await,
741            PG_STORE_NAME,
742            RDS_STORE_OP_BATCH_GET,
743            ""
744        )?;
745        Ok(BatchGetResponse { kvs })
746    }
747
748    async fn delete_range_with_query_executor(
749        &self,
750        query_executor: &mut ExecutorImpl<'_, PgClient>,
751        req: DeleteRangeRequest,
752    ) -> Result<DeleteRangeResponse> {
753        let template_type = range_template(&req.key, &req.range_end);
754        let template = self.sql_template_set.delete_template.get(template_type);
755        let params = template_type.build_params(req.key, req.range_end);
756        let params_ref = params.iter().map(|x| x as _).collect::<Vec<_>>();
757        let kvs = crate::record_rds_sql_execute_elapsed!(
758            query_executor.query(template, &params_ref).await,
759            PG_STORE_NAME,
760            RDS_STORE_OP_RANGE_DELETE,
761            template_type.as_ref()
762        )?;
763        let mut resp = DeleteRangeResponse::new(kvs.len() as i64);
764        if req.prev_kv {
765            resp.with_prev_kvs(kvs);
766        }
767        Ok(resp)
768    }
769
770    async fn batch_delete_with_query_executor(
771        &self,
772        query_executor: &mut ExecutorImpl<'_, PgClient>,
773        req: BatchDeleteRequest,
774    ) -> Result<BatchDeleteResponse> {
775        if req.keys.is_empty() {
776            return Ok(BatchDeleteResponse::default());
777        }
778        let query = self
779            .sql_template_set
780            .generate_batch_delete_query(req.keys.len());
781        let params = req.keys.iter().map(|x| x as _).collect::<Vec<_>>();
782
783        let kvs = crate::record_rds_sql_execute_elapsed!(
784            query_executor.query(&query, &params).await,
785            PG_STORE_NAME,
786            RDS_STORE_OP_BATCH_DELETE,
787            ""
788        )?;
789        if req.prev_kv {
790            Ok(BatchDeleteResponse { prev_kvs: kvs })
791        } else {
792            Ok(BatchDeleteResponse::default())
793        }
794    }
795}
796
797impl PgStore {
798    /// Create [PgStore] impl of [KvBackendRef] from url with optional TLS support.
799    ///
800    /// # Arguments
801    ///
802    /// * `url` - PostgreSQL connection URL
803    /// * `table_name` - Name of the table to use for key-value storage
804    /// * `max_txn_ops` - Maximum number of operations per transaction
805    /// * `tls_config` - Optional TLS configuration. If None, uses plaintext connection.
806    pub async fn with_url_and_tls(
807        url: &str,
808        table_name: &str,
809        max_txn_ops: usize,
810        tls_config: Option<TlsOption>,
811    ) -> Result<KvBackendRef> {
812        let mut cfg = Config::new();
813        cfg.url = Some(url.to_string());
814
815        let pool = match tls_config {
816            Some(tls_config) if tls_config.mode != TlsMode::Disable => {
817                match create_postgres_tls_connector(&tls_config) {
818                    Ok(tls_connector) => cfg
819                        .create_pool(Some(Runtime::Tokio1), tls_connector)
820                        .context(CreatePostgresPoolSnafu)?,
821                    Err(e) => {
822                        if tls_config.mode == TlsMode::Prefer {
823                            // Fallback to insecure connection if TLS fails
824                            common_telemetry::info!("Failed to create TLS connector, falling back to insecure connection");
825                            cfg.create_pool(Some(Runtime::Tokio1), NoTls)
826                                .context(CreatePostgresPoolSnafu)?
827                        } else {
828                            return Err(e);
829                        }
830                    }
831                }
832            }
833            _ => cfg
834                .create_pool(Some(Runtime::Tokio1), NoTls)
835                .context(CreatePostgresPoolSnafu)?,
836        };
837
838        Self::with_pg_pool(pool, table_name, max_txn_ops).await
839    }
840
841    /// Create [PgStore] impl of [KvBackendRef] from url (backward compatibility).
842    pub async fn with_url(url: &str, table_name: &str, max_txn_ops: usize) -> Result<KvBackendRef> {
843        Self::with_url_and_tls(url, table_name, max_txn_ops, None).await
844    }
845
846    /// Create [PgStore] impl of [KvBackendRef] from [deadpool_postgres::Pool].
847    pub async fn with_pg_pool(
848        pool: Pool,
849        table_name: &str,
850        max_txn_ops: usize,
851    ) -> Result<KvBackendRef> {
852        // This step ensures the postgres metadata backend is ready to use.
853        // We check if greptime_metakv table exists, and we will create a new table
854        // if it does not exist.
855        let client = match pool.get().await {
856            Ok(client) => client,
857            Err(e) => {
858                return GetPostgresConnectionSnafu {
859                    reason: e.to_string(),
860                }
861                .fail();
862            }
863        };
864        let template_factory = PgSqlTemplateFactory::new(table_name);
865        let sql_template_set = template_factory.build();
866        client
867            .execute(&sql_template_set.create_table_statement, &[])
868            .await
869            .with_context(|_| PostgresExecutionSnafu {
870                sql: sql_template_set.create_table_statement.to_string(),
871            })?;
872        Ok(Arc::new(Self {
873            max_txn_ops,
874            sql_template_set,
875            txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
876            executor_factory: PgExecutorFactory { pool },
877            _phantom: PhantomData,
878        }))
879    }
880}
881
882#[cfg(test)]
883mod tests {
884    use super::*;
885    use crate::kv_backend::test::{
886        prepare_kv_with_prefix, test_kv_batch_delete_with_prefix, test_kv_batch_get_with_prefix,
887        test_kv_compare_and_put_with_prefix, test_kv_delete_range_with_prefix,
888        test_kv_put_with_prefix, test_kv_range_2_with_prefix, test_kv_range_with_prefix,
889        test_simple_kv_range, test_txn_compare_equal, test_txn_compare_greater,
890        test_txn_compare_less, test_txn_compare_not_equal, test_txn_one_compare_op,
891        text_txn_multi_compare_op, unprepare_kv,
892    };
893    use crate::maybe_skip_postgres_integration_test;
894
895    async fn build_pg_kv_backend(table_name: &str) -> Option<PgStore> {
896        let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap_or_default();
897        if endpoints.is_empty() {
898            return None;
899        }
900
901        let mut cfg = Config::new();
902        cfg.url = Some(endpoints);
903        let pool = cfg
904            .create_pool(Some(Runtime::Tokio1), NoTls)
905            .context(CreatePostgresPoolSnafu)
906            .unwrap();
907        let client = pool.get().await.unwrap();
908        let template_factory = PgSqlTemplateFactory::new(table_name);
909        let sql_templates = template_factory.build();
910        client
911            .execute(&sql_templates.create_table_statement, &[])
912            .await
913            .context(PostgresExecutionSnafu {
914                sql: sql_templates.create_table_statement.to_string(),
915            })
916            .unwrap();
917        Some(PgStore {
918            max_txn_ops: 128,
919            sql_template_set: sql_templates,
920            txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
921            executor_factory: PgExecutorFactory { pool },
922            _phantom: PhantomData,
923        })
924    }
925
926    #[tokio::test]
927    async fn test_pg_put() {
928        maybe_skip_postgres_integration_test!();
929        let kv_backend = build_pg_kv_backend("put_test").await.unwrap();
930        let prefix = b"put/";
931        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
932        test_kv_put_with_prefix(&kv_backend, prefix.to_vec()).await;
933        unprepare_kv(&kv_backend, prefix).await;
934    }
935
936    #[tokio::test]
937    async fn test_pg_range() {
938        maybe_skip_postgres_integration_test!();
939        let kv_backend = build_pg_kv_backend("range_test").await.unwrap();
940        let prefix = b"range/";
941        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
942        test_kv_range_with_prefix(&kv_backend, prefix.to_vec()).await;
943        unprepare_kv(&kv_backend, prefix).await;
944    }
945
946    #[tokio::test]
947    async fn test_pg_range_2() {
948        maybe_skip_postgres_integration_test!();
949        let kv_backend = build_pg_kv_backend("range2_test").await.unwrap();
950        let prefix = b"range2/";
951        test_kv_range_2_with_prefix(&kv_backend, prefix.to_vec()).await;
952        unprepare_kv(&kv_backend, prefix).await;
953    }
954
955    #[tokio::test]
956    async fn test_pg_all_range() {
957        maybe_skip_postgres_integration_test!();
958        let kv_backend = build_pg_kv_backend("simple_range_test").await.unwrap();
959        let prefix = b"";
960        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
961        test_simple_kv_range(&kv_backend).await;
962        unprepare_kv(&kv_backend, prefix).await;
963    }
964
965    #[tokio::test]
966    async fn test_pg_batch_get() {
967        maybe_skip_postgres_integration_test!();
968        let kv_backend = build_pg_kv_backend("batch_get_test").await.unwrap();
969        let prefix = b"batch_get/";
970        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
971        test_kv_batch_get_with_prefix(&kv_backend, prefix.to_vec()).await;
972        unprepare_kv(&kv_backend, prefix).await;
973    }
974
975    #[tokio::test]
976    async fn test_pg_batch_delete() {
977        maybe_skip_postgres_integration_test!();
978        let kv_backend = build_pg_kv_backend("batch_delete_test").await.unwrap();
979        let prefix = b"batch_delete/";
980        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
981        test_kv_delete_range_with_prefix(&kv_backend, prefix.to_vec()).await;
982        unprepare_kv(&kv_backend, prefix).await;
983    }
984
985    #[tokio::test]
986    async fn test_pg_batch_delete_with_prefix() {
987        maybe_skip_postgres_integration_test!();
988        let kv_backend = build_pg_kv_backend("batch_delete_with_prefix_test")
989            .await
990            .unwrap();
991        let prefix = b"batch_delete/";
992        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
993        test_kv_batch_delete_with_prefix(&kv_backend, prefix.to_vec()).await;
994        unprepare_kv(&kv_backend, prefix).await;
995    }
996
997    #[tokio::test]
998    async fn test_pg_delete_range() {
999        maybe_skip_postgres_integration_test!();
1000        let kv_backend = build_pg_kv_backend("delete_range_test").await.unwrap();
1001        let prefix = b"delete_range/";
1002        prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
1003        test_kv_delete_range_with_prefix(&kv_backend, prefix.to_vec()).await;
1004        unprepare_kv(&kv_backend, prefix).await;
1005    }
1006
1007    #[tokio::test]
1008    async fn test_pg_compare_and_put() {
1009        maybe_skip_postgres_integration_test!();
1010        let kv_backend = build_pg_kv_backend("compare_and_put_test").await.unwrap();
1011        let prefix = b"compare_and_put/";
1012        let kv_backend = Arc::new(kv_backend);
1013        test_kv_compare_and_put_with_prefix(kv_backend.clone(), prefix.to_vec()).await;
1014    }
1015
1016    #[tokio::test]
1017    async fn test_pg_txn() {
1018        maybe_skip_postgres_integration_test!();
1019        let kv_backend = build_pg_kv_backend("txn_test").await.unwrap();
1020        test_txn_one_compare_op(&kv_backend).await;
1021        text_txn_multi_compare_op(&kv_backend).await;
1022        test_txn_compare_equal(&kv_backend).await;
1023        test_txn_compare_greater(&kv_backend).await;
1024        test_txn_compare_less(&kv_backend).await;
1025        test_txn_compare_not_equal(&kv_backend).await;
1026    }
1027}