common_base/
plugins.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
16
17use anymap2::SendSyncAnyMap;
18
19/// [`Plugins`] is a wrapper of [anymap2](https://github.com/azriel91/anymap2) and provides a thread-safe way to store and retrieve plugins.
20/// Make it Cloneable and we can treat it like an Arc struct.
21#[derive(Default, Clone)]
22pub struct Plugins {
23    inner: Arc<RwLock<SendSyncAnyMap>>,
24}
25
26impl Plugins {
27    pub fn new() -> Self {
28        Self {
29            inner: Arc::new(RwLock::new(SendSyncAnyMap::new())),
30        }
31    }
32
33    pub fn insert<T: 'static + Send + Sync>(&self, value: T) {
34        let last = self.write().insert(value);
35        if last.is_some() {
36            panic!(
37                "Plugin of type {} already exists",
38                std::any::type_name::<T>()
39            );
40        }
41    }
42
43    pub fn get<T: 'static + Send + Sync + Clone>(&self) -> Option<T> {
44        self.read().get::<T>().cloned()
45    }
46
47    pub fn get_or_insert<T, F>(&self, f: F) -> T
48    where
49        T: 'static + Send + Sync + Clone,
50        F: FnOnce() -> T,
51    {
52        let mut binding = self.write();
53        if !binding.contains::<T>() {
54            binding.insert(f());
55        }
56        binding.get::<T>().cloned().unwrap()
57    }
58
59    pub fn map_mut<T: 'static + Send + Sync, F, R>(&self, mapper: F) -> R
60    where
61        F: FnOnce(Option<&mut T>) -> R,
62    {
63        let mut binding = self.write();
64        let opt = binding.get_mut::<T>();
65        mapper(opt)
66    }
67
68    pub fn map<T: 'static + Send + Sync, F, R>(&self, mapper: F) -> Option<R>
69    where
70        F: FnOnce(&T) -> R,
71    {
72        self.read().get::<T>().map(mapper)
73    }
74
75    pub fn len(&self) -> usize {
76        self.read().len()
77    }
78
79    pub fn is_empty(&self) -> bool {
80        self.read().is_empty()
81    }
82
83    fn read(&self) -> RwLockReadGuard<'_, SendSyncAnyMap> {
84        self.inner.read().unwrap()
85    }
86
87    fn write(&self) -> RwLockWriteGuard<'_, SendSyncAnyMap> {
88        self.inner.write().unwrap()
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    #[test]
97    fn test_plugins() {
98        #[derive(Debug, Clone)]
99        struct FooPlugin {
100            x: i32,
101        }
102
103        #[derive(Debug, Clone)]
104        struct BarPlugin {
105            y: String,
106        }
107
108        let plugins = Plugins::new();
109
110        let m = plugins.clone();
111        let thread1 = std::thread::spawn(move || {
112            m.insert(FooPlugin { x: 42 });
113
114            if let Some(foo) = m.get::<FooPlugin>() {
115                assert_eq!(foo.x, 42);
116            }
117
118            assert_eq!(m.map::<FooPlugin, _, _>(|foo| foo.x * 2), Some(84));
119        });
120
121        let m = plugins.clone();
122        let thread2 = std::thread::spawn(move || {
123            m.clone().insert(BarPlugin {
124                y: "hello".to_string(),
125            });
126
127            if let Some(bar) = m.get::<BarPlugin>() {
128                assert_eq!(bar.y, "hello");
129            }
130
131            m.map_mut::<BarPlugin, _, _>(|bar| {
132                if let Some(bar) = bar {
133                    bar.y = "world".to_string();
134                }
135            });
136
137            assert_eq!(m.get::<BarPlugin>().unwrap().y, "world");
138        });
139
140        thread1.join().unwrap();
141        thread2.join().unwrap();
142
143        assert_eq!(plugins.len(), 2);
144        assert!(!plugins.is_empty());
145    }
146
147    #[test]
148    #[should_panic(expected = "Plugin of type i32 already exists")]
149    fn test_plugin_uniqueness() {
150        let plugins = Plugins::new();
151        plugins.insert(1i32);
152        plugins.insert(2i32);
153    }
154}