1use std::borrow::Borrow;
16use std::hash::Hash;
17use std::sync::Arc;
18
19use futures::future::{join_all, BoxFuture};
20use moka::future::Cache;
21use snafu::{OptionExt, ResultExt};
22
23use crate::cache_invalidator::{CacheInvalidator, Context};
24use crate::error::{self, Error, Result};
25use crate::instruction::CacheIdent;
26use crate::metrics;
27
28pub type TokenFilter<CacheToken> = Box<dyn Fn(&CacheToken) -> bool + Send + Sync>;
30
31pub type Invalidator<K, V, CacheToken> =
33 Box<dyn for<'a> Fn(&'a Cache<K, V>, &'a CacheToken) -> BoxFuture<'a, Result<()>> + Send + Sync>;
34
35pub type Initializer<K, V> = Arc<dyn Fn(&'_ K) -> BoxFuture<'_, Result<Option<V>>> + Send + Sync>;
37
38pub struct CacheContainer<K, V, CacheToken> {
42 name: String,
43 cache: Cache<K, V>,
44 invalidator: Invalidator<K, V, CacheToken>,
45 initializer: Initializer<K, V>,
46 token_filter: fn(&CacheToken) -> bool,
47}
48
49impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
50where
51 K: Send + Sync,
52 V: Send + Sync,
53 CacheToken: Send + Sync,
54{
55 pub fn new(
57 name: String,
58 cache: Cache<K, V>,
59 invalidator: Invalidator<K, V, CacheToken>,
60 initializer: Initializer<K, V>,
61 token_filter: fn(&CacheToken) -> bool,
62 ) -> Self {
63 Self {
64 name,
65 cache,
66 invalidator,
67 initializer,
68 token_filter,
69 }
70 }
71
72 pub fn name(&self) -> &str {
74 &self.name
75 }
76}
77
78#[async_trait::async_trait]
79impl<K, V> CacheInvalidator for CacheContainer<K, V, CacheIdent>
80where
81 K: Send + Sync,
82 V: Send + Sync,
83{
84 async fn invalidate(&self, _ctx: &Context, caches: &[CacheIdent]) -> Result<()> {
85 let tasks = caches
86 .iter()
87 .filter(|token| (self.token_filter)(token))
88 .map(|token| (self.invalidator)(&self.cache, token));
89 join_all(tasks)
90 .await
91 .into_iter()
92 .collect::<Result<Vec<_>>>()?;
93 Ok(())
94 }
95}
96
97impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
98where
99 K: Copy + Hash + Eq + Send + Sync + 'static,
100 V: Clone + Send + Sync + 'static,
101{
102 pub async fn get(&self, key: K) -> Result<Option<V>> {
104 metrics::CACHE_CONTAINER_CACHE_GET
105 .with_label_values(&[&self.name])
106 .inc();
107 let moved_init = self.initializer.clone();
108 let moved_key = key;
109 let init = async move {
110 metrics::CACHE_CONTAINER_CACHE_MISS
111 .with_label_values(&[&self.name])
112 .inc();
113 let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
114 .with_label_values(&[&self.name])
115 .start_timer();
116 moved_init(&moved_key)
117 .await
118 .transpose()
119 .context(error::ValueNotExistSnafu)?
120 };
121
122 match self.cache.try_get_with(key, init).await {
123 Ok(value) => Ok(Some(value)),
124 Err(err) => match err.as_ref() {
125 Error::ValueNotExist { .. } => Ok(None),
126 _ => Err(err).context(error::GetCacheSnafu),
127 },
128 }
129 }
130}
131
132impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
133where
134 K: Hash + Eq + Send + Sync + 'static,
135 V: Clone + Send + Sync + 'static,
136{
137 pub async fn invalidate(&self, caches: &[CacheToken]) -> Result<()> {
139 let tasks = caches
140 .iter()
141 .filter(|token| (self.token_filter)(token))
142 .map(|token| (self.invalidator)(&self.cache, token));
143 join_all(tasks)
144 .await
145 .into_iter()
146 .collect::<Result<Vec<_>>>()?;
147 Ok(())
148 }
149
150 pub fn contains_key<Q>(&self, key: &Q) -> bool
152 where
153 K: Borrow<Q>,
154 Q: Hash + Eq + ?Sized,
155 {
156 self.cache.contains_key(key)
157 }
158
159 pub async fn get_by_ref<Q>(&self, key: &Q) -> Result<Option<V>>
161 where
162 K: Borrow<Q>,
163 Q: ToOwned<Owned = K> + Hash + Eq + ?Sized,
164 {
165 metrics::CACHE_CONTAINER_CACHE_GET
166 .with_label_values(&[&self.name])
167 .inc();
168 let moved_init = self.initializer.clone();
169 let moved_key = key.to_owned();
170
171 let init = async move {
172 metrics::CACHE_CONTAINER_CACHE_MISS
173 .with_label_values(&[&self.name])
174 .inc();
175 let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
176 .with_label_values(&[&self.name])
177 .start_timer();
178
179 moved_init(&moved_key)
180 .await
181 .transpose()
182 .context(error::ValueNotExistSnafu)?
183 };
184
185 match self.cache.try_get_with_by_ref(key, init).await {
186 Ok(value) => Ok(Some(value)),
187 Err(err) => match err.as_ref() {
188 Error::ValueNotExist { .. } => Ok(None),
189 _ => Err(err).context(error::GetCacheSnafu),
190 },
191 }
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use std::sync::atomic::{AtomicI32, Ordering};
198 use std::sync::Arc;
199
200 use moka::future::{Cache, CacheBuilder};
201
202 use super::*;
203
204 #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
205 struct NameKey<'a> {
206 name: &'a str,
207 }
208
209 fn always_true_filter(_: &String) -> bool {
210 true
211 }
212
213 #[tokio::test]
214 async fn test_get() {
215 let cache: Cache<NameKey, String> = CacheBuilder::new(128).build();
216 let counter = Arc::new(AtomicI32::new(0));
217 let moved_counter = counter.clone();
218 let init: Initializer<NameKey, String> = Arc::new(move |_| {
219 moved_counter.fetch_add(1, Ordering::Relaxed);
220 Box::pin(async { Ok(Some("hi".to_string())) })
221 });
222 let invalidator: Invalidator<NameKey, String, String> =
223 Box::new(|_, _| Box::pin(async { Ok(()) }));
224
225 let adv_cache = CacheContainer::new(
226 "test".to_string(),
227 cache,
228 invalidator,
229 init,
230 always_true_filter,
231 );
232 let key = NameKey { name: "key" };
233 let value = adv_cache.get(key).await.unwrap().unwrap();
234 assert_eq!(value, "hi");
235 assert_eq!(counter.load(Ordering::Relaxed), 1);
236 let key = NameKey { name: "key" };
237 let value = adv_cache.get(key).await.unwrap().unwrap();
238 assert_eq!(value, "hi");
239 assert_eq!(counter.load(Ordering::Relaxed), 1);
240 }
241
242 #[tokio::test]
243 async fn test_get_by_ref() {
244 let cache: Cache<String, String> = CacheBuilder::new(128).build();
245 let counter = Arc::new(AtomicI32::new(0));
246 let moved_counter = counter.clone();
247 let init: Initializer<String, String> = Arc::new(move |_| {
248 moved_counter.fetch_add(1, Ordering::Relaxed);
249 Box::pin(async { Ok(Some("hi".to_string())) })
250 });
251 let invalidator: Invalidator<String, String, String> =
252 Box::new(|_, _| Box::pin(async { Ok(()) }));
253
254 let adv_cache = CacheContainer::new(
255 "test".to_string(),
256 cache,
257 invalidator,
258 init,
259 always_true_filter,
260 );
261 let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
262 assert_eq!(value, "hi");
263 let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
264 assert_eq!(value, "hi");
265 assert_eq!(counter.load(Ordering::Relaxed), 1);
266 let value = adv_cache.get_by_ref("bar").await.unwrap().unwrap();
267 assert_eq!(value, "hi");
268 assert_eq!(counter.load(Ordering::Relaxed), 2);
269 }
270
271 #[tokio::test]
272 async fn test_get_value_not_exits() {
273 let cache: Cache<String, String> = CacheBuilder::new(128).build();
274 let init: Initializer<String, String> =
275 Arc::new(move |_| Box::pin(async { error::ValueNotExistSnafu {}.fail() }));
276 let invalidator: Invalidator<String, String, String> =
277 Box::new(|_, _| Box::pin(async { Ok(()) }));
278
279 let adv_cache = CacheContainer::new(
280 "test".to_string(),
281 cache,
282 invalidator,
283 init,
284 always_true_filter,
285 );
286 let value = adv_cache.get_by_ref("foo").await.unwrap();
287 assert!(value.is_none());
288 }
289
290 #[tokio::test]
291 async fn test_invalidate() {
292 let cache: Cache<String, String> = CacheBuilder::new(128).build();
293 let counter = Arc::new(AtomicI32::new(0));
294 let moved_counter = counter.clone();
295 let init: Initializer<String, String> = Arc::new(move |_| {
296 moved_counter.fetch_add(1, Ordering::Relaxed);
297 Box::pin(async { Ok(Some("hi".to_string())) })
298 });
299 let invalidator: Invalidator<String, String, String> = Box::new(|cache, key| {
300 Box::pin(async move {
301 cache.invalidate(key).await;
302 Ok(())
303 })
304 });
305
306 let adv_cache = CacheContainer::new(
307 "test".to_string(),
308 cache,
309 invalidator,
310 init,
311 always_true_filter,
312 );
313 let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
314 assert_eq!(value, "hi");
315 let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
316 assert_eq!(value, "hi");
317 assert_eq!(counter.load(Ordering::Relaxed), 1);
318 adv_cache
319 .invalidate(&["foo".to_string(), "bar".to_string()])
320 .await
321 .unwrap();
322 let value = adv_cache.get_by_ref("foo").await.unwrap().unwrap();
323 assert_eq!(value, "hi");
324 assert_eq!(counter.load(Ordering::Relaxed), 2);
325 }
326}