common_meta/kv_backend/
memory.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::any::Any;
16use std::collections::BTreeMap;
17use std::marker::PhantomData;
18use std::sync::{Arc, RwLock};
19
20use async_trait::async_trait;
21use common_error::ext::ErrorExt;
22
23use crate::kv_backend::txn::{Txn, TxnOp, TxnOpResponse, TxnRequest, TxnResponse};
24use crate::kv_backend::{KvBackend, KvBackendRef, ResettableKvBackend, TxnService};
25use crate::metrics::METRIC_META_TXN_REQUEST;
26use crate::rpc::store::{
27    BatchDeleteRequest, BatchDeleteResponse, BatchGetRequest, BatchGetResponse, BatchPutRequest,
28    BatchPutResponse, DeleteRangeRequest, DeleteRangeResponse, PutRequest, PutResponse,
29    RangeRequest, RangeResponse,
30};
31use crate::rpc::KeyValue;
32
33pub struct MemoryKvBackend<T> {
34    kvs: RwLock<BTreeMap<Vec<u8>, Vec<u8>>>,
35    _phantom: PhantomData<T>,
36}
37
38impl<T> Default for MemoryKvBackend<T> {
39    fn default() -> Self {
40        Self {
41            kvs: RwLock::new(BTreeMap::new()),
42            _phantom: PhantomData,
43        }
44    }
45}
46
47impl<T> MemoryKvBackend<T> {
48    pub fn new() -> Self {
49        Self::default()
50    }
51
52    pub fn clear(&self) {
53        let mut kvs = self.kvs.write().unwrap();
54        kvs.clear();
55    }
56
57    #[cfg(test)]
58    /// Returns true if the `kvs` is empty.
59    pub fn is_empty(&self) -> bool {
60        self.kvs.read().unwrap().is_empty()
61    }
62
63    #[cfg(test)]
64    /// Returns the `kvs`.
65    pub fn dump(&self) -> BTreeMap<Vec<u8>, Vec<u8>> {
66        let kvs = self.kvs.read().unwrap();
67        kvs.clone()
68    }
69
70    #[cfg(test)]
71    /// Returns the length of `kvs`
72    pub fn len(&self) -> usize {
73        self.kvs.read().unwrap().len()
74    }
75}
76
77#[async_trait]
78impl<T: ErrorExt + Send + Sync + 'static> KvBackend for MemoryKvBackend<T> {
79    fn name(&self) -> &str {
80        "Memory"
81    }
82
83    fn as_any(&self) -> &dyn Any {
84        self
85    }
86
87    async fn range(&self, req: RangeRequest) -> Result<RangeResponse, Self::Error> {
88        let range = req.range();
89        let RangeRequest {
90            limit, keys_only, ..
91        } = req;
92
93        let kvs = self.kvs.read().unwrap();
94        let values = kvs.range(range);
95
96        let mut more = false;
97        let mut iter: i64 = 0;
98
99        let kvs = values
100            .take_while(|_| {
101                let take = limit == 0 || iter != limit;
102                iter += 1;
103                more = limit > 0 && iter > limit;
104
105                take
106            })
107            .map(|(k, v)| {
108                let key = k.clone();
109                let value = if keys_only { vec![] } else { v.clone() };
110                KeyValue { key, value }
111            })
112            .collect::<Vec<_>>();
113
114        Ok(RangeResponse { kvs, more })
115    }
116
117    async fn put(&self, req: PutRequest) -> Result<PutResponse, Self::Error> {
118        let PutRequest {
119            key,
120            value,
121            prev_kv,
122        } = req;
123
124        let mut kvs = self.kvs.write().unwrap();
125
126        let prev_kv = if prev_kv {
127            kvs.insert(key.clone(), value)
128                .map(|value| KeyValue { key, value })
129        } else {
130            kvs.insert(key, value);
131            None
132        };
133
134        Ok(PutResponse { prev_kv })
135    }
136
137    async fn batch_put(&self, req: BatchPutRequest) -> Result<BatchPutResponse, Self::Error> {
138        let mut kvs = self.kvs.write().unwrap();
139
140        let mut prev_kvs = if req.prev_kv {
141            Vec::with_capacity(req.kvs.len())
142        } else {
143            vec![]
144        };
145
146        for kv in req.kvs {
147            if req.prev_kv {
148                if let Some(value) = kvs.insert(kv.key.clone(), kv.value) {
149                    prev_kvs.push(KeyValue { key: kv.key, value });
150                }
151            } else {
152                kvs.insert(kv.key, kv.value);
153            }
154        }
155
156        Ok(BatchPutResponse { prev_kvs })
157    }
158
159    async fn batch_get(&self, req: BatchGetRequest) -> Result<BatchGetResponse, Self::Error> {
160        let kvs = self.kvs.read().unwrap();
161
162        let kvs = req
163            .keys
164            .into_iter()
165            .filter_map(|key| {
166                kvs.get_key_value(&key).map(|(k, v)| KeyValue {
167                    key: k.clone(),
168                    value: v.clone(),
169                })
170            })
171            .collect::<Vec<_>>();
172
173        Ok(BatchGetResponse { kvs })
174    }
175
176    async fn delete_range(
177        &self,
178        req: DeleteRangeRequest,
179    ) -> Result<DeleteRangeResponse, Self::Error> {
180        let range = req.range();
181        let DeleteRangeRequest { prev_kv, .. } = req;
182
183        let mut kvs = self.kvs.write().unwrap();
184
185        let keys = kvs
186            .range(range)
187            .map(|(key, _)| key.clone())
188            .collect::<Vec<_>>();
189
190        let mut prev_kvs = if prev_kv {
191            Vec::with_capacity(keys.len())
192        } else {
193            vec![]
194        };
195        let deleted = keys.len() as i64;
196
197        for key in keys {
198            if let Some(value) = kvs.remove(&key) {
199                if prev_kv {
200                    prev_kvs.push((key.clone(), value).into())
201                }
202            }
203        }
204
205        Ok(DeleteRangeResponse { deleted, prev_kvs })
206    }
207
208    async fn batch_delete(
209        &self,
210        req: BatchDeleteRequest,
211    ) -> Result<BatchDeleteResponse, Self::Error> {
212        let mut kvs = self.kvs.write().unwrap();
213
214        let mut prev_kvs = if req.prev_kv {
215            Vec::with_capacity(req.keys.len())
216        } else {
217            vec![]
218        };
219
220        for key in req.keys {
221            if req.prev_kv {
222                if let Some(value) = kvs.remove(&key) {
223                    prev_kvs.push(KeyValue { key, value });
224                }
225            } else {
226                kvs.remove(&key);
227            }
228        }
229
230        Ok(BatchDeleteResponse { prev_kvs })
231    }
232}
233
234#[async_trait]
235impl<T: ErrorExt + Send + Sync> TxnService for MemoryKvBackend<T> {
236    type Error = T;
237
238    async fn txn(&self, txn: Txn) -> Result<TxnResponse, Self::Error> {
239        let _timer = METRIC_META_TXN_REQUEST
240            .with_label_values(&["memory", "txn"])
241            .start_timer();
242
243        let TxnRequest {
244            compare,
245            success,
246            failure,
247        } = txn.into();
248
249        let mut kvs = self.kvs.write().unwrap();
250
251        let succeeded = compare.iter().all(|x| x.compare_value(kvs.get(&x.key)));
252
253        let do_txn = |txn_op| match txn_op {
254            TxnOp::Put(key, value) => {
255                kvs.insert(key, value);
256                TxnOpResponse::ResponsePut(PutResponse { prev_kv: None })
257            }
258
259            TxnOp::Get(key) => {
260                let value = kvs.get(&key).cloned();
261                let kvs = value
262                    .map(|value| KeyValue { key, value })
263                    .into_iter()
264                    .collect();
265                TxnOpResponse::ResponseGet(RangeResponse { kvs, more: false })
266            }
267
268            TxnOp::Delete(key) => {
269                let prev_value = kvs.remove(&key);
270                let deleted = if prev_value.is_some() { 1 } else { 0 };
271                TxnOpResponse::ResponseDelete(DeleteRangeResponse {
272                    deleted,
273                    prev_kvs: vec![],
274                })
275            }
276        };
277
278        let responses: Vec<_> = if succeeded { success } else { failure }
279            .into_iter()
280            .map(do_txn)
281            .collect();
282
283        Ok(TxnResponse {
284            succeeded,
285            responses,
286        })
287    }
288
289    fn max_txn_ops(&self) -> usize {
290        usize::MAX
291    }
292}
293
294impl<T: ErrorExt + Send + Sync + 'static> ResettableKvBackend for MemoryKvBackend<T> {
295    fn reset(&self) {
296        self.clear();
297    }
298
299    fn as_kv_backend_ref(self: Arc<Self>) -> KvBackendRef<T> {
300        self
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use std::sync::Arc;
307
308    use super::*;
309    use crate::error::Error;
310    use crate::kv_backend::test::{
311        prepare_kv, test_kv_batch_delete, test_kv_batch_get, test_kv_compare_and_put,
312        test_kv_delete_range, test_kv_put, test_kv_range, test_kv_range_2, test_txn_compare_equal,
313        test_txn_compare_greater, test_txn_compare_less, test_txn_compare_not_equal,
314        test_txn_one_compare_op, text_txn_multi_compare_op,
315    };
316
317    async fn mock_mem_store_with_data() -> MemoryKvBackend<Error> {
318        let kv_backend = MemoryKvBackend::<Error>::new();
319        prepare_kv(&kv_backend).await;
320
321        kv_backend
322    }
323
324    #[tokio::test]
325    async fn test_put() {
326        let kv_backend = mock_mem_store_with_data().await;
327
328        test_kv_put(&kv_backend).await;
329    }
330
331    #[tokio::test]
332    async fn test_range() {
333        let kv_backend = mock_mem_store_with_data().await;
334
335        test_kv_range(&kv_backend).await;
336    }
337
338    #[tokio::test]
339    async fn test_range_2() {
340        let kv = MemoryKvBackend::<Error>::new();
341
342        test_kv_range_2(&kv).await;
343    }
344
345    #[tokio::test]
346    async fn test_batch_get() {
347        let kv_backend = mock_mem_store_with_data().await;
348
349        test_kv_batch_get(&kv_backend).await;
350    }
351
352    #[tokio::test(flavor = "multi_thread")]
353    async fn test_compare_and_put() {
354        let kv_backend = Arc::new(MemoryKvBackend::<Error>::new());
355
356        test_kv_compare_and_put(kv_backend).await;
357    }
358
359    #[tokio::test]
360    async fn test_delete_range() {
361        let kv_backend = mock_mem_store_with_data().await;
362
363        test_kv_delete_range(&kv_backend).await;
364    }
365
366    #[tokio::test]
367    async fn test_batch_delete() {
368        let kv_backend = mock_mem_store_with_data().await;
369
370        test_kv_batch_delete(&kv_backend).await;
371    }
372
373    #[tokio::test]
374    async fn test_memory_txn() {
375        let kv_backend = MemoryKvBackend::<Error>::new();
376        test_txn_one_compare_op(&kv_backend).await;
377        text_txn_multi_compare_op(&kv_backend).await;
378        test_txn_compare_equal(&kv_backend).await;
379        test_txn_compare_greater(&kv_backend).await;
380        test_txn_compare_less(&kv_backend).await;
381        test_txn_compare_not_equal(&kv_backend).await;
382    }
383}