1use std::mem::size_of;
16
17use fst::map::OpBuilder;
18use fst::{IntoStreamer, Streamer};
19use regex_automata::dfa::dense::DFA;
20use regex_automata::dfa::Automaton;
21use regex_automata::util::primitives::StateID;
22use regex_automata::util::start::Config;
23use regex_automata::Anchored;
24use snafu::{ensure, ResultExt};
25
26use crate::inverted_index::error::{
27 EmptyPredicatesSnafu, IntersectionApplierWithInListSnafu, ParseDFASnafu, Result,
28};
29use crate::inverted_index::search::fst_apply::FstApplier;
30use crate::inverted_index::search::predicate::{Predicate, Range};
31use crate::inverted_index::FstMap;
32
33pub struct IntersectionFstApplier {
35 ranges: Vec<Range>,
37
38 dfas: Vec<DfaFstAutomaton>,
40}
41
42#[derive(Debug)]
43struct DfaFstAutomaton(DFA<Vec<u32>>);
44
45impl fst::Automaton for DfaFstAutomaton {
46 type State = StateID;
47
48 #[inline]
49 fn start(&self) -> Self::State {
50 let config = Config::new().anchored(Anchored::No);
51 self.0.start_state(&config).unwrap()
52 }
53
54 #[inline]
55 fn is_match(&self, state: &Self::State) -> bool {
56 self.0.is_match_state(*state)
57 }
58
59 #[inline]
60 fn can_match(&self, state: &Self::State) -> bool {
61 !self.0.is_dead_state(*state)
62 }
63
64 #[inline]
65 fn accept_eof(&self, state: &StateID) -> Option<StateID> {
66 if self.0.is_match_state(*state) {
67 return Some(*state);
68 }
69 Some(self.0.next_eoi_state(*state))
70 }
71
72 #[inline]
73 fn accept(&self, state: &Self::State, byte: u8) -> Self::State {
74 if self.0.is_match_state(*state) {
75 return *state;
76 }
77 self.0.next_state(*state, byte)
78 }
79}
80
81impl IntersectionFstApplier {
82 fn new(ranges: Vec<Range>, dfas: Vec<DFA<Vec<u32>>>) -> Self {
83 let dfas = dfas.into_iter().map(DfaFstAutomaton).collect();
84 Self { ranges, dfas }
85 }
86}
87
88impl FstApplier for IntersectionFstApplier {
89 fn apply(&self, fst: &FstMap) -> Vec<u64> {
90 let mut op = OpBuilder::new();
91
92 for range in &self.ranges {
93 match (range.lower.as_ref(), range.upper.as_ref()) {
94 (Some(lower), Some(upper)) => match (lower.inclusive, upper.inclusive) {
95 (true, true) => op.push(fst.range().ge(&lower.value).le(&upper.value)),
96 (true, false) => op.push(fst.range().ge(&lower.value).lt(&upper.value)),
97 (false, true) => op.push(fst.range().gt(&lower.value).le(&upper.value)),
98 (false, false) => op.push(fst.range().gt(&lower.value).lt(&upper.value)),
99 },
100 (Some(lower), None) => match lower.inclusive {
101 true => op.push(fst.range().ge(&lower.value)),
102 false => op.push(fst.range().gt(&lower.value)),
103 },
104 (None, Some(upper)) => match upper.inclusive {
105 true => op.push(fst.range().le(&upper.value)),
106 false => op.push(fst.range().lt(&upper.value)),
107 },
108 (None, None) => op.push(fst),
109 }
110 }
111
112 for dfa in &self.dfas {
113 op.push(fst.search(dfa));
114 }
115
116 let mut stream = op.intersection().into_stream();
117 let mut values = Vec::new();
118 while let Some((_, v)) = stream.next() {
119 values.push(v[0].value)
120 }
121 values
122 }
123
124 fn memory_usage(&self) -> usize {
125 let mut size = self.ranges.capacity() * size_of::<Range>();
126 for range in &self.ranges {
127 size += range
128 .lower
129 .as_ref()
130 .map_or(0, |bound| bound.value.capacity());
131 size += range
132 .upper
133 .as_ref()
134 .map_or(0, |bound| bound.value.capacity());
135 }
136
137 size += self.dfas.capacity() * size_of::<DFA<Vec<u32>>>();
138 for dfa in &self.dfas {
139 size += dfa.0.memory_usage();
140 }
141 size
142 }
143}
144
145impl IntersectionFstApplier {
146 pub fn try_from(predicates: Vec<Predicate>) -> Result<Self> {
152 ensure!(!predicates.is_empty(), EmptyPredicatesSnafu);
153
154 let mut dfas = Vec::with_capacity(predicates.len());
155 let mut ranges = Vec::with_capacity(predicates.len());
156
157 for predicate in predicates {
158 match predicate {
159 Predicate::Range(range) => ranges.push(range.range),
160 Predicate::RegexMatch(regex) => {
161 let dfa = DFA::new(®ex.pattern);
162 let dfa = dfa.map_err(Box::new).context(ParseDFASnafu)?;
163 dfas.push(dfa);
164 }
165 Predicate::InList(_) => {
167 return IntersectionApplierWithInListSnafu.fail();
168 }
169 }
170 }
171
172 Ok(Self::new(ranges, dfas))
173 }
174}
175
176impl TryFrom<Vec<Predicate>> for IntersectionFstApplier {
177 type Error = crate::inverted_index::error::Error;
178
179 fn try_from(predicates: Vec<Predicate>) -> Result<Self> {
180 Self::try_from(predicates)
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use std::collections::HashSet;
187
188 use super::*;
189 use crate::inverted_index::error::Error;
190 use crate::inverted_index::search::predicate::{
191 Bound, InListPredicate, RangePredicate, RegexMatchPredicate,
192 };
193
194 fn create_applier_from_range(range: Range) -> Result<IntersectionFstApplier> {
195 IntersectionFstApplier::try_from(vec![Predicate::Range(RangePredicate { range })])
196 }
197
198 fn create_applier_from_pattern(pattern: &str) -> Result<IntersectionFstApplier> {
199 IntersectionFstApplier::try_from(vec![Predicate::RegexMatch(RegexMatchPredicate {
200 pattern: pattern.to_string(),
201 })])
202 }
203
204 #[test]
205 fn test_intersection_fst_applier_with_ranges() {
206 let test_fst = FstMap::from_iter([("aa", 1), ("bb", 2), ("cc", 3)]).unwrap();
207
208 let applier_inclusive_lower = create_applier_from_range(Range {
209 lower: Some(Bound {
210 value: b"bb".to_vec(),
211 inclusive: true,
212 }),
213 upper: None,
214 })
215 .unwrap();
216 let results = applier_inclusive_lower.apply(&test_fst);
217 assert_eq!(results, vec![2, 3]);
218
219 let applier_exclusive_lower = create_applier_from_range(Range {
220 lower: Some(Bound {
221 value: b"bb".to_vec(),
222 inclusive: false,
223 }),
224 upper: None,
225 })
226 .unwrap();
227 let results = applier_exclusive_lower.apply(&test_fst);
228 assert_eq!(results, vec![3]);
229
230 let applier_inclusive_upper = create_applier_from_range(Range {
231 lower: None,
232 upper: Some(Bound {
233 value: b"bb".to_vec(),
234 inclusive: true,
235 }),
236 })
237 .unwrap();
238 let results = applier_inclusive_upper.apply(&test_fst);
239 assert_eq!(results, vec![1, 2]);
240
241 let applier_exclusive_upper = create_applier_from_range(Range {
242 lower: None,
243 upper: Some(Bound {
244 value: b"bb".to_vec(),
245 inclusive: false,
246 }),
247 })
248 .unwrap();
249 let results = applier_exclusive_upper.apply(&test_fst);
250 assert_eq!(results, vec![1]);
251
252 let applier_inclusive_bounds = create_applier_from_range(Range {
253 lower: Some(Bound {
254 value: b"aa".to_vec(),
255 inclusive: true,
256 }),
257 upper: Some(Bound {
258 value: b"cc".to_vec(),
259 inclusive: true,
260 }),
261 })
262 .unwrap();
263 let results = applier_inclusive_bounds.apply(&test_fst);
264 assert_eq!(results, vec![1, 2, 3]);
265
266 let applier_exclusive_bounds = create_applier_from_range(Range {
267 lower: Some(Bound {
268 value: b"aa".to_vec(),
269 inclusive: false,
270 }),
271 upper: Some(Bound {
272 value: b"cc".to_vec(),
273 inclusive: false,
274 }),
275 })
276 .unwrap();
277 let results = applier_exclusive_bounds.apply(&test_fst);
278 assert_eq!(results, vec![2]);
279 }
280
281 #[test]
282 fn test_intersection_fst_applier_with_valid_pattern() {
283 let test_fst = FstMap::from_iter([("123", 1), ("abc", 2)]).unwrap();
284
285 let cases = vec![
286 ("1", vec![1]),
287 ("2", vec![1]),
288 ("3", vec![1]),
289 ("^1", vec![1]),
290 ("^2", vec![]),
291 ("^3", vec![]),
292 ("^1.*", vec![1]),
293 ("^.*2", vec![1]),
294 ("^.*3", vec![1]),
295 ("1$", vec![]),
296 ("2$", vec![]),
297 ("3$", vec![1]),
298 ("1.*$", vec![1]),
299 ("2.*$", vec![1]),
300 ("3.*$", vec![1]),
301 ("^1..$", vec![1]),
302 ("^.2.$", vec![1]),
303 ("^..3$", vec![1]),
304 ("^[0-9]", vec![1]),
305 ("^[0-9]+$", vec![1]),
306 ("^[0-9][0-9]$", vec![]),
307 ("^[0-9][0-9][0-9]$", vec![1]),
308 ("^123$", vec![1]),
309 ("a", vec![2]),
310 ("b", vec![2]),
311 ("c", vec![2]),
312 ("^a", vec![2]),
313 ("^b", vec![]),
314 ("^c", vec![]),
315 ("^a.*", vec![2]),
316 ("^.*b", vec![2]),
317 ("^.*c", vec![2]),
318 ("a$", vec![]),
319 ("b$", vec![]),
320 ("c$", vec![2]),
321 ("a.*$", vec![2]),
322 ("b.*$", vec![2]),
323 ("c.*$", vec![2]),
324 ("^.[a-z]", vec![2]),
325 ("^abc$", vec![2]),
326 ("^ab$", vec![]),
327 ("abc$", vec![2]),
328 ("^a.c$", vec![2]),
329 ("^..c$", vec![2]),
330 ("ab", vec![2]),
331 (".*", vec![1, 2]),
332 ("", vec![1, 2]),
333 ("^$", vec![]),
334 ("1|a", vec![1, 2]),
335 ("^123$|^abc$", vec![1, 2]),
336 ("^123$|d", vec![1]),
337 ];
338
339 for (pattern, expected) in cases {
340 let applier = create_applier_from_pattern(pattern).unwrap();
341 let results = applier.apply(&test_fst);
342 assert_eq!(results, expected);
343 }
344 }
345
346 #[test]
347 fn test_intersection_fst_applier_with_composite_predicates() {
348 let test_fst = FstMap::from_iter([("aa", 1), ("bb", 2), ("cc", 3)]).unwrap();
349
350 let applier = IntersectionFstApplier::try_from(vec![
351 Predicate::Range(RangePredicate {
352 range: Range {
353 lower: Some(Bound {
354 value: b"aa".to_vec(),
355 inclusive: true,
356 }),
357 upper: Some(Bound {
358 value: b"cc".to_vec(),
359 inclusive: true,
360 }),
361 },
362 }),
363 Predicate::RegexMatch(RegexMatchPredicate {
364 pattern: "a.?".to_string(),
365 }),
366 ])
367 .unwrap();
368 let results = applier.apply(&test_fst);
369 assert_eq!(results, vec![1]);
370
371 let applier = IntersectionFstApplier::try_from(vec![
372 Predicate::Range(RangePredicate {
373 range: Range {
374 lower: Some(Bound {
375 value: b"aa".to_vec(),
376 inclusive: false,
377 }),
378 upper: Some(Bound {
379 value: b"cc".to_vec(),
380 inclusive: true,
381 }),
382 },
383 }),
384 Predicate::RegexMatch(RegexMatchPredicate {
385 pattern: "a.?".to_string(),
386 }),
387 ])
388 .unwrap();
389 let results = applier.apply(&test_fst);
390 assert!(results.is_empty());
391 }
392
393 #[test]
394 fn test_intersection_fst_applier_with_invalid_pattern() {
395 let result = create_applier_from_pattern("a(");
396 assert!(matches!(result, Err(Error::ParseDFA { .. })));
397 }
398
399 #[test]
400 fn test_intersection_fst_applier_with_empty_predicates() {
401 let result = IntersectionFstApplier::try_from(vec![]);
402 assert!(matches!(result, Err(Error::EmptyPredicates { .. })));
403 }
404
405 #[test]
406 fn test_intersection_fst_applier_with_in_list_predicate() {
407 let result = IntersectionFstApplier::try_from(vec![Predicate::InList(InListPredicate {
408 list: HashSet::from_iter([b"one".to_vec(), b"two".to_vec()]),
409 })]);
410 assert!(matches!(
411 result,
412 Err(Error::IntersectionApplierWithInList { .. })
413 ));
414 }
415
416 #[test]
417 fn test_intersection_fst_applier_memory_usage() {
418 let applier = IntersectionFstApplier::new(vec![], vec![]);
419
420 assert_eq!(applier.memory_usage(), 0);
421
422 let dfa = DFA::new("^abc$").unwrap();
423 assert_eq!(dfa.memory_usage(), 320);
424
425 let applier = IntersectionFstApplier::new(
426 vec![Range {
427 lower: Some(Bound {
428 value: b"aa".to_vec(),
429 inclusive: true,
430 }),
431 upper: Some(Bound {
432 value: b"cc".to_vec(),
433 inclusive: true,
434 }),
435 }],
436 vec![dfa],
437 );
438 assert_eq!(
439 applier.memory_usage(),
440 size_of::<Range>() + 4 + size_of::<DFA<Vec<u32>>>() + 320
441 );
442 }
443}