common_function/scalars/
matches_term.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16use std::sync::Arc;
17
18use datafusion_common::arrow::array::{Array, AsArray, BooleanArray, BooleanBuilder};
19use datafusion_common::arrow::compute;
20use datafusion_common::arrow::datatypes::DataType;
21use datafusion_common::{DataFusionError, ScalarValue};
22use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
23use memchr::memmem;
24
25use crate::function::Function;
26use crate::function_registry::FunctionRegistry;
27
28/// Exact term/phrase matching function for text columns.
29///
30/// This function checks if a text column contains exact term/phrase matches
31/// with non-alphanumeric boundaries. Designed for:
32/// - Whole-word matching (e.g. "cat" in "cat!" but not in "category")
33/// - Phrase matching (e.g. "hello world" in "note:hello world!")
34///
35/// # Signature
36/// `matches_term(text: String, term: String) -> Boolean`
37///
38/// # Arguments
39/// * `text` - String column to search
40/// * `term` - Search term/phrase
41///
42/// # Returns
43/// BooleanVector where each element indicates if the corresponding text
44/// contains an exact match of the term, following these rules:
45/// 1. Exact substring match found (case-sensitive)
46/// 2. Match boundaries are either:
47///    - Start/end of text
48///    - Any non-alphanumeric character (including spaces, hyphens, punctuation, etc.)
49///
50/// # Examples
51/// ```
52/// -- SQL examples --
53/// -- Match phrase with space --
54/// SELECT matches_term(column, 'hello world') FROM table;
55/// -- Text: "warning:hello world!" => true
56/// -- Text: "hello-world"          => false (hyphen instead of space)
57/// -- Text: "hello world2023"      => false (ending with numbers)
58///
59/// -- Match multiple words with boundaries --
60/// SELECT matches_term(column, 'critical error') FROM logs;
61/// -- Match in: "ERROR:critical error!"
62/// -- No match: "critical_errors"
63///
64/// -- Empty string handling --
65/// SELECT matches_term(column, '') FROM table;
66/// -- Text: "" => true
67/// -- Text: "any" => false
68///
69/// -- Case sensitivity --
70/// SELECT matches_term(column, 'Cat') FROM table;
71/// -- Text: "Cat" => true
72/// -- Text: "cat" => false
73/// ```
74pub struct MatchesTermFunction {
75    signature: Signature,
76}
77
78impl MatchesTermFunction {
79    pub fn register(registry: &FunctionRegistry) {
80        registry.register_scalar(MatchesTermFunction::default());
81    }
82}
83
84impl Default for MatchesTermFunction {
85    fn default() -> Self {
86        Self {
87            signature: Signature::string(2, Volatility::Immutable),
88        }
89    }
90}
91
92impl fmt::Display for MatchesTermFunction {
93    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
94        write!(f, "MATCHES_TERM")
95    }
96}
97
98impl Function for MatchesTermFunction {
99    fn name(&self) -> &str {
100        "matches_term"
101    }
102
103    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
104        Ok(DataType::Boolean)
105    }
106
107    fn signature(&self) -> &Signature {
108        &self.signature
109    }
110
111    fn invoke_with_args(
112        &self,
113        args: ScalarFunctionArgs,
114    ) -> datafusion_common::Result<ColumnarValue> {
115        let [arg0, arg1] = datafusion_common::utils::take_function_args(self.name(), &args.args)?;
116
117        fn as_str(v: &ScalarValue) -> Option<&str> {
118            match v {
119                ScalarValue::Utf8View(Some(x))
120                | ScalarValue::Utf8(Some(x))
121                | ScalarValue::LargeUtf8(Some(x)) => Some(x.as_str()),
122                _ => None,
123            }
124        }
125
126        if let (ColumnarValue::Scalar(text), ColumnarValue::Scalar(term)) = (arg0, arg1) {
127            let text = as_str(text);
128            let term = as_str(term);
129            let result = match (text, term) {
130                (Some(text), Some(term)) => Some(MatchesTermFinder::new(term).find(text)),
131                _ => None,
132            };
133            return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result)));
134        };
135
136        let v = match (arg0, arg1) {
137            (ColumnarValue::Scalar(_), ColumnarValue::Scalar(_)) => {
138                // Unreachable because we have checked this case above and returned if matched.
139                unreachable!()
140            }
141            (ColumnarValue::Scalar(text), ColumnarValue::Array(terms)) => {
142                let text = as_str(text);
143                if let Some(text) = text {
144                    let terms = compute::cast(terms, &DataType::Utf8View)?;
145                    let terms = terms.as_string_view();
146
147                    let mut builder = BooleanBuilder::with_capacity(terms.len());
148                    terms.iter().for_each(|term| {
149                        builder.append_option(term.map(|x| MatchesTermFinder::new(x).find(text)))
150                    });
151                    ColumnarValue::Array(Arc::new(builder.finish()))
152                } else {
153                    ColumnarValue::Array(Arc::new(BooleanArray::new_null(terms.len())))
154                }
155            }
156            (ColumnarValue::Array(texts), ColumnarValue::Scalar(term)) => {
157                let term = as_str(term);
158                if let Some(term) = term {
159                    let finder = MatchesTermFinder::new(term);
160
161                    let texts = compute::cast(texts, &DataType::Utf8View)?;
162                    let texts = texts.as_string_view();
163
164                    let mut builder = BooleanBuilder::with_capacity(texts.len());
165                    texts
166                        .iter()
167                        .for_each(|text| builder.append_option(text.map(|x| finder.find(x))));
168                    ColumnarValue::Array(Arc::new(builder.finish()))
169                } else {
170                    ColumnarValue::Array(Arc::new(BooleanArray::new_null(texts.len())))
171                }
172            }
173            (ColumnarValue::Array(texts), ColumnarValue::Array(terms)) => {
174                let terms = compute::cast(terms, &DataType::Utf8View)?;
175                let terms = terms.as_string_view();
176                let texts = compute::cast(texts, &DataType::Utf8View)?;
177                let texts = texts.as_string_view();
178
179                let len = texts.len();
180                if terms.len() != len {
181                    return Err(DataFusionError::Internal(format!(
182                        "input arrays have different lengths: {len}, {}",
183                        terms.len()
184                    )));
185                }
186
187                let mut builder = BooleanBuilder::with_capacity(len);
188                for (text, term) in texts.iter().zip(terms.iter()) {
189                    let result = match (text, term) {
190                        (Some(text), Some(term)) => Some(MatchesTermFinder::new(term).find(text)),
191                        _ => None,
192                    };
193                    builder.append_option(result);
194                }
195                ColumnarValue::Array(Arc::new(builder.finish()))
196            }
197        };
198        Ok(v)
199    }
200}
201
202/// A compiled finder for `matches_term` function that holds the compiled term
203/// and its metadata for efficient matching.
204///
205/// A term is considered matched when:
206/// 1. The exact sequence appears in the text
207/// 2. It is either:
208///    - At the start/end of text with adjacent non-alphanumeric character
209///    - Surrounded by non-alphanumeric characters
210///
211/// # Examples
212/// ```
213/// let finder = MatchesTermFinder::new("cat");
214/// assert!(finder.find("cat!"));      // Term at end with punctuation
215/// assert!(finder.find("dog,cat"));   // Term preceded by comma
216/// assert!(!finder.find("category")); // Partial match rejected
217///
218/// let finder = MatchesTermFinder::new("world");
219/// assert!(finder.find("hello-world")); // Hyphen boundary
220/// ```
221#[derive(Clone, Debug)]
222pub struct MatchesTermFinder {
223    finder: memmem::Finder<'static>,
224    term: String,
225    starts_with_non_alnum: bool,
226    ends_with_non_alnum: bool,
227}
228
229impl MatchesTermFinder {
230    /// Create a new `MatchesTermFinder` for the given term.
231    pub fn new(term: &str) -> Self {
232        let starts_with_non_alnum = term.chars().next().is_some_and(|c| !c.is_alphanumeric());
233        let ends_with_non_alnum = term.chars().last().is_some_and(|c| !c.is_alphanumeric());
234
235        Self {
236            finder: memmem::Finder::new(term).into_owned(),
237            term: term.to_string(),
238            starts_with_non_alnum,
239            ends_with_non_alnum,
240        }
241    }
242
243    /// Find the term in the text.
244    pub fn find(&self, text: &str) -> bool {
245        if self.term.is_empty() {
246            return text.is_empty();
247        }
248
249        if text.len() < self.term.len() {
250            return false;
251        }
252
253        let mut pos = 0;
254        while let Some(found_pos) = self.finder.find(&text.as_bytes()[pos..]) {
255            let actual_pos = pos + found_pos;
256
257            let prev_ok = self.starts_with_non_alnum
258                || text[..actual_pos]
259                    .chars()
260                    .last()
261                    .map(|c| !c.is_alphanumeric())
262                    .unwrap_or(true);
263
264            if prev_ok {
265                let next_pos = actual_pos + self.finder.needle().len();
266                let next_ok = self.ends_with_non_alnum
267                    || text[next_pos..]
268                        .chars()
269                        .next()
270                        .map(|c| !c.is_alphanumeric())
271                        .unwrap_or(true);
272
273                if next_ok {
274                    return true;
275                }
276            }
277
278            if let Some(next_char) = text[actual_pos..].chars().next() {
279                pos = actual_pos + next_char.len_utf8();
280            } else {
281                break;
282            }
283        }
284
285        false
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn matches_term_example() {
295        let finder = MatchesTermFinder::new("hello world");
296        assert!(finder.find("warning:hello world!"));
297        assert!(!finder.find("hello-world"));
298        assert!(!finder.find("hello world2023"));
299
300        let finder = MatchesTermFinder::new("critical error");
301        assert!(finder.find("ERROR:critical error!"));
302        assert!(!finder.find("critical_errors"));
303
304        let finder = MatchesTermFinder::new("");
305        assert!(finder.find(""));
306        assert!(!finder.find("any"));
307
308        let finder = MatchesTermFinder::new("Cat");
309        assert!(finder.find("Cat"));
310        assert!(!finder.find("cat"));
311    }
312
313    #[test]
314    fn matches_term_with_punctuation() {
315        assert!(MatchesTermFinder::new("cat").find("cat!"));
316        assert!(MatchesTermFinder::new("dog").find("!dog"));
317    }
318
319    #[test]
320    fn matches_phrase_with_boundaries() {
321        assert!(MatchesTermFinder::new("hello-world").find("hello-world"));
322        assert!(MatchesTermFinder::new("'foo bar'").find("test: 'foo bar'"));
323    }
324
325    #[test]
326    fn matches_at_text_boundaries() {
327        assert!(MatchesTermFinder::new("start").find("start..."));
328        assert!(MatchesTermFinder::new("end").find("...end"));
329    }
330
331    // Negative cases
332    #[test]
333    fn rejects_partial_matches() {
334        assert!(!MatchesTermFinder::new("cat").find("category"));
335        assert!(!MatchesTermFinder::new("boot").find("rebooted"));
336    }
337
338    #[test]
339    fn rejects_missing_term() {
340        assert!(!MatchesTermFinder::new("foo").find("hello world"));
341    }
342
343    // Edge cases
344    #[test]
345    fn handles_empty_inputs() {
346        assert!(!MatchesTermFinder::new("test").find(""));
347        assert!(!MatchesTermFinder::new("").find("text"));
348    }
349
350    #[test]
351    fn different_unicode_boundaries() {
352        assert!(MatchesTermFinder::new("café").find("café>"));
353        assert!(!MatchesTermFinder::new("café").find("口café>"));
354        assert!(!MatchesTermFinder::new("café").find("café口"));
355        assert!(!MatchesTermFinder::new("café").find("cafémore"));
356        assert!(MatchesTermFinder::new("русский").find("русский!"));
357        assert!(MatchesTermFinder::new("русский").find("русский!"));
358    }
359
360    #[test]
361    fn case_sensitive_matching() {
362        assert!(!MatchesTermFinder::new("cat").find("Cat"));
363        assert!(MatchesTermFinder::new("CaT").find("CaT"));
364    }
365
366    #[test]
367    fn numbers_in_term() {
368        assert!(MatchesTermFinder::new("v1.0").find("v1.0!"));
369        assert!(!MatchesTermFinder::new("v1.0").find("v1.0a"));
370    }
371
372    #[test]
373    fn adjacent_alphanumeric_fails() {
374        assert!(!MatchesTermFinder::new("cat").find("cat5"));
375        assert!(!MatchesTermFinder::new("dog").find("dogcat"));
376    }
377
378    #[test]
379    fn empty_term_text() {
380        assert!(!MatchesTermFinder::new("").find("text"));
381        assert!(MatchesTermFinder::new("").find(""));
382        assert!(!MatchesTermFinder::new("text").find(""));
383    }
384
385    #[test]
386    fn leading_non_alphanumeric() {
387        assert!(MatchesTermFinder::new("/cat").find("dog/cat"));
388        assert!(MatchesTermFinder::new("dog/").find("dog/cat"));
389        assert!(MatchesTermFinder::new("dog/cat").find("dog/cat"));
390    }
391
392    #[test]
393    fn continues_searching_after_boundary_mismatch() {
394        assert!(!MatchesTermFinder::new("log").find("bloglog!"));
395        assert!(MatchesTermFinder::new("log").find("bloglog log"));
396        assert!(MatchesTermFinder::new("log").find("alogblog_log!"));
397
398        assert!(MatchesTermFinder::new("error").find("errorlog_error_case"));
399        assert!(MatchesTermFinder::new("test").find("atestbtestc_test_end"));
400        assert!(MatchesTermFinder::new("data").find("database_data_store"));
401        assert!(!MatchesTermFinder::new("data").find("database_datastore"));
402        assert!(MatchesTermFinder::new("log.txt").find("catalog.txt_log.txt!"));
403        assert!(!MatchesTermFinder::new("log.txt").find("catalog.txtlog.txt!"));
404        assert!(MatchesTermFinder::new("data-set").find("bigdata-set_data-set!"));
405
406        assert!(MatchesTermFinder::new("中文").find("这是中文测试,中文!"));
407        assert!(MatchesTermFinder::new("error").find("错误errorerror日志_error!"));
408    }
409}