common_meta/cache/
registry.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use anymap2::SendSyncAnyMap;
18use futures::future::join_all;
19
20use crate::cache_invalidator::{CacheInvalidator, Context};
21use crate::error::Result;
22use crate::instruction::CacheIdent;
23
24pub type CacheRegistryRef = Arc<CacheRegistry>;
25pub type LayeredCacheRegistryRef = Arc<LayeredCacheRegistry>;
26
27/// [LayeredCacheRegistry] Builder.
28#[derive(Default)]
29pub struct LayeredCacheRegistryBuilder {
30    registry: LayeredCacheRegistry,
31}
32
33impl LayeredCacheRegistryBuilder {
34    /// Adds [CacheRegistry] into the next layer.
35    ///
36    /// During cache invalidation, [LayeredCacheRegistry] ensures sequential invalidation
37    /// of each layer (after the previous layer).
38    pub fn add_cache_registry(mut self, registry: CacheRegistry) -> Self {
39        self.registry.layers.push(registry);
40
41        self
42    }
43
44    /// Returns __cloned__ the value stored in the collection for the type `T`, if it exists.
45    pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
46        self.registry.get()
47    }
48
49    /// Builds the [LayeredCacheRegistry]
50    pub fn build(self) -> LayeredCacheRegistry {
51        self.registry
52    }
53}
54
55/// [LayeredCacheRegistry] invalidate caches sequentially from the first layer.
56#[derive(Default)]
57pub struct LayeredCacheRegistry {
58    layers: Vec<CacheRegistry>,
59}
60
61#[async_trait::async_trait]
62impl CacheInvalidator for LayeredCacheRegistry {
63    async fn invalidate(&self, ctx: &Context, caches: &[CacheIdent]) -> Result<()> {
64        let mut results = Vec::with_capacity(self.layers.len());
65        for registry in &self.layers {
66            results.push(registry.invalidate(ctx, caches).await);
67        }
68        results.into_iter().collect::<Result<Vec<_>>>().map(|_| ())
69    }
70}
71
72impl LayeredCacheRegistry {
73    /// Returns __cloned__ the value stored in the collection for the type `T`, if it exists.
74    pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
75        for registry in &self.layers {
76            if let Some(cache) = registry.get::<T>() {
77                return Some(cache);
78            }
79        }
80
81        None
82    }
83}
84
85/// [CacheRegistryBuilder] provides ability of
86/// - Register the `cache` which implements the [CacheInvalidator] trait into [CacheRegistry].
87/// - Build a [CacheRegistry]
88#[derive(Default)]
89pub struct CacheRegistryBuilder {
90    registry: CacheRegistry,
91}
92
93impl CacheRegistryBuilder {
94    /// Adds the cache.
95    pub fn add_cache<T: CacheInvalidator + 'static>(mut self, cache: Arc<T>) -> Self {
96        self.registry.register(cache);
97        self
98    }
99
100    /// Builds [CacheRegistry].
101    pub fn build(self) -> CacheRegistry {
102        self.registry
103    }
104}
105
106/// [CacheRegistry] provides ability of
107/// - Get a specific `cache`.
108#[derive(Default)]
109pub struct CacheRegistry {
110    indexes: Vec<Arc<dyn CacheInvalidator>>,
111    registry: SendSyncAnyMap,
112}
113
114#[async_trait::async_trait]
115impl CacheInvalidator for CacheRegistry {
116    async fn invalidate(&self, ctx: &Context, caches: &[CacheIdent]) -> Result<()> {
117        let tasks = self
118            .indexes
119            .iter()
120            .map(|invalidator| invalidator.invalidate(ctx, caches));
121        join_all(tasks)
122            .await
123            .into_iter()
124            .collect::<Result<Vec<_>>>()?;
125        Ok(())
126    }
127}
128
129impl CacheRegistry {
130    /// Sets the value stored in the collection for the type `T`.
131    /// Returns true if the collection already had a value of type `T`
132    fn register<T: CacheInvalidator + 'static>(&mut self, cache: Arc<T>) -> bool {
133        self.indexes.push(cache.clone());
134        self.registry.insert(cache).is_some()
135    }
136
137    /// Returns __cloned__ the value stored in the collection for the type `T`, if it exists.
138    pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
139        self.registry.get().cloned()
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use std::sync::atomic::{AtomicBool, AtomicI32, Ordering};
146    use std::sync::Arc;
147
148    use moka::future::{Cache, CacheBuilder};
149
150    use crate::cache::registry::CacheRegistryBuilder;
151    use crate::cache::*;
152    use crate::instruction::CacheIdent;
153
154    fn always_true_filter(_: &CacheIdent) -> bool {
155        true
156    }
157
158    fn test_cache(
159        name: &str,
160        invalidator: Invalidator<String, String, CacheIdent>,
161    ) -> CacheContainer<String, String, CacheIdent> {
162        let cache: Cache<String, String> = CacheBuilder::new(128).build();
163        let counter = Arc::new(AtomicI32::new(0));
164        let moved_counter = counter.clone();
165        let init: Initializer<String, String> = Arc::new(move |_| {
166            moved_counter.fetch_add(1, Ordering::Relaxed);
167            Box::pin(async { Ok(Some("hi".to_string())) })
168        });
169
170        CacheContainer::new(
171            name.to_string(),
172            cache,
173            invalidator,
174            init,
175            always_true_filter,
176        )
177    }
178
179    fn test_i32_cache(
180        name: &str,
181        invalidator: Invalidator<i32, String, CacheIdent>,
182    ) -> CacheContainer<i32, String, CacheIdent> {
183        let cache: Cache<i32, String> = CacheBuilder::new(128).build();
184        let counter = Arc::new(AtomicI32::new(0));
185        let moved_counter = counter.clone();
186        let init: Initializer<i32, String> = Arc::new(move |_| {
187            moved_counter.fetch_add(1, Ordering::Relaxed);
188            Box::pin(async { Ok(Some("foo".to_string())) })
189        });
190
191        CacheContainer::new(
192            name.to_string(),
193            cache,
194            invalidator,
195            init,
196            always_true_filter,
197        )
198    }
199
200    #[tokio::test]
201    async fn test_register() {
202        let builder = CacheRegistryBuilder::default();
203        let invalidator: Invalidator<_, String, CacheIdent> =
204            Box::new(|_, _| Box::pin(async { Ok(()) }));
205        let i32_cache = Arc::new(test_i32_cache("i32_cache", invalidator));
206        let invalidator: Invalidator<_, String, CacheIdent> =
207            Box::new(|_, _| Box::pin(async { Ok(()) }));
208        let cache = Arc::new(test_cache("string_cache", invalidator));
209        let registry = builder.add_cache(i32_cache).add_cache(cache).build();
210
211        let cache = registry
212            .get::<Arc<CacheContainer<i32, String, CacheIdent>>>()
213            .unwrap();
214        assert_eq!(cache.name(), "i32_cache");
215
216        let cache = registry
217            .get::<Arc<CacheContainer<String, String, CacheIdent>>>()
218            .unwrap();
219        assert_eq!(cache.name(), "string_cache");
220    }
221
222    #[tokio::test]
223    async fn test_layered_registry() {
224        let builder = LayeredCacheRegistryBuilder::default();
225        // 1st layer
226        let counter = Arc::new(AtomicBool::new(false));
227        let moved_counter = counter.clone();
228        let invalidator: Invalidator<String, String, CacheIdent> = Box::new(move |_, _| {
229            let counter = moved_counter.clone();
230            Box::pin(async move {
231                assert!(!counter.load(Ordering::Relaxed));
232                counter.store(true, Ordering::Relaxed);
233                Ok(())
234            })
235        });
236        let cache = Arc::new(test_cache("string_cache", invalidator));
237        let builder =
238            builder.add_cache_registry(CacheRegistryBuilder::default().add_cache(cache).build());
239        // 2nd layer
240        let moved_counter = counter.clone();
241        let invalidator: Invalidator<i32, String, CacheIdent> = Box::new(move |_, _| {
242            let counter = moved_counter.clone();
243            Box::pin(async move {
244                assert!(counter.load(Ordering::Relaxed));
245                Ok(())
246            })
247        });
248        let i32_cache = Arc::new(test_i32_cache("i32_cache", invalidator));
249        let builder = builder
250            .add_cache_registry(CacheRegistryBuilder::default().add_cache(i32_cache).build());
251
252        let registry = builder.build();
253        let cache = registry
254            .get::<Arc<CacheContainer<i32, String, CacheIdent>>>()
255            .unwrap();
256        assert_eq!(cache.name(), "i32_cache");
257        let cache = registry
258            .get::<Arc<CacheContainer<String, String, CacheIdent>>>()
259            .unwrap();
260        assert_eq!(cache.name(), "string_cache");
261    }
262}