1use std::borrow::Borrow;
16use std::hash::Hash;
17use std::sync::Arc;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::time::Duration;
20
21use backon::{BackoffBuilder, ExponentialBuilder};
22use futures::future::BoxFuture;
23use moka::future::Cache;
24use snafu::{OptionExt, ResultExt};
25use tokio::time::sleep;
26
27use crate::cache_invalidator::{CacheInvalidator, Context};
28use crate::error::{self, Error, Result};
29use crate::instruction::CacheIdent;
30use crate::metrics;
31
32pub type TokenFilter<CacheToken> = Box<dyn Fn(&CacheToken) -> bool + Send + Sync>;
34
35pub type Invalidator<K, V, CacheToken> = Box<
37 dyn for<'a> Fn(&'a Cache<K, V>, &'a [&CacheToken]) -> BoxFuture<'a, Result<()>> + Send + Sync,
38>;
39
40pub type Initializer<K, V> = Arc<dyn Fn(&'_ K) -> BoxFuture<'_, Result<Option<V>>> + Send + Sync>;
42
43#[derive(Debug, Clone, Copy)]
44pub enum InitStrategy {
49 Unchecked,
53 VersionChecked,
57}
58
59pub struct CacheContainer<K, V, CacheToken> {
63 name: String,
64 cache: Cache<K, V>,
65 invalidator: Invalidator<K, V, CacheToken>,
66 initializer: Initializer<K, V>,
67 token_filter: fn(&CacheToken) -> bool,
68 version: Arc<AtomicUsize>,
69 init_strategy: InitStrategy,
70}
71
72fn latest_get_backoff() -> impl Iterator<Item = Duration> {
73 ExponentialBuilder::default()
74 .with_min_delay(Duration::from_millis(10))
75 .with_max_delay(Duration::from_millis(100))
76 .with_max_times(3)
77 .build()
78}
79
80impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
81where
82 K: Send + Sync,
83 V: Send + Sync,
84 CacheToken: Send + Sync,
85{
86 pub fn new(
91 name: String,
92 cache: Cache<K, V>,
93 invalidator: Invalidator<K, V, CacheToken>,
94 initializer: Initializer<K, V>,
95 token_filter: fn(&CacheToken) -> bool,
96 ) -> Self {
97 Self::with_strategy(
98 name,
99 cache,
100 invalidator,
101 initializer,
102 token_filter,
103 InitStrategy::Unchecked,
104 )
105 }
106
107 pub fn with_strategy(
111 name: String,
112 cache: Cache<K, V>,
113 invalidator: Invalidator<K, V, CacheToken>,
114 initializer: Initializer<K, V>,
115 token_filter: fn(&CacheToken) -> bool,
116 init_strategy: InitStrategy,
117 ) -> Self {
118 Self {
119 name,
120 cache,
121 invalidator,
122 initializer,
123 token_filter,
124 version: Arc::new(AtomicUsize::new(0)),
125 init_strategy,
126 }
127 }
128
129 pub fn name(&self) -> &str {
131 &self.name
132 }
133}
134
135impl<K, V, CacheToken> CacheContainer<K, V, CacheToken> {
136 fn inc_version(&self) {
137 self.version.fetch_add(1, Ordering::Relaxed);
138 }
139}
140
141async fn init<'a, K, V>(init: Initializer<K, V>, key: K, cache_name: &'a str) -> Result<V>
142where
143 K: Send + Sync + 'a,
144 V: Send + 'a,
145{
146 metrics::CACHE_CONTAINER_CACHE_MISS
147 .with_label_values(&[cache_name])
148 .inc();
149 let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
150 .with_label_values(&[cache_name])
151 .start_timer();
152 init(&key)
153 .await
154 .transpose()
155 .context(error::ValueNotExistSnafu)?
156}
157
158async fn init_with_retry<'a, K, V>(
159 init: Initializer<K, V>,
160 key: K,
161 mut backoff: impl Iterator<Item = Duration> + 'a,
162 version: Arc<AtomicUsize>,
163 cache_name: &'a str,
164) -> Result<V>
165where
166 K: Send + Sync + 'a,
167 V: Send + 'a,
168{
169 let mut attempts = 1usize;
170 loop {
171 let pre_version = version.load(Ordering::Relaxed);
172 metrics::CACHE_CONTAINER_CACHE_MISS
173 .with_label_values(&[cache_name])
174 .inc();
175 let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
176 .with_label_values(&[cache_name])
177 .start_timer();
178 let value = init(&key)
179 .await
180 .transpose()
181 .context(error::ValueNotExistSnafu)??;
182
183 if pre_version == version.load(Ordering::Relaxed) {
184 return Ok(value);
185 }
186
187 if let Some(duration) = backoff.next() {
188 sleep(duration).await;
189 attempts += 1;
190 } else {
191 return error::GetLatestCacheRetryExceededSnafu { attempts }.fail();
192 }
193 }
194}
195
196#[async_trait::async_trait]
197impl<K, V> CacheInvalidator for CacheContainer<K, V, CacheIdent>
198where
199 K: Send + Sync,
200 V: Send + Sync,
201{
202 async fn invalidate(&self, _ctx: &Context, caches: &[CacheIdent]) -> Result<()> {
203 let idents = caches
204 .iter()
205 .filter(|token| (self.token_filter)(token))
206 .collect::<Vec<_>>();
207 if !idents.is_empty() {
208 self.inc_version();
209 (self.invalidator)(&self.cache, &idents).await?;
210 }
211
212 Ok(())
213 }
214}
215
216impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
217where
218 K: Copy + Hash + Eq + Send + Sync + 'static,
219 V: Clone + Send + Sync + 'static,
220{
221 pub async fn get(&self, key: K) -> Result<Option<V>> {
227 metrics::CACHE_CONTAINER_CACHE_GET
228 .with_label_values(&[&self.name])
229 .inc();
230
231 let result = match self.init_strategy {
232 InitStrategy::Unchecked => {
233 self.cache
234 .try_get_with(key, init(self.initializer.clone(), key, &self.name))
235 .await
236 }
237 InitStrategy::VersionChecked => {
238 self.cache
239 .try_get_with(
240 key,
241 init_with_retry(
242 self.initializer.clone(),
243 key,
244 latest_get_backoff(),
245 self.version.clone(),
246 &self.name,
247 ),
248 )
249 .await
250 }
251 };
252
253 match result {
254 Ok(value) => Ok(Some(value)),
255 Err(err) => match err.as_ref() {
256 Error::ValueNotExist { .. } => Ok(None),
257 _ => Err(err).context(error::GetCacheSnafu),
258 },
259 }
260 }
261}
262
263impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
264where
265 K: Hash + Eq + Send + Sync + 'static,
266 V: Clone + Send + Sync + 'static,
267{
268 pub async fn invalidate(&self, caches: &[CacheToken]) -> Result<()> {
270 let idents = caches
271 .iter()
272 .filter(|token| (self.token_filter)(token))
273 .collect::<Vec<_>>();
274 if !idents.is_empty() {
275 self.inc_version();
276 (self.invalidator)(&self.cache, &idents).await?;
277 }
278
279 Ok(())
280 }
281
282 pub fn contains_key<Q>(&self, key: &Q) -> bool
284 where
285 K: Borrow<Q>,
286 Q: Hash + Eq + ?Sized,
287 {
288 self.cache.contains_key(key)
289 }
290
291 pub async fn get_by_ref<Q>(&self, key: &Q) -> Result<Option<V>>
297 where
298 K: Borrow<Q>,
299 Q: ToOwned<Owned = K> + Hash + Eq + ?Sized,
300 {
301 metrics::CACHE_CONTAINER_CACHE_GET
302 .with_label_values(&[&self.name])
303 .inc();
304 let result = match self.init_strategy {
305 InitStrategy::Unchecked => {
306 self.cache
307 .try_get_with_by_ref(
308 key,
309 init(self.initializer.clone(), key.to_owned(), &self.name),
310 )
311 .await
312 }
313 InitStrategy::VersionChecked => {
314 self.cache
315 .try_get_with_by_ref(
316 key,
317 init_with_retry(
318 self.initializer.clone(),
319 key.to_owned(),
320 latest_get_backoff(),
321 self.version.clone(),
322 &self.name,
323 ),
324 )
325 .await
326 }
327 };
328
329 match result {
330 Ok(value) => Ok(Some(value)),
331 Err(err) => match err.as_ref() {
332 Error::ValueNotExist { .. } => Ok(None),
333 _ => Err(err).context(error::GetCacheSnafu),
334 },
335 }
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use std::sync::Arc;
342 use std::sync::atomic::{AtomicI32, Ordering};
343
344 use moka::future::{Cache, CacheBuilder};
345
346 use super::*;
347
348 #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
349 struct NameKey<'a> {
350 name: &'a str,
351 }
352
353 fn always_true_filter(_: &String) -> bool {
354 true
355 }
356
357 #[tokio::test]
358 async fn test_get() {
359 let cache: Cache<NameKey, String> = CacheBuilder::new(128).build();
360 let counter = Arc::new(AtomicI32::new(0));
361 let moved_counter = counter.clone();
362 let init: Initializer<NameKey, String> = Arc::new(move |_| {
363 moved_counter.fetch_add(1, Ordering::Relaxed);
364 Box::pin(async { Ok(Some("hi".to_string())) })
365 });
366 let invalidator: Invalidator<NameKey, String, String> =
367 Box::new(|_, _| Box::pin(async { Ok(()) }));
368
369 let adv_cache = CacheContainer::new(
370 "test".to_string(),
371 cache,
372 invalidator,
373 init,
374 always_true_filter,
375 );
376 let key = NameKey { name: "key" };
377 let value = adv_cache.get(key).await.unwrap().unwrap();
378 assert_eq!(value, "hi");
379 assert_eq!(counter.load(Ordering::Relaxed), 1);
380 let key = NameKey { name: "key" };
381 let value = adv_cache.get(key).await.unwrap().unwrap();
382 assert_eq!(value, "hi");
383 assert_eq!(counter.load(Ordering::Relaxed), 1);
384 }
385
386 #[tokio::test]
387 async fn test_get_by_ref() {
388 let cache: Cache<String, String> = CacheBuilder::new(128).build();
389 let counter = Arc::new(AtomicI32::new(0));
390 let moved_counter = counter.clone();
391 let init: Initializer<String, String> = Arc::new(move |_| {
392 moved_counter.fetch_add(1, Ordering::Relaxed);
393 Box::pin(async { Ok(Some("hi".to_string())) })
394 });
395 let invalidator: Invalidator<String, String, String> =
396 Box::new(|_, _| Box::pin(async { Ok(()) }));
397
398 let adv_cache = CacheContainer::new(
399 "test".to_string(),
400 cache,
401 invalidator,
402 init,
403 always_true_filter,
404 );
405 let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
406 assert_eq!(value, "hi");
407 let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
408 assert_eq!(value, "hi");
409 assert_eq!(counter.load(Ordering::Relaxed), 1);
410 let value = adv_cache.get_by_ref("bar").await.unwrap().unwrap();
411 assert_eq!(value, "hi");
412 assert_eq!(counter.load(Ordering::Relaxed), 2);
413 }
414
415 #[tokio::test]
416 async fn test_get_value_not_exits() {
417 let cache: Cache<String, String> = CacheBuilder::new(128).build();
418 let init: Initializer<String, String> =
419 Arc::new(move |_| Box::pin(async { error::ValueNotExistSnafu {}.fail() }));
420 let invalidator: Invalidator<String, String, String> =
421 Box::new(|_, _| Box::pin(async { Ok(()) }));
422
423 let adv_cache = CacheContainer::new(
424 "test".to_string(),
425 cache,
426 invalidator,
427 init,
428 always_true_filter,
429 );
430 let value = adv_cache.get_by_ref("foo").await.unwrap();
431 assert!(value.is_none());
432 }
433
434 #[tokio::test]
435 async fn test_invalidate() {
436 let cache: Cache<String, String> = CacheBuilder::new(128).build();
437 let counter = Arc::new(AtomicI32::new(0));
438 let moved_counter = counter.clone();
439 let init: Initializer<String, String> = Arc::new(move |_| {
440 moved_counter.fetch_add(1, Ordering::Relaxed);
441 Box::pin(async { Ok(Some("hi".to_string())) })
442 });
443 let invalidator: Invalidator<String, String, String> = Box::new(|cache, keys| {
444 Box::pin(async move {
445 for key in keys {
446 cache.invalidate(*key).await;
447 }
448 Ok(())
449 })
450 });
451
452 let adv_cache = CacheContainer::new(
453 "test".to_string(),
454 cache,
455 invalidator,
456 init,
457 always_true_filter,
458 );
459 let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
460 assert_eq!(value, "hi");
461 let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
462 assert_eq!(value, "hi");
463 assert_eq!(counter.load(Ordering::Relaxed), 1);
464 adv_cache
465 .invalidate(&["foo".to_string(), "bar".to_string()])
466 .await
467 .unwrap();
468 let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
469 assert_eq!(value, "hi");
470 assert_eq!(counter.load(Ordering::Relaxed), 2);
471 }
472
473 #[tokio::test(flavor = "multi_thread")]
474 async fn test_get_by_ref_returns_fresh_value_after_invalidate() {
475 let cache: Cache<String, String> = CacheBuilder::new(128).build();
476 let counter = Arc::new(AtomicI32::new(0));
477 let moved_counter = counter.clone();
478 let init: Initializer<String, String> = Arc::new(move |_| {
479 let counter = moved_counter.clone();
480 Box::pin(async move {
481 let n = counter.fetch_add(1, Ordering::Relaxed) + 1;
482 sleep(Duration::from_millis(100)).await;
483 Ok(Some(format!("v{n}")))
484 })
485 });
486 let invalidator: Invalidator<String, String, String> = Box::new(|cache, keys| {
487 Box::pin(async move {
488 for key in keys {
489 cache.invalidate(*key).await;
490 }
491 Ok(())
492 })
493 });
494
495 let adv_cache = Arc::new(CacheContainer::with_strategy(
496 "test".to_string(),
497 cache,
498 invalidator,
499 init,
500 always_true_filter,
501 InitStrategy::VersionChecked,
502 ));
503
504 let moved_cache = adv_cache.clone();
505 let get_task = tokio::spawn(async move { moved_cache.get_by_ref("foo").await });
506
507 sleep(Duration::from_millis(50)).await;
508 adv_cache.invalidate(&["foo".to_string()]).await.unwrap();
509
510 let value = get_task.await.unwrap().unwrap().unwrap();
511 assert_eq!(value, "v2");
512 assert_eq!(counter.load(Ordering::Relaxed), 2);
513 }
514}