common_base/
plugins.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
16
17use anymap2::SendSyncAnyMap;
18
19/// [`Plugins`] is a wrapper of [anymap2](https://github.com/azriel91/anymap2) and provides a thread-safe way to store and retrieve plugins.
20/// Make it Cloneable and we can treat it like an Arc struct.
21#[derive(Default, Clone)]
22pub struct Plugins {
23    inner: Arc<RwLock<SendSyncAnyMap>>,
24}
25
26impl Plugins {
27    pub fn new() -> Self {
28        Self {
29            inner: Arc::new(RwLock::new(SendSyncAnyMap::new())),
30        }
31    }
32
33    pub fn insert<T: 'static + Send + Sync>(&self, value: T) {
34        let last = self.write().insert(value);
35        assert!(last.is_none(), "each type of plugins must be one and only");
36    }
37
38    pub fn get<T: 'static + Send + Sync + Clone>(&self) -> Option<T> {
39        self.read().get::<T>().cloned()
40    }
41
42    pub fn get_or_insert<T, F>(&self, f: F) -> T
43    where
44        T: 'static + Send + Sync + Clone,
45        F: FnOnce() -> T,
46    {
47        let mut binding = self.write();
48        if !binding.contains::<T>() {
49            binding.insert(f());
50        }
51        binding.get::<T>().cloned().unwrap()
52    }
53
54    pub fn map_mut<T: 'static + Send + Sync, F, R>(&self, mapper: F) -> R
55    where
56        F: FnOnce(Option<&mut T>) -> R,
57    {
58        let mut binding = self.write();
59        let opt = binding.get_mut::<T>();
60        mapper(opt)
61    }
62
63    pub fn map<T: 'static + Send + Sync, F, R>(&self, mapper: F) -> Option<R>
64    where
65        F: FnOnce(&T) -> R,
66    {
67        self.read().get::<T>().map(mapper)
68    }
69
70    pub fn len(&self) -> usize {
71        self.read().len()
72    }
73
74    pub fn is_empty(&self) -> bool {
75        self.read().is_empty()
76    }
77
78    fn read(&self) -> RwLockReadGuard<SendSyncAnyMap> {
79        self.inner.read().unwrap()
80    }
81
82    fn write(&self) -> RwLockWriteGuard<SendSyncAnyMap> {
83        self.inner.write().unwrap()
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn test_plugins() {
93        #[derive(Debug, Clone)]
94        struct FooPlugin {
95            x: i32,
96        }
97
98        #[derive(Debug, Clone)]
99        struct BarPlugin {
100            y: String,
101        }
102
103        let plugins = Plugins::new();
104
105        let m = plugins.clone();
106        let thread1 = std::thread::spawn(move || {
107            m.insert(FooPlugin { x: 42 });
108
109            if let Some(foo) = m.get::<FooPlugin>() {
110                assert_eq!(foo.x, 42);
111            }
112
113            assert_eq!(m.map::<FooPlugin, _, _>(|foo| foo.x * 2), Some(84));
114        });
115
116        let m = plugins.clone();
117        let thread2 = std::thread::spawn(move || {
118            m.clone().insert(BarPlugin {
119                y: "hello".to_string(),
120            });
121
122            if let Some(bar) = m.get::<BarPlugin>() {
123                assert_eq!(bar.y, "hello");
124            }
125
126            m.map_mut::<BarPlugin, _, _>(|bar| {
127                if let Some(bar) = bar {
128                    bar.y = "world".to_string();
129                }
130            });
131
132            assert_eq!(m.get::<BarPlugin>().unwrap().y, "world");
133        });
134
135        thread1.join().unwrap();
136        thread2.join().unwrap();
137
138        assert_eq!(plugins.len(), 2);
139        assert!(!plugins.is_empty());
140    }
141
142    #[test]
143    #[should_panic(expected = "each type of plugins must be one and only")]
144    fn test_plugin_uniqueness() {
145        let plugins = Plugins::new();
146        plugins.insert(1i32);
147        plugins.insert(2i32);
148    }
149}