common_procedure/
rwlock.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::hash::Hash;
17use std::sync::{Arc, Mutex};
18
19use tokio::sync::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
20
21/// A guard that owns a read or write lock on a key.
22///
23/// This enum wraps either a read or write lock guard obtained from a `KeyRwLock`.
24/// The guard is automatically released when it is dropped.
25pub enum OwnedKeyRwLockGuard {
26    /// Represents a shared read lock on a key.
27    /// Multiple read locks can be held simultaneously for the same key.
28    Read { _guard: OwnedRwLockReadGuard<()> },
29
30    /// Represents an exclusive write lock on a key.
31    /// Only one write lock can be held at a time for a given key,
32    /// and no read locks can be held simultaneously with a write lock.
33    Write { _guard: OwnedRwLockWriteGuard<()> },
34}
35
36impl From<OwnedRwLockReadGuard<()>> for OwnedKeyRwLockGuard {
37    fn from(guard: OwnedRwLockReadGuard<()>) -> Self {
38        OwnedKeyRwLockGuard::Read { _guard: guard }
39    }
40}
41
42impl From<OwnedRwLockWriteGuard<()>> for OwnedKeyRwLockGuard {
43    fn from(guard: OwnedRwLockWriteGuard<()>) -> Self {
44        OwnedKeyRwLockGuard::Write { _guard: guard }
45    }
46}
47
48/// Locks based on a key, allowing other keys to lock independently.
49#[derive(Debug, Default)]
50pub struct KeyRwLock<K> {
51    /// The inner map of locks for specific keys.
52    inner: Mutex<HashMap<K, Arc<RwLock<()>>>>,
53}
54
55impl<K> KeyRwLock<K>
56where
57    K: Eq + Hash + Clone,
58{
59    pub fn new() -> Self {
60        KeyRwLock {
61            inner: Default::default(),
62        }
63    }
64
65    /// Locks the key with shared read access, returning a guard.
66    pub async fn read(&self, key: K) -> OwnedRwLockReadGuard<()> {
67        let lock = {
68            let mut locks = self.inner.lock().unwrap();
69            locks.entry(key).or_default().clone()
70        };
71
72        lock.read_owned().await
73    }
74
75    /// Locks the key with exclusive write access, returning a guard.
76    pub async fn write(&self, key: K) -> OwnedRwLockWriteGuard<()> {
77        let lock = {
78            let mut locks = self.inner.lock().unwrap();
79            locks.entry(key).or_default().clone()
80        };
81
82        lock.write_owned().await
83    }
84
85    /// Clean up stale locks.
86    ///
87    /// Note: It only cleans a lock if
88    /// - Its strong ref count equals one.
89    /// - Able to acquire the write lock.
90    pub fn clean_keys<'a>(&'a self, iter: impl IntoIterator<Item = &'a K>) {
91        let mut locks = self.inner.lock().unwrap();
92        let mut keys = Vec::new();
93        for key in iter {
94            if let Some(lock) = locks.get(key) {
95                if lock.try_write().is_ok() {
96                    debug_assert_eq!(Arc::weak_count(lock), 0);
97                    // Ensures nobody keeps this ref.
98                    if Arc::strong_count(lock) == 1 {
99                        keys.push(key);
100                    }
101                }
102            }
103        }
104
105        for key in keys {
106            locks.remove(key);
107        }
108    }
109}
110
111#[cfg(test)]
112impl<K> KeyRwLock<K>
113where
114    K: Eq + Hash + Clone,
115{
116    /// Tries to lock the key with shared read access, returning immediately.
117    pub fn try_read(&self, key: K) -> Result<OwnedRwLockReadGuard<()>, tokio::sync::TryLockError> {
118        let lock = {
119            let mut locks = self.inner.lock().unwrap();
120            locks.entry(key).or_default().clone()
121        };
122
123        lock.try_read_owned()
124    }
125
126    /// Tries lock this key with exclusive write access, returning immediately.
127    pub fn try_write(
128        &self,
129        key: K,
130    ) -> Result<OwnedRwLockWriteGuard<()>, tokio::sync::TryLockError> {
131        let lock = {
132            let mut locks = self.inner.lock().unwrap();
133            locks.entry(key).or_default().clone()
134        };
135
136        lock.try_write_owned()
137    }
138
139    /// Returns number of keys.
140    pub fn len(&self) -> usize {
141        self.inner.lock().unwrap().len()
142    }
143
144    /// Returns true the inner map is empty.
145    pub fn is_empty(&self) -> bool {
146        self.len() == 0
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[tokio::test]
155    async fn test_naive() {
156        let lock_key = KeyRwLock::new();
157
158        {
159            let _guard = lock_key.read("test1").await;
160            assert_eq!(lock_key.len(), 1);
161            assert!(lock_key.try_read("test1").is_ok());
162            assert!(lock_key.try_write("test1").is_err());
163        }
164
165        {
166            let _guard0 = lock_key.write("test2").await;
167            let _guard = lock_key.write("test1").await;
168            assert_eq!(lock_key.len(), 2);
169            assert!(lock_key.try_read("test1").is_err());
170            assert!(lock_key.try_write("test1").is_err());
171        }
172
173        assert_eq!(lock_key.len(), 2);
174
175        lock_key.clean_keys(&vec!["test1", "test2"]);
176        assert!(lock_key.is_empty());
177
178        let mut guards = Vec::new();
179        for key in ["test1", "test2"] {
180            guards.push(lock_key.read(key).await);
181        }
182        while !guards.is_empty() {
183            guards.pop();
184        }
185        lock_key.clean_keys(vec![&"test1", &"test2"]);
186        assert_eq!(lock_key.len(), 0);
187    }
188
189    #[tokio::test]
190    async fn test_clean_keys() {
191        let lock_key = KeyRwLock::<&str>::new();
192        {
193            let rwlock = {
194                lock_key
195                    .inner
196                    .lock()
197                    .unwrap()
198                    .entry("test")
199                    .or_default()
200                    .clone()
201            };
202            assert_eq!(Arc::strong_count(&rwlock), 2);
203            let _guard = rwlock.read_owned().await;
204
205            {
206                let inner = lock_key.inner.lock().unwrap();
207                let rwlock = inner.get("test").unwrap();
208                assert_eq!(Arc::strong_count(rwlock), 2);
209            }
210        }
211
212        {
213            let rwlock = {
214                lock_key
215                    .inner
216                    .lock()
217                    .unwrap()
218                    .entry("test")
219                    .or_default()
220                    .clone()
221            };
222            assert_eq!(Arc::strong_count(&rwlock), 2);
223            let _guard = rwlock.write_owned().await;
224
225            {
226                let inner = lock_key.inner.lock().unwrap();
227                let rwlock = inner.get("test").unwrap();
228                assert_eq!(Arc::strong_count(rwlock), 2);
229            }
230        }
231
232        {
233            let inner = lock_key.inner.lock().unwrap();
234            let rwlock = inner.get("test").unwrap();
235            assert_eq!(Arc::strong_count(rwlock), 1);
236        }
237
238        // Someone has the ref of the rwlock, but it waits to be granted the lock.
239        let rwlock = {
240            lock_key
241                .inner
242                .lock()
243                .unwrap()
244                .entry("test")
245                .or_default()
246                .clone()
247        };
248        assert_eq!(Arc::strong_count(&rwlock), 2);
249        // However, One thread trying to remove the "test" key should have no effect.
250        lock_key.clean_keys(vec![&"test"]);
251        // Should get the rwlock.
252        {
253            let inner = lock_key.inner.lock().unwrap();
254            inner.get("test").unwrap();
255        }
256    }
257}