Skip to main content

common_function/scalars/
matches_term.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16use std::sync::Arc;
17
18use datafusion_common::arrow::array::{Array, AsArray, BooleanArray, BooleanBuilder};
19use datafusion_common::arrow::compute;
20use datafusion_common::arrow::datatypes::DataType;
21use datafusion_common::{DataFusionError, ScalarValue};
22use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
23use icu_properties::props::Script;
24use icu_properties::{CodePointMapData, CodePointMapDataBorrowed};
25use memchr::memmem;
26
27use crate::function::Function;
28use crate::function_registry::FunctionRegistry;
29
30/// Exact term/phrase matching function for text columns.
31///
32/// This function uses script-aware matching rules:
33/// - ASCII-only terms keep whole-word style boundary matching, like Whole-word matching (e.g. "cat" in "cat!" but not in "category")
34/// - Phrase matching (e.g. "hello world" in "note:hello world!")
35/// - Terms containing Han characters match as contiguous substrings
36/// - Mixed-script identifiers and numeric terms remain searchable in Chinese text
37///
38/// # Signature
39/// `matches_term(text: String, term: String) -> Boolean`
40///
41/// # Arguments
42/// * `text` - String column to search
43/// * `term` - Search term/phrase
44///
45/// # Returns
46/// BooleanVector where each element indicates if the corresponding text
47/// contains an exact match of the term, following these rules:
48/// 1. Exact substring match found (case-sensitive)
49/// 2. For ASCII-only terms, adjacent ASCII word characters block the match
50/// 3. For Han-containing terms, contiguous substring match is sufficient
51///
52/// # Examples
53/// ```
54/// -- SQL examples --
55/// -- Match phrase with space --
56/// SELECT matches_term(column, 'hello world') FROM table;
57/// -- Text: "warning:hello world!" => true
58/// -- Text: "hello-world"          => false (hyphen instead of space)
59/// -- Text: "hello world2023"      => false (ending with numbers)
60///
61/// -- Match multiple words with boundaries --
62/// SELECT matches_term(column, 'critical error') FROM logs;
63/// -- Match in: "ERROR:critical error!"
64/// -- No match: "critical_errors"
65/// -- Chinese substring examples --
66/// SELECT matches_term(column, '手机') FROM table;
67/// -- Text: "登录手机号18888888888的动态key" => true
68///
69/// -- Empty string handling --
70/// SELECT matches_term(column, '') FROM table;
71/// -- Text: "" => true
72/// -- Text: "any" => false
73///
74/// -- Case sensitivity --
75/// SELECT matches_term(column, 'Cat') FROM table;
76/// -- Text: "Cat" => true
77/// -- Text: "cat" => false
78/// ```
79pub struct MatchesTermFunction {
80    signature: Signature,
81}
82
83impl MatchesTermFunction {
84    pub fn register(registry: &FunctionRegistry) {
85        registry.register_scalar(MatchesTermFunction::default());
86    }
87}
88
89impl Default for MatchesTermFunction {
90    fn default() -> Self {
91        Self {
92            signature: Signature::string(2, Volatility::Immutable),
93        }
94    }
95}
96
97impl fmt::Display for MatchesTermFunction {
98    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
99        write!(f, "MATCHES_TERM")
100    }
101}
102
103impl Function for MatchesTermFunction {
104    fn name(&self) -> &str {
105        "matches_term"
106    }
107
108    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
109        Ok(DataType::Boolean)
110    }
111
112    fn signature(&self) -> &Signature {
113        &self.signature
114    }
115
116    fn invoke_with_args(
117        &self,
118        args: ScalarFunctionArgs,
119    ) -> datafusion_common::Result<ColumnarValue> {
120        let [arg0, arg1] = datafusion_common::utils::take_function_args(self.name(), &args.args)?;
121
122        fn as_str(v: &ScalarValue) -> Option<&str> {
123            match v {
124                ScalarValue::Utf8View(Some(x))
125                | ScalarValue::Utf8(Some(x))
126                | ScalarValue::LargeUtf8(Some(x)) => Some(x.as_str()),
127                _ => None,
128            }
129        }
130
131        if let (ColumnarValue::Scalar(text), ColumnarValue::Scalar(term)) = (arg0, arg1) {
132            let text = as_str(text);
133            let term = as_str(term);
134            let result = match (text, term) {
135                (Some(text), Some(term)) => Some(MatchesTermFinder::new(term).find(text)),
136                _ => None,
137            };
138            return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result)));
139        };
140
141        let v = match (arg0, arg1) {
142            (ColumnarValue::Scalar(_), ColumnarValue::Scalar(_)) => {
143                // Unreachable because we have checked this case above and returned if matched.
144                unreachable!()
145            }
146            (ColumnarValue::Scalar(text), ColumnarValue::Array(terms)) => {
147                let text = as_str(text);
148                if let Some(text) = text {
149                    let terms = compute::cast(terms, &DataType::Utf8View)?;
150                    let terms = terms.as_string_view();
151
152                    let mut builder = BooleanBuilder::with_capacity(terms.len());
153                    terms.iter().for_each(|term| {
154                        builder.append_option(term.map(|x| MatchesTermFinder::new(x).find(text)))
155                    });
156                    ColumnarValue::Array(Arc::new(builder.finish()))
157                } else {
158                    ColumnarValue::Array(Arc::new(BooleanArray::new_null(terms.len())))
159                }
160            }
161            (ColumnarValue::Array(texts), ColumnarValue::Scalar(term)) => {
162                let term = as_str(term);
163                if let Some(term) = term {
164                    let finder = MatchesTermFinder::new(term);
165
166                    let texts = compute::cast(texts, &DataType::Utf8View)?;
167                    let texts = texts.as_string_view();
168
169                    let mut builder = BooleanBuilder::with_capacity(texts.len());
170                    texts
171                        .iter()
172                        .for_each(|text| builder.append_option(text.map(|x| finder.find(x))));
173                    ColumnarValue::Array(Arc::new(builder.finish()))
174                } else {
175                    ColumnarValue::Array(Arc::new(BooleanArray::new_null(texts.len())))
176                }
177            }
178            (ColumnarValue::Array(texts), ColumnarValue::Array(terms)) => {
179                let terms = compute::cast(terms, &DataType::Utf8View)?;
180                let terms = terms.as_string_view();
181                let texts = compute::cast(texts, &DataType::Utf8View)?;
182                let texts = texts.as_string_view();
183
184                let len = texts.len();
185                if terms.len() != len {
186                    return Err(DataFusionError::Internal(format!(
187                        "input arrays have different lengths: {len}, {}",
188                        terms.len()
189                    )));
190                }
191
192                let mut builder = BooleanBuilder::with_capacity(len);
193                for (text, term) in texts.iter().zip(terms.iter()) {
194                    let result = match (text, term) {
195                        (Some(text), Some(term)) => Some(MatchesTermFinder::new(term).find(text)),
196                        _ => None,
197                    };
198                    builder.append_option(result);
199                }
200                ColumnarValue::Array(Arc::new(builder.finish()))
201            }
202        };
203        Ok(v)
204    }
205}
206
207/// A compiled finder for `matches_term` function that holds the compiled term
208/// and its metadata for efficient matching.
209///
210/// A term is considered matched when:
211/// 1. The exact sequence appears in the text
212/// 2. ASCII-only terms are not adjacent to ASCII word characters
213/// 3. Han-containing terms match as contiguous substrings
214///
215/// # Examples
216/// ```
217/// let finder = MatchesTermFinder::new("cat");
218/// assert!(finder.find("cat!"));      // Term at end with punctuation
219/// assert!(finder.find("dog,cat"));   // Term preceded by comma
220/// assert!(!finder.find("category")); // Partial match rejected
221///
222/// let finder = MatchesTermFinder::new("手机");
223/// assert!(finder.find("登录手机号18888888888的动态key"));
224/// ```
225#[derive(Clone, Debug)]
226pub struct MatchesTermFinder {
227    finder: memmem::Finder<'static>,
228    term: String,
229    term_kind: TermKind,
230    starts_with_other: bool,
231    ends_with_other: bool,
232}
233
234#[derive(Clone, Copy, Debug, PartialEq, Eq)]
235enum CharClass {
236    AsciiWord,
237    Han,
238    UnicodeWord,
239    Other,
240}
241
242#[derive(Clone, Copy, Debug, PartialEq, Eq)]
243enum TermKind {
244    AsciiLike,
245    UnicodeWord,
246    HanContaining,
247}
248
249fn classify_char(c: char) -> CharClass {
250    if c.is_ascii_alphanumeric() {
251        CharClass::AsciiWord
252    } else if is_han(c) {
253        CharClass::Han
254    } else if c.is_alphanumeric() {
255        CharClass::UnicodeWord
256    } else {
257        CharClass::Other
258    }
259}
260
261static HAN_SCRIPT_DATA: CodePointMapDataBorrowed<'static, Script> =
262    CodePointMapData::<Script>::new();
263
264fn is_han(c: char) -> bool {
265    HAN_SCRIPT_DATA.get(c) == Script::Han
266}
267
268fn classify_term(term: &str) -> TermKind {
269    let mut has_han = false;
270    let mut has_unicode_word = false;
271    for c in term.chars() {
272        match classify_char(c) {
273            CharClass::AsciiWord => {}
274            CharClass::Han => has_han = true,
275            CharClass::UnicodeWord => has_unicode_word = true,
276            CharClass::Other => {}
277        }
278    }
279
280    if has_han {
281        TermKind::HanContaining
282    } else if has_unicode_word {
283        TermKind::UnicodeWord
284    } else {
285        TermKind::AsciiLike
286    }
287}
288
289fn boundary_ok(term_kind: TermKind, neighbor: Option<char>, term_has_other_boundary: bool) -> bool {
290    if term_has_other_boundary {
291        return true;
292    }
293
294    match term_kind {
295        TermKind::AsciiLike => !matches!(neighbor.map(classify_char), Some(CharClass::AsciiWord)),
296        TermKind::UnicodeWord => !matches!(
297            neighbor.map(classify_char),
298            Some(CharClass::AsciiWord | CharClass::UnicodeWord | CharClass::Han)
299        ),
300        TermKind::HanContaining => true,
301    }
302}
303
304impl MatchesTermFinder {
305    /// Create a new `MatchesTermFinder` for the given term.
306    pub fn new(term: &str) -> Self {
307        let starts_with_other = term
308            .chars()
309            .next()
310            .is_some_and(|c| classify_char(c) == CharClass::Other);
311        let ends_with_other = term
312            .chars()
313            .last()
314            .is_some_and(|c| classify_char(c) == CharClass::Other);
315        Self {
316            finder: memmem::Finder::new(term).into_owned(),
317            term: term.to_string(),
318            term_kind: classify_term(term),
319            starts_with_other,
320            ends_with_other,
321        }
322    }
323
324    /// Find the term in the text.
325    pub fn find(&self, text: &str) -> bool {
326        if self.term.is_empty() {
327            return text.is_empty();
328        }
329
330        if text.len() < self.term.len() {
331            return false;
332        }
333
334        let mut pos = 0;
335        while let Some(found_pos) = self.finder.find(&text.as_bytes()[pos..]) {
336            let actual_pos = pos + found_pos;
337
338            let prev = text[..actual_pos].chars().last();
339            let prev_ok = self.starts_with_other || boundary_ok(self.term_kind, prev, false);
340
341            if prev_ok {
342                if self.term_kind == TermKind::HanContaining {
343                    return true;
344                }
345
346                let next_pos = actual_pos + self.finder.needle().len();
347                let next = text[next_pos..].chars().next();
348                let next_ok = self.ends_with_other || boundary_ok(self.term_kind, next, false);
349
350                if next_ok {
351                    return true;
352                }
353            }
354
355            if let Some(next_char) = text[actual_pos..].chars().next() {
356                pos = actual_pos + next_char.len_utf8();
357            } else {
358                break;
359            }
360        }
361
362        false
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn matches_term_example() {
372        let finder = MatchesTermFinder::new("hello world");
373        assert!(finder.find("warning:hello world!"));
374        assert!(!finder.find("hello-world"));
375        assert!(!finder.find("hello world2023"));
376
377        let finder = MatchesTermFinder::new("critical error");
378        assert!(finder.find("ERROR:critical error!"));
379        assert!(!finder.find("critical_errors"));
380
381        let finder = MatchesTermFinder::new("");
382        assert!(finder.find(""));
383        assert!(!finder.find("any"));
384
385        let finder = MatchesTermFinder::new("Cat");
386        assert!(finder.find("Cat"));
387        assert!(!finder.find("cat"));
388    }
389
390    #[test]
391    fn matches_term_with_punctuation() {
392        assert!(MatchesTermFinder::new("cat").find("cat!"));
393        assert!(MatchesTermFinder::new("dog").find("!dog"));
394    }
395
396    #[test]
397    fn matches_phrase_with_boundaries() {
398        assert!(MatchesTermFinder::new("hello-world").find("hello-world"));
399        assert!(MatchesTermFinder::new("'foo bar'").find("test: 'foo bar'"));
400    }
401
402    #[test]
403    fn matches_at_text_boundaries() {
404        assert!(MatchesTermFinder::new("start").find("start..."));
405        assert!(MatchesTermFinder::new("end").find("...end"));
406    }
407
408    // Negative cases
409    #[test]
410    fn rejects_partial_matches() {
411        assert!(!MatchesTermFinder::new("cat").find("category"));
412        assert!(!MatchesTermFinder::new("boot").find("rebooted"));
413    }
414
415    #[test]
416    fn rejects_missing_term() {
417        assert!(!MatchesTermFinder::new("foo").find("hello world"));
418    }
419
420    // Edge cases
421    #[test]
422    fn handles_empty_inputs() {
423        assert!(!MatchesTermFinder::new("test").find(""));
424        assert!(!MatchesTermFinder::new("").find("text"));
425    }
426
427    #[test]
428    fn different_unicode_boundaries() {
429        assert!(MatchesTermFinder::new("café").find("café>"));
430        assert!(!MatchesTermFinder::new("café").find("口café>"));
431        assert!(!MatchesTermFinder::new("café").find("café口"));
432        assert!(!MatchesTermFinder::new("café").find("cafémore"));
433        assert!(MatchesTermFinder::new("русский").find("русский!"));
434        assert!(MatchesTermFinder::new("русский").find("русский!"));
435    }
436
437    #[test]
438    fn case_sensitive_matching() {
439        assert!(!MatchesTermFinder::new("cat").find("Cat"));
440        assert!(MatchesTermFinder::new("CaT").find("CaT"));
441    }
442
443    #[test]
444    fn numbers_in_term() {
445        assert!(MatchesTermFinder::new("v1.0").find("v1.0!"));
446        assert!(!MatchesTermFinder::new("v1.0").find("v1.0a"));
447    }
448
449    #[test]
450    fn mixed_script_terms_match_inside_chinese_context() {
451        let text = "登录手机号18888888888的动态key";
452        assert!(MatchesTermFinder::new("手机号").find(text));
453        assert!(MatchesTermFinder::new("18888888888").find(text));
454        assert!(MatchesTermFinder::new("手机").find(text));
455        assert!(MatchesTermFinder::new("机号").find(text));
456        assert!(MatchesTermFinder::new("机号1888").find(text));
457        assert!(MatchesTermFinder::new("农业").find("中国农业银行"));
458        assert!(MatchesTermFinder::new("error").find("错误error日志"));
459    }
460
461    #[test]
462    fn underscore_still_counts_as_boundary_for_ascii_terms() {
463        assert!(MatchesTermFinder::new("world").find("hello_world"));
464        assert!(MatchesTermFinder::new("id").find("trace_id=abc"));
465        assert!(!MatchesTermFinder::new("error").find("criticalerrors"));
466    }
467
468    #[test]
469    fn adjacent_alphanumeric_fails() {
470        assert!(!MatchesTermFinder::new("cat").find("cat5"));
471        assert!(!MatchesTermFinder::new("dog").find("dogcat"));
472    }
473
474    #[test]
475    fn empty_term_text() {
476        assert!(!MatchesTermFinder::new("").find("text"));
477        assert!(MatchesTermFinder::new("").find(""));
478        assert!(!MatchesTermFinder::new("text").find(""));
479    }
480
481    #[test]
482    fn leading_non_alphanumeric() {
483        assert!(MatchesTermFinder::new("/cat").find("dog/cat"));
484        assert!(MatchesTermFinder::new("dog/").find("dog/cat"));
485        assert!(MatchesTermFinder::new("dog/cat").find("dog/cat"));
486    }
487
488    #[test]
489    fn continues_searching_after_boundary_mismatch() {
490        assert!(!MatchesTermFinder::new("log").find("bloglog!"));
491        assert!(MatchesTermFinder::new("log").find("bloglog log"));
492        assert!(MatchesTermFinder::new("log").find("alogblog_log!"));
493
494        assert!(MatchesTermFinder::new("error").find("errorlog_error_case"));
495        assert!(MatchesTermFinder::new("test").find("atestbtestc_test_end"));
496        assert!(MatchesTermFinder::new("data").find("database_data_store"));
497        assert!(!MatchesTermFinder::new("data").find("database_datastore"));
498        assert!(MatchesTermFinder::new("log.txt").find("catalog.txt_log.txt!"));
499        assert!(!MatchesTermFinder::new("log.txt").find("catalog.txtlog.txt!"));
500        assert!(MatchesTermFinder::new("data-set").find("bigdata-set_data-set!"));
501
502        assert!(MatchesTermFinder::new("中文").find("这是中文测试,中文!"));
503        assert!(MatchesTermFinder::new("error").find("错误errorerror日志_error!"));
504    }
505
506    #[test]
507    fn han_terms_match_as_contiguous_substrings() {
508        assert!(MatchesTermFinder::new("行账号").find("中国农业银行账号"));
509        assert!(MatchesTermFinder::new("登录").find("登录手机号18888888888的动态key"));
510    }
511
512    #[test]
513    fn han_detection_uses_script_not_all_cjk() {
514        assert!(is_han('汉'));
515        assert!(is_han('\u{30000}'));
516        assert!(!is_han('あ'));
517        assert!(!is_han('한'));
518    }
519}