Skip to main content

common_meta/cache/
container.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Borrow;
16use std::hash::Hash;
17use std::sync::Arc;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::time::Duration;
20
21use backon::{BackoffBuilder, ExponentialBuilder};
22use futures::future::BoxFuture;
23use moka::future::Cache;
24use snafu::{OptionExt, ResultExt};
25use tokio::time::sleep;
26
27use crate::cache_invalidator::{CacheInvalidator, Context};
28use crate::error::{self, Error, Result};
29use crate::instruction::CacheIdent;
30use crate::metrics;
31
32/// Filters out unused [CacheToken]s
33pub type TokenFilter<CacheToken> = Box<dyn Fn(&CacheToken) -> bool + Send + Sync>;
34
35/// Invalidates cached values by [CacheToken]s.
36pub type Invalidator<K, V, CacheToken> = Box<
37    dyn for<'a> Fn(&'a Cache<K, V>, &'a [&CacheToken]) -> BoxFuture<'a, Result<()>> + Send + Sync,
38>;
39
40/// Initializes value (i.e., fetches from remote).
41pub type Initializer<K, V> = Arc<dyn Fn(&'_ K) -> BoxFuture<'_, Result<Option<V>>> + Send + Sync>;
42
43#[derive(Debug, Clone, Copy)]
44/// Initialization strategy for cache-miss loading.
45///
46/// This strategy is selected when building [CacheContainer] and remains immutable
47/// for the lifetime of the container instance.
48pub enum InitStrategy {
49    /// Fast path: load once without version conflict retry.
50    ///
51    /// Under concurrent invalidation, callers may observe stale/dirty value.
52    Unchecked,
53    /// Strict path: retry load when version changes during initialization.
54    ///
55    /// This avoids returning dirty value under invalidate/load races.
56    VersionChecked,
57}
58
59/// [CacheContainer] provides ability to:
60/// - Cache value loaded by [Initializer].
61/// - Invalidate caches by [Invalidator].
62pub struct CacheContainer<K, V, CacheToken> {
63    name: String,
64    cache: Cache<K, V>,
65    invalidator: Invalidator<K, V, CacheToken>,
66    initializer: Initializer<K, V>,
67    token_filter: fn(&CacheToken) -> bool,
68    version: Arc<AtomicUsize>,
69    init_strategy: InitStrategy,
70}
71
72fn latest_get_backoff() -> impl Iterator<Item = Duration> {
73    ExponentialBuilder::default()
74        .with_min_delay(Duration::from_millis(10))
75        .with_max_delay(Duration::from_millis(100))
76        .with_max_times(3)
77        .build()
78}
79
80impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
81where
82    K: Send + Sync,
83    V: Send + Sync,
84    CacheToken: Send + Sync,
85{
86    /// Constructs an [CacheContainer] with [InitStrategy::Unchecked].
87    ///
88    /// This keeps the historical behavior and can return stale/dirty value under
89    /// concurrent invalidation.
90    pub fn new(
91        name: String,
92        cache: Cache<K, V>,
93        invalidator: Invalidator<K, V, CacheToken>,
94        initializer: Initializer<K, V>,
95        token_filter: fn(&CacheToken) -> bool,
96    ) -> Self {
97        Self::with_strategy(
98            name,
99            cache,
100            invalidator,
101            initializer,
102            token_filter,
103            InitStrategy::Unchecked,
104        )
105    }
106
107    /// Constructs an [CacheContainer] with explicit [InitStrategy].
108    ///
109    /// The strategy is fixed at construction time and cannot be changed later.
110    pub fn with_strategy(
111        name: String,
112        cache: Cache<K, V>,
113        invalidator: Invalidator<K, V, CacheToken>,
114        initializer: Initializer<K, V>,
115        token_filter: fn(&CacheToken) -> bool,
116        init_strategy: InitStrategy,
117    ) -> Self {
118        Self {
119            name,
120            cache,
121            invalidator,
122            initializer,
123            token_filter,
124            version: Arc::new(AtomicUsize::new(0)),
125            init_strategy,
126        }
127    }
128
129    /// Returns the `name`.
130    pub fn name(&self) -> &str {
131        &self.name
132    }
133}
134
135impl<K, V, CacheToken> CacheContainer<K, V, CacheToken> {
136    fn inc_version(&self) {
137        self.version.fetch_add(1, Ordering::Relaxed);
138    }
139}
140
141async fn init<'a, K, V>(init: Initializer<K, V>, key: K, cache_name: &'a str) -> Result<V>
142where
143    K: Send + Sync + 'a,
144    V: Send + 'a,
145{
146    metrics::CACHE_CONTAINER_CACHE_MISS
147        .with_label_values(&[cache_name])
148        .inc();
149    let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
150        .with_label_values(&[cache_name])
151        .start_timer();
152    init(&key)
153        .await
154        .transpose()
155        .context(error::ValueNotExistSnafu)?
156}
157
158async fn init_with_retry<'a, K, V>(
159    init: Initializer<K, V>,
160    key: K,
161    mut backoff: impl Iterator<Item = Duration> + 'a,
162    version: Arc<AtomicUsize>,
163    cache_name: &'a str,
164) -> Result<V>
165where
166    K: Send + Sync + 'a,
167    V: Send + 'a,
168{
169    let mut attempts = 1usize;
170    loop {
171        let pre_version = version.load(Ordering::Relaxed);
172        metrics::CACHE_CONTAINER_CACHE_MISS
173            .with_label_values(&[cache_name])
174            .inc();
175        let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
176            .with_label_values(&[cache_name])
177            .start_timer();
178        let value = init(&key)
179            .await
180            .transpose()
181            .context(error::ValueNotExistSnafu)??;
182
183        if pre_version == version.load(Ordering::Relaxed) {
184            return Ok(value);
185        }
186
187        if let Some(duration) = backoff.next() {
188            sleep(duration).await;
189            attempts += 1;
190        } else {
191            return error::GetLatestCacheRetryExceededSnafu { attempts }.fail();
192        }
193    }
194}
195
196#[async_trait::async_trait]
197impl<K, V> CacheInvalidator for CacheContainer<K, V, CacheIdent>
198where
199    K: Send + Sync,
200    V: Send + Sync,
201{
202    async fn invalidate(&self, _ctx: &Context, caches: &[CacheIdent]) -> Result<()> {
203        let idents = caches
204            .iter()
205            .filter(|token| (self.token_filter)(token))
206            .collect::<Vec<_>>();
207        if !idents.is_empty() {
208            self.inc_version();
209            (self.invalidator)(&self.cache, &idents).await?;
210        }
211
212        Ok(())
213    }
214}
215
216impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
217where
218    K: Copy + Hash + Eq + Send + Sync + 'static,
219    V: Clone + Send + Sync + 'static,
220{
221    /// Returns a value from cache for copyable keys.
222    ///
223    /// With [InitStrategy::Unchecked], this method prioritizes latency and may
224    /// return stale/dirty value. With [InitStrategy::VersionChecked], this method
225    /// retries initialization on version change and avoids dirty returns.
226    pub async fn get(&self, key: K) -> Result<Option<V>> {
227        metrics::CACHE_CONTAINER_CACHE_GET
228            .with_label_values(&[&self.name])
229            .inc();
230
231        let result = match self.init_strategy {
232            InitStrategy::Unchecked => {
233                self.cache
234                    .try_get_with(key, init(self.initializer.clone(), key, &self.name))
235                    .await
236            }
237            InitStrategy::VersionChecked => {
238                self.cache
239                    .try_get_with(
240                        key,
241                        init_with_retry(
242                            self.initializer.clone(),
243                            key,
244                            latest_get_backoff(),
245                            self.version.clone(),
246                            &self.name,
247                        ),
248                    )
249                    .await
250            }
251        };
252
253        match result {
254            Ok(value) => Ok(Some(value)),
255            Err(err) => match err.as_ref() {
256                Error::ValueNotExist { .. } => Ok(None),
257                _ => Err(err).context(error::GetCacheSnafu),
258            },
259        }
260    }
261}
262
263impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
264where
265    K: Hash + Eq + Send + Sync + 'static,
266    V: Clone + Send + Sync + 'static,
267{
268    /// Invalidates cache by [CacheToken].
269    pub async fn invalidate(&self, caches: &[CacheToken]) -> Result<()> {
270        let idents = caches
271            .iter()
272            .filter(|token| (self.token_filter)(token))
273            .collect::<Vec<_>>();
274        if !idents.is_empty() {
275            self.inc_version();
276            (self.invalidator)(&self.cache, &idents).await?;
277        }
278
279        Ok(())
280    }
281
282    /// Returns true if the cache contains a value for the key.
283    pub fn contains_key<Q>(&self, key: &Q) -> bool
284    where
285        K: Borrow<Q>,
286        Q: Hash + Eq + ?Sized,
287    {
288        self.cache.contains_key(key)
289    }
290
291    /// Returns a value from cache by key reference.
292    ///
293    /// With [InitStrategy::Unchecked], this method prioritizes latency and may
294    /// return stale/dirty value. With [InitStrategy::VersionChecked], this method
295    /// retries initialization on version change and avoids dirty returns.
296    pub async fn get_by_ref<Q>(&self, key: &Q) -> Result<Option<V>>
297    where
298        K: Borrow<Q>,
299        Q: ToOwned<Owned = K> + Hash + Eq + ?Sized,
300    {
301        metrics::CACHE_CONTAINER_CACHE_GET
302            .with_label_values(&[&self.name])
303            .inc();
304        let result = match self.init_strategy {
305            InitStrategy::Unchecked => {
306                self.cache
307                    .try_get_with_by_ref(
308                        key,
309                        init(self.initializer.clone(), key.to_owned(), &self.name),
310                    )
311                    .await
312            }
313            InitStrategy::VersionChecked => {
314                self.cache
315                    .try_get_with_by_ref(
316                        key,
317                        init_with_retry(
318                            self.initializer.clone(),
319                            key.to_owned(),
320                            latest_get_backoff(),
321                            self.version.clone(),
322                            &self.name,
323                        ),
324                    )
325                    .await
326            }
327        };
328
329        match result {
330            Ok(value) => Ok(Some(value)),
331            Err(err) => match err.as_ref() {
332                Error::ValueNotExist { .. } => Ok(None),
333                _ => Err(err).context(error::GetCacheSnafu),
334            },
335        }
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use std::sync::Arc;
342    use std::sync::atomic::{AtomicI32, Ordering};
343
344    use moka::future::{Cache, CacheBuilder};
345
346    use super::*;
347
348    #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
349    struct NameKey<'a> {
350        name: &'a str,
351    }
352
353    fn always_true_filter(_: &String) -> bool {
354        true
355    }
356
357    #[tokio::test]
358    async fn test_get() {
359        let cache: Cache<NameKey, String> = CacheBuilder::new(128).build();
360        let counter = Arc::new(AtomicI32::new(0));
361        let moved_counter = counter.clone();
362        let init: Initializer<NameKey, String> = Arc::new(move |_| {
363            moved_counter.fetch_add(1, Ordering::Relaxed);
364            Box::pin(async { Ok(Some("hi".to_string())) })
365        });
366        let invalidator: Invalidator<NameKey, String, String> =
367            Box::new(|_, _| Box::pin(async { Ok(()) }));
368
369        let adv_cache = CacheContainer::new(
370            "test".to_string(),
371            cache,
372            invalidator,
373            init,
374            always_true_filter,
375        );
376        let key = NameKey { name: "key" };
377        let value = adv_cache.get(key).await.unwrap().unwrap();
378        assert_eq!(value, "hi");
379        assert_eq!(counter.load(Ordering::Relaxed), 1);
380        let key = NameKey { name: "key" };
381        let value = adv_cache.get(key).await.unwrap().unwrap();
382        assert_eq!(value, "hi");
383        assert_eq!(counter.load(Ordering::Relaxed), 1);
384    }
385
386    #[tokio::test]
387    async fn test_get_by_ref() {
388        let cache: Cache<String, String> = CacheBuilder::new(128).build();
389        let counter = Arc::new(AtomicI32::new(0));
390        let moved_counter = counter.clone();
391        let init: Initializer<String, String> = Arc::new(move |_| {
392            moved_counter.fetch_add(1, Ordering::Relaxed);
393            Box::pin(async { Ok(Some("hi".to_string())) })
394        });
395        let invalidator: Invalidator<String, String, String> =
396            Box::new(|_, _| Box::pin(async { Ok(()) }));
397
398        let adv_cache = CacheContainer::new(
399            "test".to_string(),
400            cache,
401            invalidator,
402            init,
403            always_true_filter,
404        );
405        let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
406        assert_eq!(value, "hi");
407        let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
408        assert_eq!(value, "hi");
409        assert_eq!(counter.load(Ordering::Relaxed), 1);
410        let value = adv_cache.get_by_ref("bar").await.unwrap().unwrap();
411        assert_eq!(value, "hi");
412        assert_eq!(counter.load(Ordering::Relaxed), 2);
413    }
414
415    #[tokio::test]
416    async fn test_get_value_not_exits() {
417        let cache: Cache<String, String> = CacheBuilder::new(128).build();
418        let init: Initializer<String, String> =
419            Arc::new(move |_| Box::pin(async { error::ValueNotExistSnafu {}.fail() }));
420        let invalidator: Invalidator<String, String, String> =
421            Box::new(|_, _| Box::pin(async { Ok(()) }));
422
423        let adv_cache = CacheContainer::new(
424            "test".to_string(),
425            cache,
426            invalidator,
427            init,
428            always_true_filter,
429        );
430        let value = adv_cache.get_by_ref("foo").await.unwrap();
431        assert!(value.is_none());
432    }
433
434    #[tokio::test]
435    async fn test_invalidate() {
436        let cache: Cache<String, String> = CacheBuilder::new(128).build();
437        let counter = Arc::new(AtomicI32::new(0));
438        let moved_counter = counter.clone();
439        let init: Initializer<String, String> = Arc::new(move |_| {
440            moved_counter.fetch_add(1, Ordering::Relaxed);
441            Box::pin(async { Ok(Some("hi".to_string())) })
442        });
443        let invalidator: Invalidator<String, String, String> = Box::new(|cache, keys| {
444            Box::pin(async move {
445                for key in keys {
446                    cache.invalidate(*key).await;
447                }
448                Ok(())
449            })
450        });
451
452        let adv_cache = CacheContainer::new(
453            "test".to_string(),
454            cache,
455            invalidator,
456            init,
457            always_true_filter,
458        );
459        let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
460        assert_eq!(value, "hi");
461        let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
462        assert_eq!(value, "hi");
463        assert_eq!(counter.load(Ordering::Relaxed), 1);
464        adv_cache
465            .invalidate(&["foo".to_string(), "bar".to_string()])
466            .await
467            .unwrap();
468        let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
469        assert_eq!(value, "hi");
470        assert_eq!(counter.load(Ordering::Relaxed), 2);
471    }
472
473    #[tokio::test(flavor = "multi_thread")]
474    async fn test_get_by_ref_returns_fresh_value_after_invalidate() {
475        let cache: Cache<String, String> = CacheBuilder::new(128).build();
476        let counter = Arc::new(AtomicI32::new(0));
477        let moved_counter = counter.clone();
478        let init: Initializer<String, String> = Arc::new(move |_| {
479            let counter = moved_counter.clone();
480            Box::pin(async move {
481                let n = counter.fetch_add(1, Ordering::Relaxed) + 1;
482                sleep(Duration::from_millis(100)).await;
483                Ok(Some(format!("v{n}")))
484            })
485        });
486        let invalidator: Invalidator<String, String, String> = Box::new(|cache, keys| {
487            Box::pin(async move {
488                for key in keys {
489                    cache.invalidate(*key).await;
490                }
491                Ok(())
492            })
493        });
494
495        let adv_cache = Arc::new(CacheContainer::with_strategy(
496            "test".to_string(),
497            cache,
498            invalidator,
499            init,
500            always_true_filter,
501            InitStrategy::VersionChecked,
502        ));
503
504        let moved_cache = adv_cache.clone();
505        let get_task = tokio::spawn(async move { moved_cache.get_by_ref("foo").await });
506
507        sleep(Duration::from_millis(50)).await;
508        adv_cache.invalidate(&["foo".to_string()]).await.unwrap();
509
510        let value = get_task.await.unwrap().unwrap().unwrap();
511        assert_eq!(value, "v2");
512        assert_eq!(counter.load(Ordering::Relaxed), 2);
513    }
514}