index/inverted_index/search/fst_apply/
intersection_apply.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::mem::size_of;
16
17use fst::map::OpBuilder;
18use fst::{IntoStreamer, Streamer};
19use regex_automata::dfa::dense::DFA;
20use regex_automata::dfa::Automaton;
21use regex_automata::util::primitives::StateID;
22use regex_automata::util::start::Config;
23use regex_automata::Anchored;
24use snafu::{ensure, ResultExt};
25
26use crate::inverted_index::error::{
27    EmptyPredicatesSnafu, IntersectionApplierWithInListSnafu, ParseDFASnafu, Result,
28};
29use crate::inverted_index::search::fst_apply::FstApplier;
30use crate::inverted_index::search::predicate::{Predicate, Range};
31use crate::inverted_index::FstMap;
32
33/// `IntersectionFstApplier` applies intersection operations on an FstMap using specified ranges and regex patterns.
34pub struct IntersectionFstApplier {
35    /// A list of `Range` which define inclusive or exclusive ranges for keys to be queried in the FstMap.
36    ranges: Vec<Range>,
37
38    /// A list of `Dfa` compiled from regular expression patterns.
39    dfas: Vec<DfaFstAutomaton>,
40}
41
42#[derive(Debug)]
43struct DfaFstAutomaton(DFA<Vec<u32>>);
44
45impl fst::Automaton for DfaFstAutomaton {
46    type State = StateID;
47
48    #[inline]
49    fn start(&self) -> Self::State {
50        let config = Config::new().anchored(Anchored::No);
51        self.0.start_state(&config).unwrap()
52    }
53
54    #[inline]
55    fn is_match(&self, state: &Self::State) -> bool {
56        self.0.is_match_state(*state)
57    }
58
59    #[inline]
60    fn can_match(&self, state: &Self::State) -> bool {
61        !self.0.is_dead_state(*state)
62    }
63
64    #[inline]
65    fn accept_eof(&self, state: &StateID) -> Option<StateID> {
66        if self.0.is_match_state(*state) {
67            return Some(*state);
68        }
69        Some(self.0.next_eoi_state(*state))
70    }
71
72    #[inline]
73    fn accept(&self, state: &Self::State, byte: u8) -> Self::State {
74        if self.0.is_match_state(*state) {
75            return *state;
76        }
77        self.0.next_state(*state, byte)
78    }
79}
80
81impl IntersectionFstApplier {
82    fn new(ranges: Vec<Range>, dfas: Vec<DFA<Vec<u32>>>) -> Self {
83        let dfas = dfas.into_iter().map(DfaFstAutomaton).collect();
84        Self { ranges, dfas }
85    }
86}
87
88impl FstApplier for IntersectionFstApplier {
89    fn apply(&self, fst: &FstMap) -> Vec<u64> {
90        let mut op = OpBuilder::new();
91
92        for range in &self.ranges {
93            match (range.lower.as_ref(), range.upper.as_ref()) {
94                (Some(lower), Some(upper)) => match (lower.inclusive, upper.inclusive) {
95                    (true, true) => op.push(fst.range().ge(&lower.value).le(&upper.value)),
96                    (true, false) => op.push(fst.range().ge(&lower.value).lt(&upper.value)),
97                    (false, true) => op.push(fst.range().gt(&lower.value).le(&upper.value)),
98                    (false, false) => op.push(fst.range().gt(&lower.value).lt(&upper.value)),
99                },
100                (Some(lower), None) => match lower.inclusive {
101                    true => op.push(fst.range().ge(&lower.value)),
102                    false => op.push(fst.range().gt(&lower.value)),
103                },
104                (None, Some(upper)) => match upper.inclusive {
105                    true => op.push(fst.range().le(&upper.value)),
106                    false => op.push(fst.range().lt(&upper.value)),
107                },
108                (None, None) => op.push(fst),
109            }
110        }
111
112        for dfa in &self.dfas {
113            op.push(fst.search(dfa));
114        }
115
116        let mut stream = op.intersection().into_stream();
117        let mut values = Vec::new();
118        while let Some((_, v)) = stream.next() {
119            values.push(v[0].value)
120        }
121        values
122    }
123
124    fn memory_usage(&self) -> usize {
125        let mut size = self.ranges.capacity() * size_of::<Range>();
126        for range in &self.ranges {
127            size += range
128                .lower
129                .as_ref()
130                .map_or(0, |bound| bound.value.capacity());
131            size += range
132                .upper
133                .as_ref()
134                .map_or(0, |bound| bound.value.capacity());
135        }
136
137        size += self.dfas.capacity() * size_of::<DFA<Vec<u32>>>();
138        for dfa in &self.dfas {
139            size += dfa.0.memory_usage();
140        }
141        size
142    }
143}
144
145impl IntersectionFstApplier {
146    /// Attempts to create an `IntersectionFstApplier` from a list of `Predicate`.
147    ///
148    /// This function only accepts predicates of the variants `Range` and `RegexMatch`.
149    /// It does not accept `InList` predicates and will return an error if any are found.
150    /// `InList` predicates are handled by `KeysFstApplier`.
151    pub fn try_from(predicates: Vec<Predicate>) -> Result<Self> {
152        ensure!(!predicates.is_empty(), EmptyPredicatesSnafu);
153
154        let mut dfas = Vec::with_capacity(predicates.len());
155        let mut ranges = Vec::with_capacity(predicates.len());
156
157        for predicate in predicates {
158            match predicate {
159                Predicate::Range(range) => ranges.push(range.range),
160                Predicate::RegexMatch(regex) => {
161                    let dfa = DFA::new(&regex.pattern);
162                    let dfa = dfa.map_err(Box::new).context(ParseDFASnafu)?;
163                    dfas.push(dfa);
164                }
165                // Rejection of `InList` predicates is enforced here.
166                Predicate::InList(_) => {
167                    return IntersectionApplierWithInListSnafu.fail();
168                }
169            }
170        }
171
172        Ok(Self::new(ranges, dfas))
173    }
174}
175
176impl TryFrom<Vec<Predicate>> for IntersectionFstApplier {
177    type Error = crate::inverted_index::error::Error;
178
179    fn try_from(predicates: Vec<Predicate>) -> Result<Self> {
180        Self::try_from(predicates)
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use std::collections::HashSet;
187
188    use super::*;
189    use crate::inverted_index::error::Error;
190    use crate::inverted_index::search::predicate::{
191        Bound, InListPredicate, RangePredicate, RegexMatchPredicate,
192    };
193
194    fn create_applier_from_range(range: Range) -> Result<IntersectionFstApplier> {
195        IntersectionFstApplier::try_from(vec![Predicate::Range(RangePredicate { range })])
196    }
197
198    fn create_applier_from_pattern(pattern: &str) -> Result<IntersectionFstApplier> {
199        IntersectionFstApplier::try_from(vec![Predicate::RegexMatch(RegexMatchPredicate {
200            pattern: pattern.to_string(),
201        })])
202    }
203
204    #[test]
205    fn test_intersection_fst_applier_with_ranges() {
206        let test_fst = FstMap::from_iter([("aa", 1), ("bb", 2), ("cc", 3)]).unwrap();
207
208        let applier_inclusive_lower = create_applier_from_range(Range {
209            lower: Some(Bound {
210                value: b"bb".to_vec(),
211                inclusive: true,
212            }),
213            upper: None,
214        })
215        .unwrap();
216        let results = applier_inclusive_lower.apply(&test_fst);
217        assert_eq!(results, vec![2, 3]);
218
219        let applier_exclusive_lower = create_applier_from_range(Range {
220            lower: Some(Bound {
221                value: b"bb".to_vec(),
222                inclusive: false,
223            }),
224            upper: None,
225        })
226        .unwrap();
227        let results = applier_exclusive_lower.apply(&test_fst);
228        assert_eq!(results, vec![3]);
229
230        let applier_inclusive_upper = create_applier_from_range(Range {
231            lower: None,
232            upper: Some(Bound {
233                value: b"bb".to_vec(),
234                inclusive: true,
235            }),
236        })
237        .unwrap();
238        let results = applier_inclusive_upper.apply(&test_fst);
239        assert_eq!(results, vec![1, 2]);
240
241        let applier_exclusive_upper = create_applier_from_range(Range {
242            lower: None,
243            upper: Some(Bound {
244                value: b"bb".to_vec(),
245                inclusive: false,
246            }),
247        })
248        .unwrap();
249        let results = applier_exclusive_upper.apply(&test_fst);
250        assert_eq!(results, vec![1]);
251
252        let applier_inclusive_bounds = create_applier_from_range(Range {
253            lower: Some(Bound {
254                value: b"aa".to_vec(),
255                inclusive: true,
256            }),
257            upper: Some(Bound {
258                value: b"cc".to_vec(),
259                inclusive: true,
260            }),
261        })
262        .unwrap();
263        let results = applier_inclusive_bounds.apply(&test_fst);
264        assert_eq!(results, vec![1, 2, 3]);
265
266        let applier_exclusive_bounds = create_applier_from_range(Range {
267            lower: Some(Bound {
268                value: b"aa".to_vec(),
269                inclusive: false,
270            }),
271            upper: Some(Bound {
272                value: b"cc".to_vec(),
273                inclusive: false,
274            }),
275        })
276        .unwrap();
277        let results = applier_exclusive_bounds.apply(&test_fst);
278        assert_eq!(results, vec![2]);
279    }
280
281    #[test]
282    fn test_intersection_fst_applier_with_valid_pattern() {
283        let test_fst = FstMap::from_iter([("123", 1), ("abc", 2)]).unwrap();
284
285        let cases = vec![
286            ("1", vec![1]),
287            ("2", vec![1]),
288            ("3", vec![1]),
289            ("^1", vec![1]),
290            ("^2", vec![]),
291            ("^3", vec![]),
292            ("^1.*", vec![1]),
293            ("^.*2", vec![1]),
294            ("^.*3", vec![1]),
295            ("1$", vec![]),
296            ("2$", vec![]),
297            ("3$", vec![1]),
298            ("1.*$", vec![1]),
299            ("2.*$", vec![1]),
300            ("3.*$", vec![1]),
301            ("^1..$", vec![1]),
302            ("^.2.$", vec![1]),
303            ("^..3$", vec![1]),
304            ("^[0-9]", vec![1]),
305            ("^[0-9]+$", vec![1]),
306            ("^[0-9][0-9]$", vec![]),
307            ("^[0-9][0-9][0-9]$", vec![1]),
308            ("^123$", vec![1]),
309            ("a", vec![2]),
310            ("b", vec![2]),
311            ("c", vec![2]),
312            ("^a", vec![2]),
313            ("^b", vec![]),
314            ("^c", vec![]),
315            ("^a.*", vec![2]),
316            ("^.*b", vec![2]),
317            ("^.*c", vec![2]),
318            ("a$", vec![]),
319            ("b$", vec![]),
320            ("c$", vec![2]),
321            ("a.*$", vec![2]),
322            ("b.*$", vec![2]),
323            ("c.*$", vec![2]),
324            ("^.[a-z]", vec![2]),
325            ("^abc$", vec![2]),
326            ("^ab$", vec![]),
327            ("abc$", vec![2]),
328            ("^a.c$", vec![2]),
329            ("^..c$", vec![2]),
330            ("ab", vec![2]),
331            (".*", vec![1, 2]),
332            ("", vec![1, 2]),
333            ("^$", vec![]),
334            ("1|a", vec![1, 2]),
335            ("^123$|^abc$", vec![1, 2]),
336            ("^123$|d", vec![1]),
337        ];
338
339        for (pattern, expected) in cases {
340            let applier = create_applier_from_pattern(pattern).unwrap();
341            let results = applier.apply(&test_fst);
342            assert_eq!(results, expected);
343        }
344    }
345
346    #[test]
347    fn test_intersection_fst_applier_with_composite_predicates() {
348        let test_fst = FstMap::from_iter([("aa", 1), ("bb", 2), ("cc", 3)]).unwrap();
349
350        let applier = IntersectionFstApplier::try_from(vec![
351            Predicate::Range(RangePredicate {
352                range: Range {
353                    lower: Some(Bound {
354                        value: b"aa".to_vec(),
355                        inclusive: true,
356                    }),
357                    upper: Some(Bound {
358                        value: b"cc".to_vec(),
359                        inclusive: true,
360                    }),
361                },
362            }),
363            Predicate::RegexMatch(RegexMatchPredicate {
364                pattern: "a.?".to_string(),
365            }),
366        ])
367        .unwrap();
368        let results = applier.apply(&test_fst);
369        assert_eq!(results, vec![1]);
370
371        let applier = IntersectionFstApplier::try_from(vec![
372            Predicate::Range(RangePredicate {
373                range: Range {
374                    lower: Some(Bound {
375                        value: b"aa".to_vec(),
376                        inclusive: false,
377                    }),
378                    upper: Some(Bound {
379                        value: b"cc".to_vec(),
380                        inclusive: true,
381                    }),
382                },
383            }),
384            Predicate::RegexMatch(RegexMatchPredicate {
385                pattern: "a.?".to_string(),
386            }),
387        ])
388        .unwrap();
389        let results = applier.apply(&test_fst);
390        assert!(results.is_empty());
391    }
392
393    #[test]
394    fn test_intersection_fst_applier_with_invalid_pattern() {
395        let result = create_applier_from_pattern("a(");
396        assert!(matches!(result, Err(Error::ParseDFA { .. })));
397    }
398
399    #[test]
400    fn test_intersection_fst_applier_with_empty_predicates() {
401        let result = IntersectionFstApplier::try_from(vec![]);
402        assert!(matches!(result, Err(Error::EmptyPredicates { .. })));
403    }
404
405    #[test]
406    fn test_intersection_fst_applier_with_in_list_predicate() {
407        let result = IntersectionFstApplier::try_from(vec![Predicate::InList(InListPredicate {
408            list: HashSet::from_iter([b"one".to_vec(), b"two".to_vec()]),
409        })]);
410        assert!(matches!(
411            result,
412            Err(Error::IntersectionApplierWithInList { .. })
413        ));
414    }
415
416    #[test]
417    fn test_intersection_fst_applier_memory_usage() {
418        let applier = IntersectionFstApplier::new(vec![], vec![]);
419
420        assert_eq!(applier.memory_usage(), 0);
421
422        let dfa = DFA::new("^abc$").unwrap();
423        assert_eq!(dfa.memory_usage(), 320);
424
425        let applier = IntersectionFstApplier::new(
426            vec![Range {
427                lower: Some(Bound {
428                    value: b"aa".to_vec(),
429                    inclusive: true,
430                }),
431                upper: Some(Bound {
432                    value: b"cc".to_vec(),
433                    inclusive: true,
434                }),
435            }],
436            vec![dfa],
437        );
438        assert_eq!(
439            applier.memory_usage(),
440            size_of::<Range>() + 4 + size_of::<DFA<Vec<u32>>>() + 320
441        );
442    }
443}