1use std::ops::Range;
16use std::sync::Arc;
17
18use snafu::ensure;
19use tokio::sync::Mutex;
20
21use crate::error::{self, Result};
22use crate::kv_backend::KvBackendRef;
23use crate::rpc::store::CompareAndPutRequest;
24
25pub type SequenceRef = Arc<Sequence>;
26
27pub(crate) const SEQ_PREFIX: &str = "__meta_seq";
28
29pub struct SequenceBuilder {
30 name: String,
31 initial: u64,
32 step: u64,
33 generator: KvBackendRef,
34 max: u64,
35}
36
37fn seq_name(name: impl AsRef<str>) -> String {
38 format!("{}-{}", SEQ_PREFIX, name.as_ref())
39}
40
41impl SequenceBuilder {
42 pub fn new(name: impl AsRef<str>, generator: KvBackendRef) -> Self {
43 Self {
44 name: seq_name(name),
45 initial: 0,
46 step: 1,
47 generator,
48 max: u64::MAX,
49 }
50 }
51
52 pub fn initial(self, initial: u64) -> Self {
53 Self { initial, ..self }
54 }
55
56 pub fn step(self, step: u64) -> Self {
57 Self { step, ..self }
58 }
59
60 pub fn max(self, max: u64) -> Self {
61 Self { max, ..self }
62 }
63
64 pub fn build(self) -> Sequence {
65 Sequence {
66 inner: Mutex::new(Inner {
67 name: self.name,
68 generator: self.generator,
69 initial: self.initial,
70 next: self.initial,
71 step: self.step,
72 range: None,
73 force_quit: 1024,
74 max: self.max,
75 }),
76 }
77 }
78}
79
80pub struct Sequence {
81 inner: Mutex<Inner>,
82}
83
84impl Sequence {
85 pub async fn next(&self) -> Result<u64> {
86 let mut inner = self.inner.lock().await;
87 inner.next().await
88 }
89
90 pub async fn min_max(&self) -> Range<u64> {
91 let inner = self.inner.lock().await;
92 inner.initial..inner.max
93 }
94}
95
96struct Inner {
97 name: String,
98 generator: KvBackendRef,
99 initial: u64,
101 next: u64,
104 step: u64,
106 range: Option<Range<u64>>,
108 force_quit: usize,
110 max: u64,
111}
112
113impl Inner {
114 pub async fn next(&mut self) -> Result<u64> {
118 for _ in 0..self.force_quit {
119 match &self.range {
120 Some(range) => {
121 if range.contains(&self.next) {
122 let res = Ok(self.next);
123 self.next += 1;
124 return res;
125 }
126 self.range = None;
127 }
128 None => {
129 let range = self.next_range().await?;
130 self.next = range.start;
131 self.range = Some(range);
132 }
133 }
134 }
135
136 error::NextSequenceSnafu {
137 err_msg: format!("{}.next()", &self.name),
138 }
139 .fail()
140 }
141
142 pub async fn next_range(&self) -> Result<Range<u64>> {
143 let key = self.name.as_bytes();
144 let mut start = self.next;
145
146 let mut expect = if start == self.initial {
147 vec![]
148 } else {
149 u64::to_le_bytes(start).to_vec()
150 };
151
152 for _ in 0..self.force_quit {
153 let step = self.step.min(self.max - start);
154
155 ensure!(
156 step > 0,
157 error::NextSequenceSnafu {
158 err_msg: format!("next sequence exhausted, max: {}", self.max)
159 }
160 );
161
162 let value = u64::to_le_bytes(start + step);
164
165 let req = CompareAndPutRequest {
166 key: key.to_vec(),
167 expect,
168 value: value.to_vec(),
169 };
170
171 let res = self.generator.compare_and_put(req).await?;
172
173 if !res.success {
174 if let Some(kv) = res.prev_kv {
175 let v: [u8; 8] = match kv.value.clone().try_into() {
176 Ok(a) => a,
177 Err(v) => {
178 return error::UnexpectedSequenceValueSnafu {
179 err_msg: format!("Not a valid u64 for '{}': {v:?}", self.name),
180 }
181 .fail()
182 }
183 };
184 let v = u64::from_le_bytes(v);
185 start = v.max(self.initial);
187 expect = kv.value;
188 } else {
189 start = self.initial;
190 expect = vec![];
191 }
192 continue;
193 }
194
195 return Ok(Range {
196 start,
197 end: start + step,
198 });
199 }
200
201 error::NextSequenceSnafu {
202 err_msg: format!("{}.next_range()", &self.name),
203 }
204 .fail()
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use std::any::Any;
211 use std::collections::HashSet;
212 use std::sync::Arc;
213
214 use itertools::{Itertools, MinMaxResult};
215 use tokio::sync::mpsc;
216
217 use super::*;
218 use crate::error::Error;
219 use crate::kv_backend::memory::MemoryKvBackend;
220 use crate::kv_backend::{KvBackend, TxnService};
221 use crate::rpc::store::{
222 BatchDeleteRequest, BatchDeleteResponse, BatchGetRequest, BatchGetResponse,
223 BatchPutRequest, BatchPutResponse, CompareAndPutResponse, DeleteRangeRequest,
224 DeleteRangeResponse, PutRequest, PutResponse, RangeRequest, RangeResponse,
225 };
226
227 #[tokio::test]
228 async fn test_sequence_with_existed_value() {
229 async fn test(exist: u64, expected: Vec<u64>) {
230 let kv_backend = Arc::new(MemoryKvBackend::default());
231
232 let exist = u64::to_le_bytes(exist);
233 kv_backend
234 .put(PutRequest::new().with_key(seq_name("s")).with_value(exist))
235 .await
236 .unwrap();
237
238 let initial = 100;
239 let seq = SequenceBuilder::new("s", kv_backend)
240 .initial(initial)
241 .build();
242
243 let mut actual = Vec::with_capacity(expected.len());
244 for _ in 0..expected.len() {
245 actual.push(seq.next().await.unwrap());
246 }
247 assert_eq!(actual, expected);
248 }
249
250 test(1, vec![100, 101, 102]).await;
252 test(100, vec![100, 101, 102]).await;
253
254 test(200, vec![200, 201, 202]).await;
256 }
257
258 #[tokio::test(flavor = "multi_thread")]
259 async fn test_sequence_with_contention() {
260 let seq = Arc::new(
261 SequenceBuilder::new("s", Arc::new(MemoryKvBackend::default()))
262 .initial(1024)
263 .build(),
264 );
265
266 let (tx, mut rx) = mpsc::unbounded_channel();
267 for _ in 0..10 {
269 tokio::spawn({
270 let seq = seq.clone();
271 let tx = tx.clone();
272 async move {
273 for _ in 0..100 {
274 tx.send(seq.next().await.unwrap()).unwrap()
275 }
276 }
277 });
278 }
279
280 let mut nums = HashSet::new();
282 let mut c = 0;
283 while c < 1000
284 && let Some(x) = rx.recv().await
285 {
286 nums.insert(x);
287 c += 1;
288 }
289 assert_eq!(nums.len(), 1000);
290 let MinMaxResult::MinMax(min, max) = nums.iter().minmax() else {
291 unreachable!("nums has more than one elements");
292 };
293 assert_eq!(*min, 1024);
294 assert_eq!(*max, 2023);
295 }
296
297 #[tokio::test]
298 async fn test_sequence() {
299 let kv_backend = Arc::new(MemoryKvBackend::default());
300 let initial = 1024;
301 let seq = SequenceBuilder::new("test_seq", kv_backend)
302 .initial(initial)
303 .build();
304
305 for i in initial..initial + 100 {
306 assert_eq!(i, seq.next().await.unwrap());
307 }
308 }
309
310 #[tokio::test]
311 async fn test_sequence_out_of_rage() {
312 let seq = SequenceBuilder::new("test_seq", Arc::new(MemoryKvBackend::default()))
313 .initial(u64::MAX - 10)
314 .step(10)
315 .build();
316
317 for _ in 0..10 {
318 let _ = seq.next().await.unwrap();
319 }
320
321 let res = seq.next().await;
322 assert!(res.is_err());
323 assert!(matches!(res.unwrap_err(), Error::NextSequence { .. }))
324 }
325
326 #[tokio::test]
327 async fn test_sequence_force_quit() {
328 struct Noop;
329
330 impl TxnService for Noop {
331 type Error = Error;
332 }
333
334 #[async_trait::async_trait]
335 impl KvBackend for Noop {
336 fn name(&self) -> &str {
337 "Noop"
338 }
339
340 fn as_any(&self) -> &dyn Any {
341 self
342 }
343
344 async fn range(&self, _: RangeRequest) -> Result<RangeResponse> {
345 unreachable!()
346 }
347
348 async fn put(&self, _: PutRequest) -> Result<PutResponse> {
349 unreachable!()
350 }
351
352 async fn batch_put(&self, _: BatchPutRequest) -> Result<BatchPutResponse> {
353 unreachable!()
354 }
355
356 async fn batch_get(&self, _: BatchGetRequest) -> Result<BatchGetResponse> {
357 unreachable!()
358 }
359
360 async fn compare_and_put(
361 &self,
362 _: CompareAndPutRequest,
363 ) -> Result<CompareAndPutResponse> {
364 Ok(CompareAndPutResponse::default())
365 }
366
367 async fn delete_range(&self, _: DeleteRangeRequest) -> Result<DeleteRangeResponse> {
368 unreachable!()
369 }
370
371 async fn batch_delete(&self, _: BatchDeleteRequest) -> Result<BatchDeleteResponse> {
372 unreachable!()
373 }
374 }
375
376 let seq = SequenceBuilder::new("test_seq", Arc::new(Noop)).build();
377
378 let next = seq.next().await;
379 assert!(next.is_err());
380 }
381}