use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
use anymap2::SendSyncAnyMap;
#[derive(Default, Clone)]
pub struct Plugins {
inner: Arc<RwLock<SendSyncAnyMap>>,
}
impl Plugins {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(SendSyncAnyMap::new())),
}
}
pub fn insert<T: 'static + Send + Sync>(&self, value: T) {
let _ = self.write().insert(value);
}
pub fn get<T: 'static + Send + Sync + Clone>(&self) -> Option<T> {
self.read().get::<T>().cloned()
}
pub fn get_or_insert<T, F>(&self, f: F) -> T
where
T: 'static + Send + Sync + Clone,
F: FnOnce() -> T,
{
let mut binding = self.write();
if !binding.contains::<T>() {
binding.insert(f());
}
binding.get::<T>().cloned().unwrap()
}
pub fn map_mut<T: 'static + Send + Sync, F, R>(&self, mapper: F) -> R
where
F: FnOnce(Option<&mut T>) -> R,
{
let mut binding = self.write();
let opt = binding.get_mut::<T>();
mapper(opt)
}
pub fn map<T: 'static + Send + Sync, F, R>(&self, mapper: F) -> Option<R>
where
F: FnOnce(&T) -> R,
{
self.read().get::<T>().map(mapper)
}
pub fn len(&self) -> usize {
self.read().len()
}
pub fn is_empty(&self) -> bool {
self.read().is_empty()
}
fn read(&self) -> RwLockReadGuard<SendSyncAnyMap> {
self.inner.read().unwrap()
}
fn write(&self) -> RwLockWriteGuard<SendSyncAnyMap> {
self.inner.write().unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_plugins() {
#[derive(Debug, Clone)]
struct FooPlugin {
x: i32,
}
#[derive(Debug, Clone)]
struct BarPlugin {
y: String,
}
let plugins = Plugins::new();
let m = plugins.clone();
let thread1 = std::thread::spawn(move || {
m.insert(FooPlugin { x: 42 });
if let Some(foo) = m.get::<FooPlugin>() {
assert_eq!(foo.x, 42);
}
assert_eq!(m.map::<FooPlugin, _, _>(|foo| foo.x * 2), Some(84));
});
let m = plugins.clone();
let thread2 = std::thread::spawn(move || {
m.clone().insert(BarPlugin {
y: "hello".to_string(),
});
if let Some(bar) = m.get::<BarPlugin>() {
assert_eq!(bar.y, "hello");
}
m.map_mut::<BarPlugin, _, _>(|bar| {
if let Some(bar) = bar {
bar.y = "world".to_string();
}
});
assert_eq!(m.get::<BarPlugin>().unwrap().y, "world");
});
thread1.join().unwrap();
thread2.join().unwrap();
assert_eq!(plugins.len(), 2);
assert!(!plugins.is_empty());
}
}