common_procedure/local/
rwlock.rsuse std::collections::HashMap;
use std::hash::Hash;
use std::sync::{Arc, Mutex};
use tokio::sync::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
pub enum OwnedKeyRwLockGuard {
Read { _guard: OwnedRwLockReadGuard<()> },
Write { _guard: OwnedRwLockWriteGuard<()> },
}
impl From<OwnedRwLockReadGuard<()>> for OwnedKeyRwLockGuard {
fn from(guard: OwnedRwLockReadGuard<()>) -> Self {
OwnedKeyRwLockGuard::Read { _guard: guard }
}
}
impl From<OwnedRwLockWriteGuard<()>> for OwnedKeyRwLockGuard {
fn from(guard: OwnedRwLockWriteGuard<()>) -> Self {
OwnedKeyRwLockGuard::Write { _guard: guard }
}
}
#[derive(Debug)]
pub struct KeyRwLock<K> {
inner: Mutex<HashMap<K, Arc<RwLock<()>>>>,
}
impl<K> KeyRwLock<K>
where
K: Eq + Hash + Clone,
{
pub fn new() -> Self {
KeyRwLock {
inner: Default::default(),
}
}
pub async fn read(&self, key: K) -> OwnedRwLockReadGuard<()> {
let lock = {
let mut locks = self.inner.lock().unwrap();
locks.entry(key).or_default().clone()
};
lock.read_owned().await
}
pub async fn write(&self, key: K) -> OwnedRwLockWriteGuard<()> {
let lock = {
let mut locks = self.inner.lock().unwrap();
locks.entry(key).or_default().clone()
};
lock.write_owned().await
}
pub fn clean_keys<'a>(&'a self, iter: impl IntoIterator<Item = &'a K>) {
let mut locks = self.inner.lock().unwrap();
let mut keys = Vec::new();
for key in iter {
if let Some(lock) = locks.get(key) {
if lock.try_write().is_ok() {
debug_assert_eq!(Arc::weak_count(lock), 0);
if Arc::strong_count(lock) == 1 {
keys.push(key);
}
}
}
}
for key in keys {
locks.remove(key);
}
}
}
#[cfg(test)]
impl<K> KeyRwLock<K>
where
K: Eq + Hash + Clone,
{
pub fn try_read(&self, key: K) -> Result<OwnedRwLockReadGuard<()>, tokio::sync::TryLockError> {
let lock = {
let mut locks = self.inner.lock().unwrap();
locks.entry(key).or_default().clone()
};
lock.try_read_owned()
}
pub fn try_write(
&self,
key: K,
) -> Result<OwnedRwLockWriteGuard<()>, tokio::sync::TryLockError> {
let lock = {
let mut locks = self.inner.lock().unwrap();
locks.entry(key).or_default().clone()
};
lock.try_write_owned()
}
pub fn len(&self) -> usize {
self.inner.lock().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_naive() {
let lock_key = KeyRwLock::new();
{
let _guard = lock_key.read("test1").await;
assert_eq!(lock_key.len(), 1);
assert!(lock_key.try_read("test1").is_ok());
assert!(lock_key.try_write("test1").is_err());
}
{
let _guard0 = lock_key.write("test2").await;
let _guard = lock_key.write("test1").await;
assert_eq!(lock_key.len(), 2);
assert!(lock_key.try_read("test1").is_err());
assert!(lock_key.try_write("test1").is_err());
}
assert_eq!(lock_key.len(), 2);
lock_key.clean_keys(&vec!["test1", "test2"]);
assert!(lock_key.is_empty());
let mut guards = Vec::new();
for key in ["test1", "test2"] {
guards.push(lock_key.read(key).await);
}
while !guards.is_empty() {
guards.pop();
}
lock_key.clean_keys(vec![&"test1", &"test2"]);
assert_eq!(lock_key.len(), 0);
}
#[tokio::test]
async fn test_clean_keys() {
let lock_key = KeyRwLock::<&str>::new();
{
let rwlock = {
lock_key
.inner
.lock()
.unwrap()
.entry("test")
.or_default()
.clone()
};
assert_eq!(Arc::strong_count(&rwlock), 2);
let _guard = rwlock.read_owned().await;
{
let inner = lock_key.inner.lock().unwrap();
let rwlock = inner.get("test").unwrap();
assert_eq!(Arc::strong_count(rwlock), 2);
}
}
{
let rwlock = {
lock_key
.inner
.lock()
.unwrap()
.entry("test")
.or_default()
.clone()
};
assert_eq!(Arc::strong_count(&rwlock), 2);
let _guard = rwlock.write_owned().await;
{
let inner = lock_key.inner.lock().unwrap();
let rwlock = inner.get("test").unwrap();
assert_eq!(Arc::strong_count(rwlock), 2);
}
}
{
let inner = lock_key.inner.lock().unwrap();
let rwlock = inner.get("test").unwrap();
assert_eq!(Arc::strong_count(rwlock), 1);
}
let rwlock = {
lock_key
.inner
.lock()
.unwrap()
.entry("test")
.or_default()
.clone()
};
assert_eq!(Arc::strong_count(&rwlock), 2);
lock_key.clean_keys(vec![&"test"]);
{
let inner = lock_key.inner.lock().unwrap();
inner.get("test").unwrap();
}
}
}