common_meta/
sequence.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::ops::Range;
16use std::sync::Arc;
17
18use snafu::ensure;
19use tokio::sync::Mutex;
20
21use crate::error::{self, Result};
22use crate::kv_backend::KvBackendRef;
23use crate::rpc::store::CompareAndPutRequest;
24
25pub type SequenceRef = Arc<Sequence>;
26
27pub(crate) const SEQ_PREFIX: &str = "__meta_seq";
28
29pub struct SequenceBuilder {
30    name: String,
31    initial: u64,
32    step: u64,
33    generator: KvBackendRef,
34    max: u64,
35}
36
37fn seq_name(name: impl AsRef<str>) -> String {
38    format!("{}-{}", SEQ_PREFIX, name.as_ref())
39}
40
41impl SequenceBuilder {
42    pub fn new(name: impl AsRef<str>, generator: KvBackendRef) -> Self {
43        Self {
44            name: seq_name(name),
45            initial: 0,
46            step: 1,
47            generator,
48            max: u64::MAX,
49        }
50    }
51
52    pub fn initial(self, initial: u64) -> Self {
53        Self { initial, ..self }
54    }
55
56    pub fn step(self, step: u64) -> Self {
57        Self { step, ..self }
58    }
59
60    pub fn max(self, max: u64) -> Self {
61        Self { max, ..self }
62    }
63
64    pub fn build(self) -> Sequence {
65        Sequence {
66            inner: Mutex::new(Inner {
67                name: self.name,
68                generator: self.generator,
69                initial: self.initial,
70                next: self.initial,
71                step: self.step,
72                range: None,
73                force_quit: 1024,
74                max: self.max,
75            }),
76        }
77    }
78}
79
80pub struct Sequence {
81    inner: Mutex<Inner>,
82}
83
84impl Sequence {
85    pub async fn next(&self) -> Result<u64> {
86        let mut inner = self.inner.lock().await;
87        inner.next().await
88    }
89
90    pub async fn min_max(&self) -> Range<u64> {
91        let inner = self.inner.lock().await;
92        inner.initial..inner.max
93    }
94}
95
96struct Inner {
97    name: String,
98    generator: KvBackendRef,
99    // The initial(minimal) value of the sequence.
100    initial: u64,
101    // The next available sequences(if it is in the range,
102    // otherwise it need to fetch from generator again).
103    next: u64,
104    // Fetch several sequences at once: [start, start + step).
105    step: u64,
106    // The range of available sequences for the local cache.
107    range: Option<Range<u64>>,
108    // Used to avoid dead loops.
109    force_quit: usize,
110    max: u64,
111}
112
113impl Inner {
114    /// 1. returns the `next` value directly if it is in the `range` (local cache)
115    /// 2. fetch(CAS) next `range` from the `generator`
116    /// 3. jump to step 1
117    pub async fn next(&mut self) -> Result<u64> {
118        for _ in 0..self.force_quit {
119            match &self.range {
120                Some(range) => {
121                    if range.contains(&self.next) {
122                        let res = Ok(self.next);
123                        self.next += 1;
124                        return res;
125                    }
126                    self.range = None;
127                }
128                None => {
129                    let range = self.next_range().await?;
130                    self.next = range.start;
131                    self.range = Some(range);
132                }
133            }
134        }
135
136        error::NextSequenceSnafu {
137            err_msg: format!("{}.next()", &self.name),
138        }
139        .fail()
140    }
141
142    pub async fn next_range(&self) -> Result<Range<u64>> {
143        let key = self.name.as_bytes();
144        let mut start = self.next;
145
146        let mut expect = if start == self.initial {
147            vec![]
148        } else {
149            u64::to_le_bytes(start).to_vec()
150        };
151
152        for _ in 0..self.force_quit {
153            let step = self.step.min(self.max - start);
154
155            ensure!(
156                step > 0,
157                error::NextSequenceSnafu {
158                    err_msg: format!("next sequence exhausted, max: {}", self.max)
159                }
160            );
161
162            // No overflow: step <= self.max - start -> step + start <= self.max <= u64::MAX
163            let value = u64::to_le_bytes(start + step);
164
165            let req = CompareAndPutRequest {
166                key: key.to_vec(),
167                expect,
168                value: value.to_vec(),
169            };
170
171            let res = self.generator.compare_and_put(req).await?;
172
173            if !res.success {
174                if let Some(kv) = res.prev_kv {
175                    let v: [u8; 8] = match kv.value.clone().try_into() {
176                        Ok(a) => a,
177                        Err(v) => {
178                            return error::UnexpectedSequenceValueSnafu {
179                                err_msg: format!("Not a valid u64 for '{}': {v:?}", self.name),
180                            }
181                            .fail()
182                        }
183                    };
184                    let v = u64::from_le_bytes(v);
185                    // If the existed value is smaller than the initial, we should start from the initial.
186                    start = v.max(self.initial);
187                    expect = kv.value;
188                } else {
189                    start = self.initial;
190                    expect = vec![];
191                }
192                continue;
193            }
194
195            return Ok(Range {
196                start,
197                end: start + step,
198            });
199        }
200
201        error::NextSequenceSnafu {
202            err_msg: format!("{}.next_range()", &self.name),
203        }
204        .fail()
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use std::any::Any;
211    use std::collections::HashSet;
212    use std::sync::Arc;
213
214    use itertools::{Itertools, MinMaxResult};
215    use tokio::sync::mpsc;
216
217    use super::*;
218    use crate::error::Error;
219    use crate::kv_backend::memory::MemoryKvBackend;
220    use crate::kv_backend::{KvBackend, TxnService};
221    use crate::rpc::store::{
222        BatchDeleteRequest, BatchDeleteResponse, BatchGetRequest, BatchGetResponse,
223        BatchPutRequest, BatchPutResponse, CompareAndPutResponse, DeleteRangeRequest,
224        DeleteRangeResponse, PutRequest, PutResponse, RangeRequest, RangeResponse,
225    };
226
227    #[tokio::test]
228    async fn test_sequence_with_existed_value() {
229        async fn test(exist: u64, expected: Vec<u64>) {
230            let kv_backend = Arc::new(MemoryKvBackend::default());
231
232            let exist = u64::to_le_bytes(exist);
233            kv_backend
234                .put(PutRequest::new().with_key(seq_name("s")).with_value(exist))
235                .await
236                .unwrap();
237
238            let initial = 100;
239            let seq = SequenceBuilder::new("s", kv_backend)
240                .initial(initial)
241                .build();
242
243            let mut actual = Vec::with_capacity(expected.len());
244            for _ in 0..expected.len() {
245                actual.push(seq.next().await.unwrap());
246            }
247            assert_eq!(actual, expected);
248        }
249
250        // put a value not greater than the "initial", the sequence should start from "initial"
251        test(1, vec![100, 101, 102]).await;
252        test(100, vec![100, 101, 102]).await;
253
254        // put a value greater than the "initial", the sequence should start from the put value
255        test(200, vec![200, 201, 202]).await;
256    }
257
258    #[tokio::test(flavor = "multi_thread")]
259    async fn test_sequence_with_contention() {
260        let seq = Arc::new(
261            SequenceBuilder::new("s", Arc::new(MemoryKvBackend::default()))
262                .initial(1024)
263                .build(),
264        );
265
266        let (tx, mut rx) = mpsc::unbounded_channel();
267        // Spawn 10 tasks to concurrently get the next sequence. Each task will get 100 sequences.
268        for _ in 0..10 {
269            tokio::spawn({
270                let seq = seq.clone();
271                let tx = tx.clone();
272                async move {
273                    for _ in 0..100 {
274                        tx.send(seq.next().await.unwrap()).unwrap()
275                    }
276                }
277            });
278        }
279
280        // Test that we get 1000 unique sequences, and start from 1024 to 2023.
281        let mut nums = HashSet::new();
282        let mut c = 0;
283        while c < 1000
284            && let Some(x) = rx.recv().await
285        {
286            nums.insert(x);
287            c += 1;
288        }
289        assert_eq!(nums.len(), 1000);
290        let MinMaxResult::MinMax(min, max) = nums.iter().minmax() else {
291            unreachable!("nums has more than one elements");
292        };
293        assert_eq!(*min, 1024);
294        assert_eq!(*max, 2023);
295    }
296
297    #[tokio::test]
298    async fn test_sequence() {
299        let kv_backend = Arc::new(MemoryKvBackend::default());
300        let initial = 1024;
301        let seq = SequenceBuilder::new("test_seq", kv_backend)
302            .initial(initial)
303            .build();
304
305        for i in initial..initial + 100 {
306            assert_eq!(i, seq.next().await.unwrap());
307        }
308    }
309
310    #[tokio::test]
311    async fn test_sequence_out_of_rage() {
312        let seq = SequenceBuilder::new("test_seq", Arc::new(MemoryKvBackend::default()))
313            .initial(u64::MAX - 10)
314            .step(10)
315            .build();
316
317        for _ in 0..10 {
318            let _ = seq.next().await.unwrap();
319        }
320
321        let res = seq.next().await;
322        assert!(res.is_err());
323        assert!(matches!(res.unwrap_err(), Error::NextSequence { .. }))
324    }
325
326    #[tokio::test]
327    async fn test_sequence_force_quit() {
328        struct Noop;
329
330        impl TxnService for Noop {
331            type Error = Error;
332        }
333
334        #[async_trait::async_trait]
335        impl KvBackend for Noop {
336            fn name(&self) -> &str {
337                "Noop"
338            }
339
340            fn as_any(&self) -> &dyn Any {
341                self
342            }
343
344            async fn range(&self, _: RangeRequest) -> Result<RangeResponse> {
345                unreachable!()
346            }
347
348            async fn put(&self, _: PutRequest) -> Result<PutResponse> {
349                unreachable!()
350            }
351
352            async fn batch_put(&self, _: BatchPutRequest) -> Result<BatchPutResponse> {
353                unreachable!()
354            }
355
356            async fn batch_get(&self, _: BatchGetRequest) -> Result<BatchGetResponse> {
357                unreachable!()
358            }
359
360            async fn compare_and_put(
361                &self,
362                _: CompareAndPutRequest,
363            ) -> Result<CompareAndPutResponse> {
364                Ok(CompareAndPutResponse::default())
365            }
366
367            async fn delete_range(&self, _: DeleteRangeRequest) -> Result<DeleteRangeResponse> {
368                unreachable!()
369            }
370
371            async fn batch_delete(&self, _: BatchDeleteRequest) -> Result<BatchDeleteResponse> {
372                unreachable!()
373            }
374        }
375
376        let seq = SequenceBuilder::new("test_seq", Arc::new(Noop)).build();
377
378        let next = seq.next().await;
379        assert!(next.is_err());
380    }
381}