common_meta/cache/
registry.rs1use std::sync::Arc;
16
17use anymap2::SendSyncAnyMap;
18use futures::future::join_all;
19
20use crate::cache_invalidator::{CacheInvalidator, Context};
21use crate::error::Result;
22use crate::instruction::CacheIdent;
23
24pub type CacheRegistryRef = Arc<CacheRegistry>;
25pub type LayeredCacheRegistryRef = Arc<LayeredCacheRegistry>;
26
27#[derive(Default)]
29pub struct LayeredCacheRegistryBuilder {
30 registry: LayeredCacheRegistry,
31}
32
33impl LayeredCacheRegistryBuilder {
34 pub fn add_cache_registry(mut self, registry: CacheRegistry) -> Self {
39 self.registry.layers.push(registry);
40
41 self
42 }
43
44 pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
46 self.registry.get()
47 }
48
49 pub fn build(self) -> LayeredCacheRegistry {
51 self.registry
52 }
53}
54
55#[derive(Default)]
57pub struct LayeredCacheRegistry {
58 layers: Vec<CacheRegistry>,
59}
60
61#[async_trait::async_trait]
62impl CacheInvalidator for LayeredCacheRegistry {
63 async fn invalidate(&self, ctx: &Context, caches: &[CacheIdent]) -> Result<()> {
64 let mut results = Vec::with_capacity(self.layers.len());
65 for registry in &self.layers {
66 results.push(registry.invalidate(ctx, caches).await);
67 }
68 results.into_iter().collect::<Result<Vec<_>>>().map(|_| ())
69 }
70}
71
72impl LayeredCacheRegistry {
73 pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
75 for registry in &self.layers {
76 if let Some(cache) = registry.get::<T>() {
77 return Some(cache);
78 }
79 }
80
81 None
82 }
83}
84
85#[derive(Default)]
89pub struct CacheRegistryBuilder {
90 registry: CacheRegistry,
91}
92
93impl CacheRegistryBuilder {
94 pub fn add_cache<T: CacheInvalidator + 'static>(mut self, cache: Arc<T>) -> Self {
96 self.registry.register(cache);
97 self
98 }
99
100 pub fn build(self) -> CacheRegistry {
102 self.registry
103 }
104}
105
106#[derive(Default)]
109pub struct CacheRegistry {
110 indexes: Vec<Arc<dyn CacheInvalidator>>,
111 registry: SendSyncAnyMap,
112}
113
114#[async_trait::async_trait]
115impl CacheInvalidator for CacheRegistry {
116 async fn invalidate(&self, ctx: &Context, caches: &[CacheIdent]) -> Result<()> {
117 let tasks = self
118 .indexes
119 .iter()
120 .map(|invalidator| invalidator.invalidate(ctx, caches));
121 join_all(tasks)
122 .await
123 .into_iter()
124 .collect::<Result<Vec<_>>>()?;
125 Ok(())
126 }
127}
128
129impl CacheRegistry {
130 fn register<T: CacheInvalidator + 'static>(&mut self, cache: Arc<T>) -> bool {
133 self.indexes.push(cache.clone());
134 self.registry.insert(cache).is_some()
135 }
136
137 pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
139 self.registry.get().cloned()
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use std::sync::atomic::{AtomicBool, AtomicI32, Ordering};
146 use std::sync::Arc;
147
148 use moka::future::{Cache, CacheBuilder};
149
150 use crate::cache::registry::CacheRegistryBuilder;
151 use crate::cache::*;
152 use crate::instruction::CacheIdent;
153
154 fn always_true_filter(_: &CacheIdent) -> bool {
155 true
156 }
157
158 fn test_cache(
159 name: &str,
160 invalidator: Invalidator<String, String, CacheIdent>,
161 ) -> CacheContainer<String, String, CacheIdent> {
162 let cache: Cache<String, String> = CacheBuilder::new(128).build();
163 let counter = Arc::new(AtomicI32::new(0));
164 let moved_counter = counter.clone();
165 let init: Initializer<String, String> = Arc::new(move |_| {
166 moved_counter.fetch_add(1, Ordering::Relaxed);
167 Box::pin(async { Ok(Some("hi".to_string())) })
168 });
169
170 CacheContainer::new(
171 name.to_string(),
172 cache,
173 invalidator,
174 init,
175 always_true_filter,
176 )
177 }
178
179 fn test_i32_cache(
180 name: &str,
181 invalidator: Invalidator<i32, String, CacheIdent>,
182 ) -> CacheContainer<i32, String, CacheIdent> {
183 let cache: Cache<i32, String> = CacheBuilder::new(128).build();
184 let counter = Arc::new(AtomicI32::new(0));
185 let moved_counter = counter.clone();
186 let init: Initializer<i32, String> = Arc::new(move |_| {
187 moved_counter.fetch_add(1, Ordering::Relaxed);
188 Box::pin(async { Ok(Some("foo".to_string())) })
189 });
190
191 CacheContainer::new(
192 name.to_string(),
193 cache,
194 invalidator,
195 init,
196 always_true_filter,
197 )
198 }
199
200 #[tokio::test]
201 async fn test_register() {
202 let builder = CacheRegistryBuilder::default();
203 let invalidator: Invalidator<_, String, CacheIdent> =
204 Box::new(|_, _| Box::pin(async { Ok(()) }));
205 let i32_cache = Arc::new(test_i32_cache("i32_cache", invalidator));
206 let invalidator: Invalidator<_, String, CacheIdent> =
207 Box::new(|_, _| Box::pin(async { Ok(()) }));
208 let cache = Arc::new(test_cache("string_cache", invalidator));
209 let registry = builder.add_cache(i32_cache).add_cache(cache).build();
210
211 let cache = registry
212 .get::<Arc<CacheContainer<i32, String, CacheIdent>>>()
213 .unwrap();
214 assert_eq!(cache.name(), "i32_cache");
215
216 let cache = registry
217 .get::<Arc<CacheContainer<String, String, CacheIdent>>>()
218 .unwrap();
219 assert_eq!(cache.name(), "string_cache");
220 }
221
222 #[tokio::test]
223 async fn test_layered_registry() {
224 let builder = LayeredCacheRegistryBuilder::default();
225 let counter = Arc::new(AtomicBool::new(false));
227 let moved_counter = counter.clone();
228 let invalidator: Invalidator<String, String, CacheIdent> = Box::new(move |_, _| {
229 let counter = moved_counter.clone();
230 Box::pin(async move {
231 assert!(!counter.load(Ordering::Relaxed));
232 counter.store(true, Ordering::Relaxed);
233 Ok(())
234 })
235 });
236 let cache = Arc::new(test_cache("string_cache", invalidator));
237 let builder =
238 builder.add_cache_registry(CacheRegistryBuilder::default().add_cache(cache).build());
239 let moved_counter = counter.clone();
241 let invalidator: Invalidator<i32, String, CacheIdent> = Box::new(move |_, _| {
242 let counter = moved_counter.clone();
243 Box::pin(async move {
244 assert!(counter.load(Ordering::Relaxed));
245 Ok(())
246 })
247 });
248 let i32_cache = Arc::new(test_i32_cache("i32_cache", invalidator));
249 let builder = builder
250 .add_cache_registry(CacheRegistryBuilder::default().add_cache(i32_cache).build());
251
252 let registry = builder.build();
253 let cache = registry
254 .get::<Arc<CacheContainer<i32, String, CacheIdent>>>()
255 .unwrap();
256 assert_eq!(cache.name(), "i32_cache");
257 let cache = registry
258 .get::<Arc<CacheContainer<String, String, CacheIdent>>>()
259 .unwrap();
260 assert_eq!(cache.name(), "string_cache");
261 }
262}