1use std::marker::PhantomData;
16use std::sync::Arc;
17
18use common_telemetry::debug;
19use deadpool_postgres::{Config, Pool, Runtime};
20use snafu::ResultExt;
21use strum::AsRefStr;
22use tokio_postgres::types::ToSql;
23use tokio_postgres::{IsolationLevel, NoTls, Row};
24
25use crate::error::{
26 CreatePostgresPoolSnafu, GetPostgresConnectionSnafu, PostgresExecutionSnafu,
27 PostgresTransactionSnafu, Result,
28};
29use crate::kv_backend::rds::{
30 Executor, ExecutorFactory, ExecutorImpl, KvQueryExecutor, RdsStore, Transaction,
31 RDS_STORE_OP_BATCH_DELETE, RDS_STORE_OP_BATCH_GET, RDS_STORE_OP_BATCH_PUT,
32 RDS_STORE_OP_RANGE_DELETE, RDS_STORE_OP_RANGE_QUERY, RDS_STORE_TXN_RETRY_COUNT,
33};
34use crate::kv_backend::KvBackendRef;
35use crate::rpc::store::{
36 BatchDeleteRequest, BatchDeleteResponse, BatchGetRequest, BatchGetResponse, BatchPutRequest,
37 BatchPutResponse, DeleteRangeRequest, DeleteRangeResponse, RangeRequest, RangeResponse,
38};
39use crate::rpc::KeyValue;
40
41const PG_STORE_NAME: &str = "pg_store";
42
43pub struct PgClient(deadpool::managed::Object<deadpool_postgres::Manager>);
44pub struct PgTxnClient<'a>(deadpool_postgres::Transaction<'a>);
45
46fn key_value_from_row(r: Row) -> KeyValue {
48 KeyValue {
49 key: r.get(0),
50 value: r.get(1),
51 }
52}
53
54const EMPTY: &[u8] = &[0];
55
56#[derive(Debug, Clone, Copy, AsRefStr)]
58enum RangeTemplateType {
59 Point,
60 Range,
61 Full,
62 LeftBounded,
63 Prefix,
64}
65
66impl RangeTemplateType {
68 fn build_params(&self, mut key: Vec<u8>, range_end: Vec<u8>) -> Vec<Vec<u8>> {
71 match self {
72 RangeTemplateType::Point => vec![key],
73 RangeTemplateType::Range => vec![key, range_end],
74 RangeTemplateType::Full => vec![],
75 RangeTemplateType::LeftBounded => vec![key],
76 RangeTemplateType::Prefix => {
77 key.push(b'%');
78 vec![key]
79 }
80 }
81 }
82}
83
84#[derive(Debug, Clone)]
86struct RangeTemplate {
87 point: String,
88 range: String,
89 full: String,
90 left_bounded: String,
91 prefix: String,
92}
93
94impl RangeTemplate {
95 fn get(&self, typ: RangeTemplateType) -> &str {
97 match typ {
98 RangeTemplateType::Point => &self.point,
99 RangeTemplateType::Range => &self.range,
100 RangeTemplateType::Full => &self.full,
101 RangeTemplateType::LeftBounded => &self.left_bounded,
102 RangeTemplateType::Prefix => &self.prefix,
103 }
104 }
105
106 fn with_limit(template: &str, limit: i64) -> String {
108 if limit == 0 {
109 return format!("{};", template);
110 }
111 format!("{} LIMIT {};", template, limit)
112 }
113}
114
115fn is_prefix_range(start: &[u8], end: &[u8]) -> bool {
116 if start.len() != end.len() {
117 return false;
118 }
119 let l = start.len();
120 let same_prefix = start[0..l - 1] == end[0..l - 1];
121 if let (Some(rhs), Some(lhs)) = (start.last(), end.last()) {
122 return same_prefix && (*rhs + 1) == *lhs;
123 }
124 false
125}
126
127fn range_template(key: &[u8], range_end: &[u8]) -> RangeTemplateType {
129 match (key, range_end) {
130 (_, &[]) => RangeTemplateType::Point,
131 (EMPTY, EMPTY) => RangeTemplateType::Full,
132 (_, EMPTY) => RangeTemplateType::LeftBounded,
133 (start, end) => {
134 if is_prefix_range(start, end) {
135 RangeTemplateType::Prefix
136 } else {
137 RangeTemplateType::Range
138 }
139 }
140 }
141}
142
143fn pg_generate_in_placeholders(from: usize, to: usize) -> Vec<String> {
145 (from..=to).map(|i| format!("${}", i)).collect()
146}
147
148struct PgSqlTemplateFactory<'a> {
150 table_name: &'a str,
151}
152
153impl<'a> PgSqlTemplateFactory<'a> {
154 fn new(table_name: &'a str) -> Self {
156 Self { table_name }
157 }
158
159 fn build(&self) -> PgSqlTemplateSet {
161 let table_name = self.table_name;
162 PgSqlTemplateSet {
164 table_name: table_name.to_string(),
165 create_table_statement: format!(
166 "CREATE TABLE IF NOT EXISTS \"{table_name}\"(k bytea PRIMARY KEY, v bytea)",
167 ),
168 range_template: RangeTemplate {
169 point: format!("SELECT k, v FROM \"{table_name}\" WHERE k = $1"),
170 range: format!(
171 "SELECT k, v FROM \"{table_name}\" WHERE k >= $1 AND k < $2 ORDER BY k"
172 ),
173 full: format!("SELECT k, v FROM \"{table_name}\" ORDER BY k"),
174 left_bounded: format!("SELECT k, v FROM \"{table_name}\" WHERE k >= $1 ORDER BY k"),
175 prefix: format!("SELECT k, v FROM \"{table_name}\" WHERE k LIKE $1 ORDER BY k"),
176 },
177 delete_template: RangeTemplate {
178 point: format!("DELETE FROM \"{table_name}\" WHERE k = $1 RETURNING k,v;"),
179 range: format!(
180 "DELETE FROM \"{table_name}\" WHERE k >= $1 AND k < $2 RETURNING k,v;"
181 ),
182 full: format!("DELETE FROM \"{table_name}\" RETURNING k,v"),
183 left_bounded: format!("DELETE FROM \"{table_name}\" WHERE k >= $1 RETURNING k,v;"),
184 prefix: format!("DELETE FROM \"{table_name}\" WHERE k LIKE $1 RETURNING k,v;"),
185 },
186 }
187 }
188}
189
190#[derive(Debug, Clone)]
192pub struct PgSqlTemplateSet {
193 table_name: String,
194 create_table_statement: String,
195 range_template: RangeTemplate,
196 delete_template: RangeTemplate,
197}
198
199impl PgSqlTemplateSet {
200 fn generate_batch_get_query(&self, key_len: usize) -> String {
202 let table_name = &self.table_name;
203 let in_clause = pg_generate_in_placeholders(1, key_len).join(", ");
204 format!(
205 "SELECT k, v FROM \"{table_name}\" WHERE k in ({});",
206 in_clause
207 )
208 }
209
210 fn generate_batch_delete_query(&self, key_len: usize) -> String {
212 let table_name = &self.table_name;
213 let in_clause = pg_generate_in_placeholders(1, key_len).join(", ");
214 format!(
215 "DELETE FROM \"{table_name}\" WHERE k in ({}) RETURNING k,v;",
216 in_clause
217 )
218 }
219
220 fn generate_batch_upsert_query(&self, kv_len: usize) -> String {
222 let table_name = &self.table_name;
223 let in_placeholders: Vec<String> = (1..=kv_len).map(|i| format!("${}", i)).collect();
224 let in_clause = in_placeholders.join(", ");
225 let mut param_index = kv_len + 1;
226 let mut values_placeholders = Vec::new();
227 for _ in 0..kv_len {
228 values_placeholders.push(format!("(${0}, ${1})", param_index, param_index + 1));
229 param_index += 2;
230 }
231 let values_clause = values_placeholders.join(", ");
232
233 format!(
234 r#"
235 WITH prev AS (
236 SELECT k,v FROM "{table_name}" WHERE k IN ({in_clause})
237 ), update AS (
238 INSERT INTO "{table_name}" (k, v) VALUES
239 {values_clause}
240 ON CONFLICT (
241 k
242 ) DO UPDATE SET
243 v = excluded.v
244 )
245
246 SELECT k, v FROM prev;
247 "#
248 )
249 }
250}
251
252#[async_trait::async_trait]
253impl Executor for PgClient {
254 type Transaction<'a>
255 = PgTxnClient<'a>
256 where
257 Self: 'a;
258
259 fn name() -> &'static str {
260 "Postgres"
261 }
262
263 async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>> {
264 let params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
265 let stmt = self
266 .0
267 .prepare_cached(query)
268 .await
269 .context(PostgresExecutionSnafu { sql: query })?;
270 let rows = self
271 .0
272 .query(&stmt, ¶ms)
273 .await
274 .context(PostgresExecutionSnafu { sql: query })?;
275 Ok(rows.into_iter().map(key_value_from_row).collect())
276 }
277
278 async fn txn_executor<'a>(&'a mut self) -> Result<Self::Transaction<'a>> {
279 let txn = self
280 .0
281 .build_transaction()
282 .isolation_level(IsolationLevel::Serializable)
283 .start()
284 .await
285 .context(PostgresTransactionSnafu {
286 operation: "begin".to_string(),
287 })?;
288 Ok(PgTxnClient(txn))
289 }
290}
291
292#[async_trait::async_trait]
293impl<'a> Transaction<'a> for PgTxnClient<'a> {
294 async fn query(&mut self, query: &str, params: &[&Vec<u8>]) -> Result<Vec<KeyValue>> {
295 let params: Vec<&(dyn ToSql + Sync)> = params.iter().map(|p| p as _).collect();
296 let stmt = self
297 .0
298 .prepare_cached(query)
299 .await
300 .context(PostgresExecutionSnafu { sql: query })?;
301 let rows = self
302 .0
303 .query(&stmt, ¶ms)
304 .await
305 .context(PostgresExecutionSnafu { sql: query })?;
306 Ok(rows.into_iter().map(key_value_from_row).collect())
307 }
308
309 async fn commit(self) -> Result<()> {
310 self.0.commit().await.context(PostgresTransactionSnafu {
311 operation: "commit",
312 })?;
313 Ok(())
314 }
315}
316
317pub struct PgExecutorFactory {
318 pool: Pool,
319}
320
321impl PgExecutorFactory {
322 async fn client(&self) -> Result<PgClient> {
323 match self.pool.get().await {
324 Ok(client) => Ok(PgClient(client)),
325 Err(e) => GetPostgresConnectionSnafu {
326 reason: e.to_string(),
327 }
328 .fail(),
329 }
330 }
331}
332
333#[async_trait::async_trait]
334impl ExecutorFactory<PgClient> for PgExecutorFactory {
335 async fn default_executor(&self) -> Result<PgClient> {
336 self.client().await
337 }
338
339 async fn txn_executor<'a>(
340 &self,
341 default_executor: &'a mut PgClient,
342 ) -> Result<PgTxnClient<'a>> {
343 default_executor.txn_executor().await
344 }
345}
346
347pub type PgStore = RdsStore<PgClient, PgExecutorFactory, PgSqlTemplateSet>;
350
351#[async_trait::async_trait]
352impl KvQueryExecutor<PgClient> for PgStore {
353 async fn range_with_query_executor(
354 &self,
355 query_executor: &mut ExecutorImpl<'_, PgClient>,
356 req: RangeRequest,
357 ) -> Result<RangeResponse> {
358 let template_type = range_template(&req.key, &req.range_end);
359 let template = self.sql_template_set.range_template.get(template_type);
360 let params = template_type.build_params(req.key, req.range_end);
361 let params_ref = params.iter().collect::<Vec<_>>();
362 let query =
364 RangeTemplate::with_limit(template, if req.limit == 0 { 0 } else { req.limit + 1 });
365 let limit = req.limit as usize;
366 debug!("query: {:?}, params: {:?}", query, params);
367 let mut kvs = crate::record_rds_sql_execute_elapsed!(
368 query_executor.query(&query, ¶ms_ref).await,
369 PG_STORE_NAME,
370 RDS_STORE_OP_RANGE_QUERY,
371 template_type.as_ref()
372 )?;
373
374 if req.keys_only {
375 kvs.iter_mut().for_each(|kv| kv.value = vec![]);
376 }
377 if limit == 0 || kvs.len() <= limit {
379 return Ok(RangeResponse { kvs, more: false });
380 }
381 let removed = kvs.pop();
383 debug_assert!(removed.is_some());
384 Ok(RangeResponse { kvs, more: true })
385 }
386
387 async fn batch_put_with_query_executor(
388 &self,
389 query_executor: &mut ExecutorImpl<'_, PgClient>,
390 req: BatchPutRequest,
391 ) -> Result<BatchPutResponse> {
392 let mut in_params = Vec::with_capacity(req.kvs.len() * 3);
393 let mut values_params = Vec::with_capacity(req.kvs.len() * 2);
394
395 for kv in &req.kvs {
396 let processed_key = &kv.key;
397 in_params.push(processed_key);
398
399 let processed_value = &kv.value;
400 values_params.push(processed_key);
401 values_params.push(processed_value);
402 }
403 in_params.extend(values_params);
404 let params = in_params.iter().map(|x| x as _).collect::<Vec<_>>();
405 let query = self
406 .sql_template_set
407 .generate_batch_upsert_query(req.kvs.len());
408
409 let kvs = crate::record_rds_sql_execute_elapsed!(
410 query_executor.query(&query, ¶ms).await,
411 PG_STORE_NAME,
412 RDS_STORE_OP_BATCH_PUT,
413 ""
414 )?;
415 if req.prev_kv {
416 Ok(BatchPutResponse { prev_kvs: kvs })
417 } else {
418 Ok(BatchPutResponse::default())
419 }
420 }
421
422 async fn batch_get_with_query_executor(
424 &self,
425 query_executor: &mut ExecutorImpl<'_, PgClient>,
426 req: BatchGetRequest,
427 ) -> Result<BatchGetResponse> {
428 if req.keys.is_empty() {
429 return Ok(BatchGetResponse { kvs: vec![] });
430 }
431 let query = self
432 .sql_template_set
433 .generate_batch_get_query(req.keys.len());
434 let params = req.keys.iter().map(|x| x as _).collect::<Vec<_>>();
435 let kvs = crate::record_rds_sql_execute_elapsed!(
436 query_executor.query(&query, ¶ms).await,
437 PG_STORE_NAME,
438 RDS_STORE_OP_BATCH_GET,
439 ""
440 )?;
441 Ok(BatchGetResponse { kvs })
442 }
443
444 async fn delete_range_with_query_executor(
445 &self,
446 query_executor: &mut ExecutorImpl<'_, PgClient>,
447 req: DeleteRangeRequest,
448 ) -> Result<DeleteRangeResponse> {
449 let template_type = range_template(&req.key, &req.range_end);
450 let template = self.sql_template_set.delete_template.get(template_type);
451 let params = template_type.build_params(req.key, req.range_end);
452 let params_ref = params.iter().map(|x| x as _).collect::<Vec<_>>();
453 let kvs = crate::record_rds_sql_execute_elapsed!(
454 query_executor.query(template, ¶ms_ref).await,
455 PG_STORE_NAME,
456 RDS_STORE_OP_RANGE_DELETE,
457 template_type.as_ref()
458 )?;
459 let mut resp = DeleteRangeResponse::new(kvs.len() as i64);
460 if req.prev_kv {
461 resp.with_prev_kvs(kvs);
462 }
463 Ok(resp)
464 }
465
466 async fn batch_delete_with_query_executor(
467 &self,
468 query_executor: &mut ExecutorImpl<'_, PgClient>,
469 req: BatchDeleteRequest,
470 ) -> Result<BatchDeleteResponse> {
471 if req.keys.is_empty() {
472 return Ok(BatchDeleteResponse::default());
473 }
474 let query = self
475 .sql_template_set
476 .generate_batch_delete_query(req.keys.len());
477 let params = req.keys.iter().map(|x| x as _).collect::<Vec<_>>();
478
479 let kvs = crate::record_rds_sql_execute_elapsed!(
480 query_executor.query(&query, ¶ms).await,
481 PG_STORE_NAME,
482 RDS_STORE_OP_BATCH_DELETE,
483 ""
484 )?;
485 if req.prev_kv {
486 Ok(BatchDeleteResponse { prev_kvs: kvs })
487 } else {
488 Ok(BatchDeleteResponse::default())
489 }
490 }
491}
492
493impl PgStore {
494 pub async fn with_url(url: &str, table_name: &str, max_txn_ops: usize) -> Result<KvBackendRef> {
496 let mut cfg = Config::new();
497 cfg.url = Some(url.to_string());
498 let pool = cfg
500 .create_pool(Some(Runtime::Tokio1), NoTls)
501 .context(CreatePostgresPoolSnafu)?;
502 Self::with_pg_pool(pool, table_name, max_txn_ops).await
503 }
504
505 pub async fn with_pg_pool(
507 pool: Pool,
508 table_name: &str,
509 max_txn_ops: usize,
510 ) -> Result<KvBackendRef> {
511 let client = match pool.get().await {
515 Ok(client) => client,
516 Err(e) => {
517 return GetPostgresConnectionSnafu {
518 reason: e.to_string(),
519 }
520 .fail();
521 }
522 };
523 let template_factory = PgSqlTemplateFactory::new(table_name);
524 let sql_template_set = template_factory.build();
525 client
526 .execute(&sql_template_set.create_table_statement, &[])
527 .await
528 .with_context(|_| PostgresExecutionSnafu {
529 sql: sql_template_set.create_table_statement.to_string(),
530 })?;
531 Ok(Arc::new(Self {
532 max_txn_ops,
533 sql_template_set,
534 txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
535 executor_factory: PgExecutorFactory { pool },
536 _phantom: PhantomData,
537 }))
538 }
539}
540
541#[cfg(test)]
542mod tests {
543 use super::*;
544 use crate::kv_backend::test::{
545 prepare_kv_with_prefix, test_kv_batch_delete_with_prefix, test_kv_batch_get_with_prefix,
546 test_kv_compare_and_put_with_prefix, test_kv_delete_range_with_prefix,
547 test_kv_put_with_prefix, test_kv_range_2_with_prefix, test_kv_range_with_prefix,
548 test_simple_kv_range, test_txn_compare_equal, test_txn_compare_greater,
549 test_txn_compare_less, test_txn_compare_not_equal, test_txn_one_compare_op,
550 text_txn_multi_compare_op, unprepare_kv,
551 };
552 use crate::maybe_skip_postgres_integration_test;
553
554 async fn build_pg_kv_backend(table_name: &str) -> Option<PgStore> {
555 let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap_or_default();
556 if endpoints.is_empty() {
557 return None;
558 }
559
560 let mut cfg = Config::new();
561 cfg.url = Some(endpoints);
562 let pool = cfg
563 .create_pool(Some(Runtime::Tokio1), NoTls)
564 .context(CreatePostgresPoolSnafu)
565 .unwrap();
566 let client = pool.get().await.unwrap();
567 let template_factory = PgSqlTemplateFactory::new(table_name);
568 let sql_templates = template_factory.build();
569 client
570 .execute(&sql_templates.create_table_statement, &[])
571 .await
572 .context(PostgresExecutionSnafu {
573 sql: sql_templates.create_table_statement.to_string(),
574 })
575 .unwrap();
576 Some(PgStore {
577 max_txn_ops: 128,
578 sql_template_set: sql_templates,
579 txn_retry_count: RDS_STORE_TXN_RETRY_COUNT,
580 executor_factory: PgExecutorFactory { pool },
581 _phantom: PhantomData,
582 })
583 }
584
585 #[tokio::test]
586 async fn test_pg_put() {
587 maybe_skip_postgres_integration_test!();
588 let kv_backend = build_pg_kv_backend("put_test").await.unwrap();
589 let prefix = b"put/";
590 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
591 test_kv_put_with_prefix(&kv_backend, prefix.to_vec()).await;
592 unprepare_kv(&kv_backend, prefix).await;
593 }
594
595 #[tokio::test]
596 async fn test_pg_range() {
597 maybe_skip_postgres_integration_test!();
598 let kv_backend = build_pg_kv_backend("range_test").await.unwrap();
599 let prefix = b"range/";
600 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
601 test_kv_range_with_prefix(&kv_backend, prefix.to_vec()).await;
602 unprepare_kv(&kv_backend, prefix).await;
603 }
604
605 #[tokio::test]
606 async fn test_pg_range_2() {
607 maybe_skip_postgres_integration_test!();
608 let kv_backend = build_pg_kv_backend("range2_test").await.unwrap();
609 let prefix = b"range2/";
610 test_kv_range_2_with_prefix(&kv_backend, prefix.to_vec()).await;
611 unprepare_kv(&kv_backend, prefix).await;
612 }
613
614 #[tokio::test]
615 async fn test_pg_all_range() {
616 maybe_skip_postgres_integration_test!();
617 let kv_backend = build_pg_kv_backend("simple_range_test").await.unwrap();
618 let prefix = b"";
619 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
620 test_simple_kv_range(&kv_backend).await;
621 unprepare_kv(&kv_backend, prefix).await;
622 }
623
624 #[tokio::test]
625 async fn test_pg_batch_get() {
626 maybe_skip_postgres_integration_test!();
627 let kv_backend = build_pg_kv_backend("batch_get_test").await.unwrap();
628 let prefix = b"batch_get/";
629 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
630 test_kv_batch_get_with_prefix(&kv_backend, prefix.to_vec()).await;
631 unprepare_kv(&kv_backend, prefix).await;
632 }
633
634 #[tokio::test]
635 async fn test_pg_batch_delete() {
636 maybe_skip_postgres_integration_test!();
637 let kv_backend = build_pg_kv_backend("batch_delete_test").await.unwrap();
638 let prefix = b"batch_delete/";
639 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
640 test_kv_delete_range_with_prefix(&kv_backend, prefix.to_vec()).await;
641 unprepare_kv(&kv_backend, prefix).await;
642 }
643
644 #[tokio::test]
645 async fn test_pg_batch_delete_with_prefix() {
646 maybe_skip_postgres_integration_test!();
647 let kv_backend = build_pg_kv_backend("batch_delete_with_prefix_test")
648 .await
649 .unwrap();
650 let prefix = b"batch_delete/";
651 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
652 test_kv_batch_delete_with_prefix(&kv_backend, prefix.to_vec()).await;
653 unprepare_kv(&kv_backend, prefix).await;
654 }
655
656 #[tokio::test]
657 async fn test_pg_delete_range() {
658 maybe_skip_postgres_integration_test!();
659 let kv_backend = build_pg_kv_backend("delete_range_test").await.unwrap();
660 let prefix = b"delete_range/";
661 prepare_kv_with_prefix(&kv_backend, prefix.to_vec()).await;
662 test_kv_delete_range_with_prefix(&kv_backend, prefix.to_vec()).await;
663 unprepare_kv(&kv_backend, prefix).await;
664 }
665
666 #[tokio::test]
667 async fn test_pg_compare_and_put() {
668 maybe_skip_postgres_integration_test!();
669 let kv_backend = build_pg_kv_backend("compare_and_put_test").await.unwrap();
670 let prefix = b"compare_and_put/";
671 let kv_backend = Arc::new(kv_backend);
672 test_kv_compare_and_put_with_prefix(kv_backend.clone(), prefix.to_vec()).await;
673 }
674
675 #[tokio::test]
676 async fn test_pg_txn() {
677 maybe_skip_postgres_integration_test!();
678 let kv_backend = build_pg_kv_backend("txn_test").await.unwrap();
679 test_txn_one_compare_op(&kv_backend).await;
680 text_txn_multi_compare_op(&kv_backend).await;
681 test_txn_compare_equal(&kv_backend).await;
682 test_txn_compare_greater(&kv_backend).await;
683 test_txn_compare_less(&kv_backend).await;
684 test_txn_compare_not_equal(&kv_backend).await;
685 }
686}