1use std::any::Any;
16use std::collections::BTreeMap;
17use std::marker::PhantomData;
18use std::sync::{Arc, RwLock};
19
20use async_trait::async_trait;
21use common_error::ext::ErrorExt;
22
23use crate::kv_backend::txn::{Txn, TxnOp, TxnOpResponse, TxnRequest, TxnResponse};
24use crate::kv_backend::{KvBackend, KvBackendRef, ResettableKvBackend, TxnService};
25use crate::metrics::METRIC_META_TXN_REQUEST;
26use crate::rpc::store::{
27 BatchDeleteRequest, BatchDeleteResponse, BatchGetRequest, BatchGetResponse, BatchPutRequest,
28 BatchPutResponse, DeleteRangeRequest, DeleteRangeResponse, PutRequest, PutResponse,
29 RangeRequest, RangeResponse,
30};
31use crate::rpc::KeyValue;
32
33pub struct MemoryKvBackend<T> {
34 kvs: RwLock<BTreeMap<Vec<u8>, Vec<u8>>>,
35 _phantom: PhantomData<T>,
36}
37
38impl<T> Default for MemoryKvBackend<T> {
39 fn default() -> Self {
40 Self {
41 kvs: RwLock::new(BTreeMap::new()),
42 _phantom: PhantomData,
43 }
44 }
45}
46
47impl<T> MemoryKvBackend<T> {
48 pub fn new() -> Self {
49 Self::default()
50 }
51
52 pub fn clear(&self) {
53 let mut kvs = self.kvs.write().unwrap();
54 kvs.clear();
55 }
56
57 #[cfg(test)]
58 pub fn is_empty(&self) -> bool {
60 self.kvs.read().unwrap().is_empty()
61 }
62
63 #[cfg(test)]
64 pub fn dump(&self) -> BTreeMap<Vec<u8>, Vec<u8>> {
66 let kvs = self.kvs.read().unwrap();
67 kvs.clone()
68 }
69
70 #[cfg(test)]
71 pub fn len(&self) -> usize {
73 self.kvs.read().unwrap().len()
74 }
75}
76
77#[async_trait]
78impl<T: ErrorExt + Send + Sync + 'static> KvBackend for MemoryKvBackend<T> {
79 fn name(&self) -> &str {
80 "Memory"
81 }
82
83 fn as_any(&self) -> &dyn Any {
84 self
85 }
86
87 async fn range(&self, req: RangeRequest) -> Result<RangeResponse, Self::Error> {
88 let range = req.range();
89 let RangeRequest {
90 limit, keys_only, ..
91 } = req;
92
93 let kvs = self.kvs.read().unwrap();
94 let values = kvs.range(range);
95
96 let mut more = false;
97 let mut iter: i64 = 0;
98
99 let kvs = values
100 .take_while(|_| {
101 let take = limit == 0 || iter != limit;
102 iter += 1;
103 more = limit > 0 && iter > limit;
104
105 take
106 })
107 .map(|(k, v)| {
108 let key = k.clone();
109 let value = if keys_only { vec![] } else { v.clone() };
110 KeyValue { key, value }
111 })
112 .collect::<Vec<_>>();
113
114 Ok(RangeResponse { kvs, more })
115 }
116
117 async fn put(&self, req: PutRequest) -> Result<PutResponse, Self::Error> {
118 let PutRequest {
119 key,
120 value,
121 prev_kv,
122 } = req;
123
124 let mut kvs = self.kvs.write().unwrap();
125
126 let prev_kv = if prev_kv {
127 kvs.insert(key.clone(), value)
128 .map(|value| KeyValue { key, value })
129 } else {
130 kvs.insert(key, value);
131 None
132 };
133
134 Ok(PutResponse { prev_kv })
135 }
136
137 async fn batch_put(&self, req: BatchPutRequest) -> Result<BatchPutResponse, Self::Error> {
138 let mut kvs = self.kvs.write().unwrap();
139
140 let mut prev_kvs = if req.prev_kv {
141 Vec::with_capacity(req.kvs.len())
142 } else {
143 vec![]
144 };
145
146 for kv in req.kvs {
147 if req.prev_kv {
148 if let Some(value) = kvs.insert(kv.key.clone(), kv.value) {
149 prev_kvs.push(KeyValue { key: kv.key, value });
150 }
151 } else {
152 kvs.insert(kv.key, kv.value);
153 }
154 }
155
156 Ok(BatchPutResponse { prev_kvs })
157 }
158
159 async fn batch_get(&self, req: BatchGetRequest) -> Result<BatchGetResponse, Self::Error> {
160 let kvs = self.kvs.read().unwrap();
161
162 let kvs = req
163 .keys
164 .into_iter()
165 .filter_map(|key| {
166 kvs.get_key_value(&key).map(|(k, v)| KeyValue {
167 key: k.clone(),
168 value: v.clone(),
169 })
170 })
171 .collect::<Vec<_>>();
172
173 Ok(BatchGetResponse { kvs })
174 }
175
176 async fn delete_range(
177 &self,
178 req: DeleteRangeRequest,
179 ) -> Result<DeleteRangeResponse, Self::Error> {
180 let range = req.range();
181 let DeleteRangeRequest { prev_kv, .. } = req;
182
183 let mut kvs = self.kvs.write().unwrap();
184
185 let keys = kvs
186 .range(range)
187 .map(|(key, _)| key.clone())
188 .collect::<Vec<_>>();
189
190 let mut prev_kvs = if prev_kv {
191 Vec::with_capacity(keys.len())
192 } else {
193 vec![]
194 };
195 let deleted = keys.len() as i64;
196
197 for key in keys {
198 if let Some(value) = kvs.remove(&key) {
199 if prev_kv {
200 prev_kvs.push((key.clone(), value).into())
201 }
202 }
203 }
204
205 Ok(DeleteRangeResponse { deleted, prev_kvs })
206 }
207
208 async fn batch_delete(
209 &self,
210 req: BatchDeleteRequest,
211 ) -> Result<BatchDeleteResponse, Self::Error> {
212 let mut kvs = self.kvs.write().unwrap();
213
214 let mut prev_kvs = if req.prev_kv {
215 Vec::with_capacity(req.keys.len())
216 } else {
217 vec![]
218 };
219
220 for key in req.keys {
221 if req.prev_kv {
222 if let Some(value) = kvs.remove(&key) {
223 prev_kvs.push(KeyValue { key, value });
224 }
225 } else {
226 kvs.remove(&key);
227 }
228 }
229
230 Ok(BatchDeleteResponse { prev_kvs })
231 }
232}
233
234#[async_trait]
235impl<T: ErrorExt + Send + Sync> TxnService for MemoryKvBackend<T> {
236 type Error = T;
237
238 async fn txn(&self, txn: Txn) -> Result<TxnResponse, Self::Error> {
239 let _timer = METRIC_META_TXN_REQUEST
240 .with_label_values(&["memory", "txn"])
241 .start_timer();
242
243 let TxnRequest {
244 compare,
245 success,
246 failure,
247 } = txn.into();
248
249 let mut kvs = self.kvs.write().unwrap();
250
251 let succeeded = compare.iter().all(|x| x.compare_value(kvs.get(&x.key)));
252
253 let do_txn = |txn_op| match txn_op {
254 TxnOp::Put(key, value) => {
255 kvs.insert(key, value);
256 TxnOpResponse::ResponsePut(PutResponse { prev_kv: None })
257 }
258
259 TxnOp::Get(key) => {
260 let value = kvs.get(&key).cloned();
261 let kvs = value
262 .map(|value| KeyValue { key, value })
263 .into_iter()
264 .collect();
265 TxnOpResponse::ResponseGet(RangeResponse { kvs, more: false })
266 }
267
268 TxnOp::Delete(key) => {
269 let prev_value = kvs.remove(&key);
270 let deleted = if prev_value.is_some() { 1 } else { 0 };
271 TxnOpResponse::ResponseDelete(DeleteRangeResponse {
272 deleted,
273 prev_kvs: vec![],
274 })
275 }
276 };
277
278 let responses: Vec<_> = if succeeded { success } else { failure }
279 .into_iter()
280 .map(do_txn)
281 .collect();
282
283 Ok(TxnResponse {
284 succeeded,
285 responses,
286 })
287 }
288
289 fn max_txn_ops(&self) -> usize {
290 usize::MAX
291 }
292}
293
294impl<T: ErrorExt + Send + Sync + 'static> ResettableKvBackend for MemoryKvBackend<T> {
295 fn reset(&self) {
296 self.clear();
297 }
298
299 fn as_kv_backend_ref(self: Arc<Self>) -> KvBackendRef<T> {
300 self
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use std::sync::Arc;
307
308 use super::*;
309 use crate::error::Error;
310 use crate::kv_backend::test::{
311 prepare_kv, test_kv_batch_delete, test_kv_batch_get, test_kv_compare_and_put,
312 test_kv_delete_range, test_kv_put, test_kv_range, test_kv_range_2, test_txn_compare_equal,
313 test_txn_compare_greater, test_txn_compare_less, test_txn_compare_not_equal,
314 test_txn_one_compare_op, text_txn_multi_compare_op,
315 };
316
317 async fn mock_mem_store_with_data() -> MemoryKvBackend<Error> {
318 let kv_backend = MemoryKvBackend::<Error>::new();
319 prepare_kv(&kv_backend).await;
320
321 kv_backend
322 }
323
324 #[tokio::test]
325 async fn test_put() {
326 let kv_backend = mock_mem_store_with_data().await;
327
328 test_kv_put(&kv_backend).await;
329 }
330
331 #[tokio::test]
332 async fn test_range() {
333 let kv_backend = mock_mem_store_with_data().await;
334
335 test_kv_range(&kv_backend).await;
336 }
337
338 #[tokio::test]
339 async fn test_range_2() {
340 let kv = MemoryKvBackend::<Error>::new();
341
342 test_kv_range_2(&kv).await;
343 }
344
345 #[tokio::test]
346 async fn test_batch_get() {
347 let kv_backend = mock_mem_store_with_data().await;
348
349 test_kv_batch_get(&kv_backend).await;
350 }
351
352 #[tokio::test(flavor = "multi_thread")]
353 async fn test_compare_and_put() {
354 let kv_backend = Arc::new(MemoryKvBackend::<Error>::new());
355
356 test_kv_compare_and_put(kv_backend).await;
357 }
358
359 #[tokio::test]
360 async fn test_delete_range() {
361 let kv_backend = mock_mem_store_with_data().await;
362
363 test_kv_delete_range(&kv_backend).await;
364 }
365
366 #[tokio::test]
367 async fn test_batch_delete() {
368 let kv_backend = mock_mem_store_with_data().await;
369
370 test_kv_batch_delete(&kv_backend).await;
371 }
372
373 #[tokio::test]
374 async fn test_memory_txn() {
375 let kv_backend = MemoryKvBackend::<Error>::new();
376 test_txn_one_compare_op(&kv_backend).await;
377 text_txn_multi_compare_op(&kv_backend).await;
378 test_txn_compare_equal(&kv_backend).await;
379 test_txn_compare_greater(&kv_backend).await;
380 test_txn_compare_less(&kv_backend).await;
381 test_txn_compare_not_equal(&kv_backend).await;
382 }
383}