common_meta/cache/
container.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Borrow;
16use std::hash::Hash;
17use std::sync::Arc;
18
19use futures::future::{join_all, BoxFuture};
20use moka::future::Cache;
21use snafu::{OptionExt, ResultExt};
22
23use crate::cache_invalidator::{CacheInvalidator, Context};
24use crate::error::{self, Error, Result};
25use crate::instruction::CacheIdent;
26use crate::metrics;
27
28/// Filters out unused [CacheToken]s
29pub type TokenFilter<CacheToken> = Box<dyn Fn(&CacheToken) -> bool + Send + Sync>;
30
31/// Invalidates cached values by [CacheToken]s.
32pub type Invalidator<K, V, CacheToken> =
33    Box<dyn for<'a> Fn(&'a Cache<K, V>, &'a CacheToken) -> BoxFuture<'a, Result<()>> + Send + Sync>;
34
35/// Initializes value (i.e., fetches from remote).
36pub type Initializer<K, V> = Arc<dyn Fn(&'_ K) -> BoxFuture<'_, Result<Option<V>>> + Send + Sync>;
37
38/// [CacheContainer] provides ability to:
39/// - Cache value loaded by [Initializer].
40/// - Invalidate caches by [Invalidator].
41pub struct CacheContainer<K, V, CacheToken> {
42    name: String,
43    cache: Cache<K, V>,
44    invalidator: Invalidator<K, V, CacheToken>,
45    initializer: Initializer<K, V>,
46    token_filter: fn(&CacheToken) -> bool,
47}
48
49impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
50where
51    K: Send + Sync,
52    V: Send + Sync,
53    CacheToken: Send + Sync,
54{
55    /// Constructs an [CacheContainer].
56    pub fn new(
57        name: String,
58        cache: Cache<K, V>,
59        invalidator: Invalidator<K, V, CacheToken>,
60        initializer: Initializer<K, V>,
61        token_filter: fn(&CacheToken) -> bool,
62    ) -> Self {
63        Self {
64            name,
65            cache,
66            invalidator,
67            initializer,
68            token_filter,
69        }
70    }
71
72    /// Returns the `name`.
73    pub fn name(&self) -> &str {
74        &self.name
75    }
76}
77
78#[async_trait::async_trait]
79impl<K, V> CacheInvalidator for CacheContainer<K, V, CacheIdent>
80where
81    K: Send + Sync,
82    V: Send + Sync,
83{
84    async fn invalidate(&self, _ctx: &Context, caches: &[CacheIdent]) -> Result<()> {
85        let tasks = caches
86            .iter()
87            .filter(|token| (self.token_filter)(token))
88            .map(|token| (self.invalidator)(&self.cache, token));
89        join_all(tasks)
90            .await
91            .into_iter()
92            .collect::<Result<Vec<_>>>()?;
93        Ok(())
94    }
95}
96
97impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
98where
99    K: Copy + Hash + Eq + Send + Sync + 'static,
100    V: Clone + Send + Sync + 'static,
101{
102    /// Returns a _clone_ of the value corresponding to the key.
103    pub async fn get(&self, key: K) -> Result<Option<V>> {
104        metrics::CACHE_CONTAINER_CACHE_GET
105            .with_label_values(&[&self.name])
106            .inc();
107        let moved_init = self.initializer.clone();
108        let moved_key = key;
109        let init = async move {
110            metrics::CACHE_CONTAINER_CACHE_MISS
111                .with_label_values(&[&self.name])
112                .inc();
113            let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
114                .with_label_values(&[&self.name])
115                .start_timer();
116            moved_init(&moved_key)
117                .await
118                .transpose()
119                .context(error::ValueNotExistSnafu)?
120        };
121
122        match self.cache.try_get_with(key, init).await {
123            Ok(value) => Ok(Some(value)),
124            Err(err) => match err.as_ref() {
125                Error::ValueNotExist { .. } => Ok(None),
126                _ => Err(err).context(error::GetCacheSnafu),
127            },
128        }
129    }
130}
131
132impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
133where
134    K: Hash + Eq + Send + Sync + 'static,
135    V: Clone + Send + Sync + 'static,
136{
137    /// Invalidates cache by [CacheToken].
138    pub async fn invalidate(&self, caches: &[CacheToken]) -> Result<()> {
139        let tasks = caches
140            .iter()
141            .filter(|token| (self.token_filter)(token))
142            .map(|token| (self.invalidator)(&self.cache, token));
143        join_all(tasks)
144            .await
145            .into_iter()
146            .collect::<Result<Vec<_>>>()?;
147        Ok(())
148    }
149
150    /// Returns true if the cache contains a value for the key.
151    pub fn contains_key<Q>(&self, key: &Q) -> bool
152    where
153        K: Borrow<Q>,
154        Q: Hash + Eq + ?Sized,
155    {
156        self.cache.contains_key(key)
157    }
158
159    /// Returns a _clone_ of the value corresponding to the key.
160    pub async fn get_by_ref<Q>(&self, key: &Q) -> Result<Option<V>>
161    where
162        K: Borrow<Q>,
163        Q: ToOwned<Owned = K> + Hash + Eq + ?Sized,
164    {
165        metrics::CACHE_CONTAINER_CACHE_GET
166            .with_label_values(&[&self.name])
167            .inc();
168        let moved_init = self.initializer.clone();
169        let moved_key = key.to_owned();
170
171        let init = async move {
172            metrics::CACHE_CONTAINER_CACHE_MISS
173                .with_label_values(&[&self.name])
174                .inc();
175            let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
176                .with_label_values(&[&self.name])
177                .start_timer();
178
179            moved_init(&moved_key)
180                .await
181                .transpose()
182                .context(error::ValueNotExistSnafu)?
183        };
184
185        match self.cache.try_get_with_by_ref(key, init).await {
186            Ok(value) => Ok(Some(value)),
187            Err(err) => match err.as_ref() {
188                Error::ValueNotExist { .. } => Ok(None),
189                _ => Err(err).context(error::GetCacheSnafu),
190            },
191        }
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use std::sync::atomic::{AtomicI32, Ordering};
198    use std::sync::Arc;
199
200    use moka::future::{Cache, CacheBuilder};
201
202    use super::*;
203
204    #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
205    struct NameKey<'a> {
206        name: &'a str,
207    }
208
209    fn always_true_filter(_: &String) -> bool {
210        true
211    }
212
213    #[tokio::test]
214    async fn test_get() {
215        let cache: Cache<NameKey, String> = CacheBuilder::new(128).build();
216        let counter = Arc::new(AtomicI32::new(0));
217        let moved_counter = counter.clone();
218        let init: Initializer<NameKey, String> = Arc::new(move |_| {
219            moved_counter.fetch_add(1, Ordering::Relaxed);
220            Box::pin(async { Ok(Some("hi".to_string())) })
221        });
222        let invalidator: Invalidator<NameKey, String, String> =
223            Box::new(|_, _| Box::pin(async { Ok(()) }));
224
225        let adv_cache = CacheContainer::new(
226            "test".to_string(),
227            cache,
228            invalidator,
229            init,
230            always_true_filter,
231        );
232        let key = NameKey { name: "key" };
233        let value = adv_cache.get(key).await.unwrap().unwrap();
234        assert_eq!(value, "hi");
235        assert_eq!(counter.load(Ordering::Relaxed), 1);
236        let key = NameKey { name: "key" };
237        let value = adv_cache.get(key).await.unwrap().unwrap();
238        assert_eq!(value, "hi");
239        assert_eq!(counter.load(Ordering::Relaxed), 1);
240    }
241
242    #[tokio::test]
243    async fn test_get_by_ref() {
244        let cache: Cache<String, String> = CacheBuilder::new(128).build();
245        let counter = Arc::new(AtomicI32::new(0));
246        let moved_counter = counter.clone();
247        let init: Initializer<String, String> = Arc::new(move |_| {
248            moved_counter.fetch_add(1, Ordering::Relaxed);
249            Box::pin(async { Ok(Some("hi".to_string())) })
250        });
251        let invalidator: Invalidator<String, String, String> =
252            Box::new(|_, _| Box::pin(async { Ok(()) }));
253
254        let adv_cache = CacheContainer::new(
255            "test".to_string(),
256            cache,
257            invalidator,
258            init,
259            always_true_filter,
260        );
261        let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
262        assert_eq!(value, "hi");
263        let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
264        assert_eq!(value, "hi");
265        assert_eq!(counter.load(Ordering::Relaxed), 1);
266        let value = adv_cache.get_by_ref("bar").await.unwrap().unwrap();
267        assert_eq!(value, "hi");
268        assert_eq!(counter.load(Ordering::Relaxed), 2);
269    }
270
271    #[tokio::test]
272    async fn test_get_value_not_exits() {
273        let cache: Cache<String, String> = CacheBuilder::new(128).build();
274        let init: Initializer<String, String> =
275            Arc::new(move |_| Box::pin(async { error::ValueNotExistSnafu {}.fail() }));
276        let invalidator: Invalidator<String, String, String> =
277            Box::new(|_, _| Box::pin(async { Ok(()) }));
278
279        let adv_cache = CacheContainer::new(
280            "test".to_string(),
281            cache,
282            invalidator,
283            init,
284            always_true_filter,
285        );
286        let value = adv_cache.get_by_ref("foo").await.unwrap();
287        assert!(value.is_none());
288    }
289
290    #[tokio::test]
291    async fn test_invalidate() {
292        let cache: Cache<String, String> = CacheBuilder::new(128).build();
293        let counter = Arc::new(AtomicI32::new(0));
294        let moved_counter = counter.clone();
295        let init: Initializer<String, String> = Arc::new(move |_| {
296            moved_counter.fetch_add(1, Ordering::Relaxed);
297            Box::pin(async { Ok(Some("hi".to_string())) })
298        });
299        let invalidator: Invalidator<String, String, String> = Box::new(|cache, key| {
300            Box::pin(async move {
301                cache.invalidate(key).await;
302                Ok(())
303            })
304        });
305
306        let adv_cache = CacheContainer::new(
307            "test".to_string(),
308            cache,
309            invalidator,
310            init,
311            always_true_filter,
312        );
313        let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
314        assert_eq!(value, "hi");
315        let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
316        assert_eq!(value, "hi");
317        assert_eq!(counter.load(Ordering::Relaxed), 1);
318        adv_cache
319            .invalidate(&["foo".to_string(), "bar".to_string()])
320            .await
321            .unwrap();
322        let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
323        assert_eq!(value, "hi");
324        assert_eq!(counter.load(Ordering::Relaxed), 2);
325    }
326}