common_procedure/store/
util.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16
17use async_stream::try_stream;
18use futures::{Stream, TryStreamExt};
19use snafu::{ensure, ResultExt};
20
21use crate::error;
22use crate::error::Result;
23use crate::store::state_store::KeySet;
24
25struct CollectingState {
26    pairs: Vec<(String, Vec<u8>)>,
27}
28
29fn parse_segments(segments: Vec<(String, Vec<u8>)>, prefix: &str) -> Result<Vec<(usize, Vec<u8>)>> {
30    segments
31        .into_iter()
32        .map(|(key, value)| {
33            let suffix = key.trim_start_matches(prefix);
34            let index = suffix
35                .parse::<usize>()
36                .context(error::ParseSegmentKeySnafu { key })?;
37
38            Ok((index, value))
39        })
40        .collect::<Result<Vec<_>>>()
41}
42
43/// Merges multiple values into a single key-value pair.
44/// Returns an error if:
45/// - Part values are lost.
46/// - Failed to parse the key of segment.
47fn merge_multiple_values(
48    CollectingState { mut pairs }: CollectingState,
49) -> Result<(KeySet, Vec<u8>)> {
50    if pairs.len() == 1 {
51        // Safety: must exist.
52        let (key, value) = pairs.into_iter().next().unwrap();
53        Ok((KeySet::new(key, 0), value))
54    } else {
55        let segments = pairs.split_off(1);
56        // Safety: must exist.
57        let (key, value) = pairs.into_iter().next().unwrap();
58        let prefix = KeySet::with_prefix(&key);
59        let mut parsed_segments = parse_segments(segments, &prefix)?;
60        parsed_segments.sort_unstable_by(|a, b| a.0.cmp(&b.0));
61
62        // Safety: `parsed_segments` must larger than 0.
63        let segment_num = parsed_segments.last().unwrap().0;
64        ensure!(
65            // The segment index start from 1.
66            parsed_segments.len() == segment_num,
67            error::UnexpectedSnafu {
68                err_msg: format!(
69                    "Corrupted segment keys, parsed segment indexes: {:?}",
70                    parsed_segments
71                        .into_iter()
72                        .map(|(key, _)| key)
73                        .collect::<Vec<_>>()
74                )
75            }
76        );
77
78        let segment_values = parsed_segments.into_iter().map(|(_, value)| value);
79        let mut values = Vec::with_capacity(segment_values.len() + 1);
80        values.push(value);
81        values.extend(segment_values);
82
83        Ok((KeySet::new(key, segment_num), values.concat()))
84    }
85}
86
87impl CollectingState {
88    fn new(key: String, value: Vec<u8>) -> CollectingState {
89        Self {
90            pairs: vec![(key, value)],
91        }
92    }
93
94    fn push(&mut self, key: String, value: Vec<u8>) {
95        self.pairs.push((key, value));
96    }
97
98    fn key(&self) -> &str {
99        self.pairs[0].0.as_str()
100    }
101}
102
103type Upstream = dyn Stream<Item = Result<(String, Vec<u8>)>> + Send;
104
105/// Merges multiple values that have the same prefix of the key
106/// from `upstream` into a single value.
107pub fn multiple_value_stream(
108    mut upstream: Pin<Box<Upstream>>,
109) -> impl Stream<Item = Result<(KeySet, Vec<u8>)>> {
110    try_stream! {
111        let mut collecting: Option<CollectingState> = None;
112        while let Some((key, value)) = upstream.try_next().await? {
113            match collecting.take() {
114                Some(mut current) => {
115                    if key.starts_with(current.key()) {
116                        // Pushes the key value pair into `collecting`.
117                        current.push(key, value);
118                        collecting = Some(current);
119                    } else {
120                        // Starts to collect next key value pair.
121                        collecting = Some(CollectingState::new(key, value));
122                        yield merge_multiple_values(current)?;
123                    }
124                }
125                None => collecting = Some(CollectingState::new(key, value)),
126            }
127        }
128        if let Some(current) = collecting.take() {
129            yield merge_multiple_values(current)?
130        }
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use std::assert_matches::assert_matches;
137
138    use futures::stream::{self};
139    use futures::TryStreamExt;
140
141    use super::*;
142    use crate::error::{self};
143
144    #[test]
145    fn test_key_set_keys() {
146        let key = KeySet::new("baz".to_string(), 3);
147        let keys = key.keys();
148        assert_eq!(keys.len(), 4);
149        assert_eq!(&keys[0], "baz");
150        assert_eq!(&keys[1], &KeySet::with_segment_suffix("baz", 1));
151    }
152
153    #[tokio::test]
154    async fn test_merge_multiple_values() {
155        let upstream = stream::iter(vec![
156            Ok(("foo".to_string(), vec![0, 1, 2, 3])),
157            Ok(("foo/0002".to_string(), vec![6, 7])),
158            Ok(("foo/0003".to_string(), vec![8])),
159            Ok(("foo/0001".to_string(), vec![4, 5])),
160            Ok(("bar".to_string(), vec![0, 1, 2, 3])),
161            Ok(("baz".to_string(), vec![0, 1, 2, 3])),
162            Ok(("baz/0003".to_string(), vec![8])),
163            Ok(("baz/0001".to_string(), vec![4, 5])),
164            Ok(("baz/0002".to_string(), vec![6, 7])),
165        ]);
166        let mut stream = Box::pin(multiple_value_stream(Box::pin(upstream)));
167        let (key, value) = stream.try_next().await.unwrap().unwrap();
168        let keys = key.keys();
169        assert_eq!(keys[0], "foo");
170        assert_eq!(keys.len(), 4);
171        assert_eq!(value, vec![0, 1, 2, 3, 4, 5, 6, 7, 8]);
172        let (key, value) = stream.try_next().await.unwrap().unwrap();
173        let keys = key.keys();
174        assert_eq!(keys[0], "bar");
175        assert_eq!(keys.len(), 1);
176        assert_eq!(value, vec![0, 1, 2, 3]);
177        let (key, value) = stream.try_next().await.unwrap().unwrap();
178        let keys = key.keys();
179        assert_eq!(keys[0], "baz");
180        assert_eq!(keys.len(), 4);
181        assert_eq!(value, vec![0, 1, 2, 3, 4, 5, 6, 7, 8]);
182        assert!(stream.try_next().await.unwrap().is_none());
183        // Call again
184        assert!(stream.try_next().await.unwrap().is_none());
185    }
186
187    #[tokio::test]
188    async fn test_empty_upstream() {
189        let upstream = stream::iter(vec![]);
190        let mut stream = Box::pin(multiple_value_stream(Box::pin(upstream)));
191        assert!(stream.try_next().await.unwrap().is_none());
192        // Call again
193        assert!(stream.try_next().await.unwrap().is_none());
194    }
195
196    #[tokio::test]
197    async fn test_multiple_values_stream_err() {
198        let upstream = stream::iter(vec![
199            Err(error::UnexpectedSnafu { err_msg: "mock" }.build()),
200            Ok(("foo".to_string(), vec![0, 1, 2, 3])),
201            Ok(("foo/0001".to_string(), vec![4, 5])),
202        ]);
203        let mut stream = Box::pin(multiple_value_stream(Box::pin(upstream)));
204        let err = stream.try_next().await.unwrap_err();
205        assert_matches!(err, error::Error::Unexpected { .. });
206
207        let upstream = stream::iter(vec![
208            Ok(("foo".to_string(), vec![0, 1, 2, 3])),
209            Ok(("foo/0001".to_string(), vec![4, 5])),
210            Err(error::UnexpectedSnafu { err_msg: "mock" }.build()),
211        ]);
212        let mut stream = Box::pin(multiple_value_stream(Box::pin(upstream)));
213        let err = stream.try_next().await.unwrap_err();
214        assert_matches!(err, error::Error::Unexpected { .. });
215    }
216}