common_function/scalars/
matches_term.rs1use std::fmt;
16use std::sync::Arc;
17
18use datafusion_common::arrow::array::{Array, AsArray, BooleanArray, BooleanBuilder};
19use datafusion_common::arrow::compute;
20use datafusion_common::arrow::datatypes::DataType;
21use datafusion_common::{DataFusionError, ScalarValue};
22use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
23use icu_properties::props::Script;
24use icu_properties::{CodePointMapData, CodePointMapDataBorrowed};
25use memchr::memmem;
26
27use crate::function::Function;
28use crate::function_registry::FunctionRegistry;
29
30pub struct MatchesTermFunction {
80 signature: Signature,
81}
82
83impl MatchesTermFunction {
84 pub fn register(registry: &FunctionRegistry) {
85 registry.register_scalar(MatchesTermFunction::default());
86 }
87}
88
89impl Default for MatchesTermFunction {
90 fn default() -> Self {
91 Self {
92 signature: Signature::string(2, Volatility::Immutable),
93 }
94 }
95}
96
97impl fmt::Display for MatchesTermFunction {
98 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
99 write!(f, "MATCHES_TERM")
100 }
101}
102
103impl Function for MatchesTermFunction {
104 fn name(&self) -> &str {
105 "matches_term"
106 }
107
108 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
109 Ok(DataType::Boolean)
110 }
111
112 fn signature(&self) -> &Signature {
113 &self.signature
114 }
115
116 fn invoke_with_args(
117 &self,
118 args: ScalarFunctionArgs,
119 ) -> datafusion_common::Result<ColumnarValue> {
120 let [arg0, arg1] = datafusion_common::utils::take_function_args(self.name(), &args.args)?;
121
122 fn as_str(v: &ScalarValue) -> Option<&str> {
123 match v {
124 ScalarValue::Utf8View(Some(x))
125 | ScalarValue::Utf8(Some(x))
126 | ScalarValue::LargeUtf8(Some(x)) => Some(x.as_str()),
127 _ => None,
128 }
129 }
130
131 if let (ColumnarValue::Scalar(text), ColumnarValue::Scalar(term)) = (arg0, arg1) {
132 let text = as_str(text);
133 let term = as_str(term);
134 let result = match (text, term) {
135 (Some(text), Some(term)) => Some(MatchesTermFinder::new(term).find(text)),
136 _ => None,
137 };
138 return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result)));
139 };
140
141 let v = match (arg0, arg1) {
142 (ColumnarValue::Scalar(_), ColumnarValue::Scalar(_)) => {
143 unreachable!()
145 }
146 (ColumnarValue::Scalar(text), ColumnarValue::Array(terms)) => {
147 let text = as_str(text);
148 if let Some(text) = text {
149 let terms = compute::cast(terms, &DataType::Utf8View)?;
150 let terms = terms.as_string_view();
151
152 let mut builder = BooleanBuilder::with_capacity(terms.len());
153 terms.iter().for_each(|term| {
154 builder.append_option(term.map(|x| MatchesTermFinder::new(x).find(text)))
155 });
156 ColumnarValue::Array(Arc::new(builder.finish()))
157 } else {
158 ColumnarValue::Array(Arc::new(BooleanArray::new_null(terms.len())))
159 }
160 }
161 (ColumnarValue::Array(texts), ColumnarValue::Scalar(term)) => {
162 let term = as_str(term);
163 if let Some(term) = term {
164 let finder = MatchesTermFinder::new(term);
165
166 let texts = compute::cast(texts, &DataType::Utf8View)?;
167 let texts = texts.as_string_view();
168
169 let mut builder = BooleanBuilder::with_capacity(texts.len());
170 texts
171 .iter()
172 .for_each(|text| builder.append_option(text.map(|x| finder.find(x))));
173 ColumnarValue::Array(Arc::new(builder.finish()))
174 } else {
175 ColumnarValue::Array(Arc::new(BooleanArray::new_null(texts.len())))
176 }
177 }
178 (ColumnarValue::Array(texts), ColumnarValue::Array(terms)) => {
179 let terms = compute::cast(terms, &DataType::Utf8View)?;
180 let terms = terms.as_string_view();
181 let texts = compute::cast(texts, &DataType::Utf8View)?;
182 let texts = texts.as_string_view();
183
184 let len = texts.len();
185 if terms.len() != len {
186 return Err(DataFusionError::Internal(format!(
187 "input arrays have different lengths: {len}, {}",
188 terms.len()
189 )));
190 }
191
192 let mut builder = BooleanBuilder::with_capacity(len);
193 for (text, term) in texts.iter().zip(terms.iter()) {
194 let result = match (text, term) {
195 (Some(text), Some(term)) => Some(MatchesTermFinder::new(term).find(text)),
196 _ => None,
197 };
198 builder.append_option(result);
199 }
200 ColumnarValue::Array(Arc::new(builder.finish()))
201 }
202 };
203 Ok(v)
204 }
205}
206
207#[derive(Clone, Debug)]
226pub struct MatchesTermFinder {
227 finder: memmem::Finder<'static>,
228 term: String,
229 term_kind: TermKind,
230 starts_with_other: bool,
231 ends_with_other: bool,
232}
233
234#[derive(Clone, Copy, Debug, PartialEq, Eq)]
235enum CharClass {
236 AsciiWord,
237 Han,
238 UnicodeWord,
239 Other,
240}
241
242#[derive(Clone, Copy, Debug, PartialEq, Eq)]
243enum TermKind {
244 AsciiLike,
245 UnicodeWord,
246 HanContaining,
247}
248
249fn classify_char(c: char) -> CharClass {
250 if c.is_ascii_alphanumeric() {
251 CharClass::AsciiWord
252 } else if is_han(c) {
253 CharClass::Han
254 } else if c.is_alphanumeric() {
255 CharClass::UnicodeWord
256 } else {
257 CharClass::Other
258 }
259}
260
261static HAN_SCRIPT_DATA: CodePointMapDataBorrowed<'static, Script> =
262 CodePointMapData::<Script>::new();
263
264fn is_han(c: char) -> bool {
265 HAN_SCRIPT_DATA.get(c) == Script::Han
266}
267
268fn classify_term(term: &str) -> TermKind {
269 let mut has_han = false;
270 let mut has_unicode_word = false;
271 for c in term.chars() {
272 match classify_char(c) {
273 CharClass::AsciiWord => {}
274 CharClass::Han => has_han = true,
275 CharClass::UnicodeWord => has_unicode_word = true,
276 CharClass::Other => {}
277 }
278 }
279
280 if has_han {
281 TermKind::HanContaining
282 } else if has_unicode_word {
283 TermKind::UnicodeWord
284 } else {
285 TermKind::AsciiLike
286 }
287}
288
289fn boundary_ok(term_kind: TermKind, neighbor: Option<char>, term_has_other_boundary: bool) -> bool {
290 if term_has_other_boundary {
291 return true;
292 }
293
294 match term_kind {
295 TermKind::AsciiLike => !matches!(neighbor.map(classify_char), Some(CharClass::AsciiWord)),
296 TermKind::UnicodeWord => !matches!(
297 neighbor.map(classify_char),
298 Some(CharClass::AsciiWord | CharClass::UnicodeWord | CharClass::Han)
299 ),
300 TermKind::HanContaining => true,
301 }
302}
303
304impl MatchesTermFinder {
305 pub fn new(term: &str) -> Self {
307 let starts_with_other = term
308 .chars()
309 .next()
310 .is_some_and(|c| classify_char(c) == CharClass::Other);
311 let ends_with_other = term
312 .chars()
313 .last()
314 .is_some_and(|c| classify_char(c) == CharClass::Other);
315 Self {
316 finder: memmem::Finder::new(term).into_owned(),
317 term: term.to_string(),
318 term_kind: classify_term(term),
319 starts_with_other,
320 ends_with_other,
321 }
322 }
323
324 pub fn find(&self, text: &str) -> bool {
326 if self.term.is_empty() {
327 return text.is_empty();
328 }
329
330 if text.len() < self.term.len() {
331 return false;
332 }
333
334 let mut pos = 0;
335 while let Some(found_pos) = self.finder.find(&text.as_bytes()[pos..]) {
336 let actual_pos = pos + found_pos;
337
338 let prev = text[..actual_pos].chars().last();
339 let prev_ok = self.starts_with_other || boundary_ok(self.term_kind, prev, false);
340
341 if prev_ok {
342 if self.term_kind == TermKind::HanContaining {
343 return true;
344 }
345
346 let next_pos = actual_pos + self.finder.needle().len();
347 let next = text[next_pos..].chars().next();
348 let next_ok = self.ends_with_other || boundary_ok(self.term_kind, next, false);
349
350 if next_ok {
351 return true;
352 }
353 }
354
355 if let Some(next_char) = text[actual_pos..].chars().next() {
356 pos = actual_pos + next_char.len_utf8();
357 } else {
358 break;
359 }
360 }
361
362 false
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn matches_term_example() {
372 let finder = MatchesTermFinder::new("hello world");
373 assert!(finder.find("warning:hello world!"));
374 assert!(!finder.find("hello-world"));
375 assert!(!finder.find("hello world2023"));
376
377 let finder = MatchesTermFinder::new("critical error");
378 assert!(finder.find("ERROR:critical error!"));
379 assert!(!finder.find("critical_errors"));
380
381 let finder = MatchesTermFinder::new("");
382 assert!(finder.find(""));
383 assert!(!finder.find("any"));
384
385 let finder = MatchesTermFinder::new("Cat");
386 assert!(finder.find("Cat"));
387 assert!(!finder.find("cat"));
388 }
389
390 #[test]
391 fn matches_term_with_punctuation() {
392 assert!(MatchesTermFinder::new("cat").find("cat!"));
393 assert!(MatchesTermFinder::new("dog").find("!dog"));
394 }
395
396 #[test]
397 fn matches_phrase_with_boundaries() {
398 assert!(MatchesTermFinder::new("hello-world").find("hello-world"));
399 assert!(MatchesTermFinder::new("'foo bar'").find("test: 'foo bar'"));
400 }
401
402 #[test]
403 fn matches_at_text_boundaries() {
404 assert!(MatchesTermFinder::new("start").find("start..."));
405 assert!(MatchesTermFinder::new("end").find("...end"));
406 }
407
408 #[test]
410 fn rejects_partial_matches() {
411 assert!(!MatchesTermFinder::new("cat").find("category"));
412 assert!(!MatchesTermFinder::new("boot").find("rebooted"));
413 }
414
415 #[test]
416 fn rejects_missing_term() {
417 assert!(!MatchesTermFinder::new("foo").find("hello world"));
418 }
419
420 #[test]
422 fn handles_empty_inputs() {
423 assert!(!MatchesTermFinder::new("test").find(""));
424 assert!(!MatchesTermFinder::new("").find("text"));
425 }
426
427 #[test]
428 fn different_unicode_boundaries() {
429 assert!(MatchesTermFinder::new("café").find("café>"));
430 assert!(!MatchesTermFinder::new("café").find("口café>"));
431 assert!(!MatchesTermFinder::new("café").find("café口"));
432 assert!(!MatchesTermFinder::new("café").find("cafémore"));
433 assert!(MatchesTermFinder::new("русский").find("русский!"));
434 assert!(MatchesTermFinder::new("русский").find("русский!"));
435 }
436
437 #[test]
438 fn case_sensitive_matching() {
439 assert!(!MatchesTermFinder::new("cat").find("Cat"));
440 assert!(MatchesTermFinder::new("CaT").find("CaT"));
441 }
442
443 #[test]
444 fn numbers_in_term() {
445 assert!(MatchesTermFinder::new("v1.0").find("v1.0!"));
446 assert!(!MatchesTermFinder::new("v1.0").find("v1.0a"));
447 }
448
449 #[test]
450 fn mixed_script_terms_match_inside_chinese_context() {
451 let text = "登录手机号18888888888的动态key";
452 assert!(MatchesTermFinder::new("手机号").find(text));
453 assert!(MatchesTermFinder::new("18888888888").find(text));
454 assert!(MatchesTermFinder::new("手机").find(text));
455 assert!(MatchesTermFinder::new("机号").find(text));
456 assert!(MatchesTermFinder::new("机号1888").find(text));
457 assert!(MatchesTermFinder::new("农业").find("中国农业银行"));
458 assert!(MatchesTermFinder::new("error").find("错误error日志"));
459 }
460
461 #[test]
462 fn underscore_still_counts_as_boundary_for_ascii_terms() {
463 assert!(MatchesTermFinder::new("world").find("hello_world"));
464 assert!(MatchesTermFinder::new("id").find("trace_id=abc"));
465 assert!(!MatchesTermFinder::new("error").find("criticalerrors"));
466 }
467
468 #[test]
469 fn adjacent_alphanumeric_fails() {
470 assert!(!MatchesTermFinder::new("cat").find("cat5"));
471 assert!(!MatchesTermFinder::new("dog").find("dogcat"));
472 }
473
474 #[test]
475 fn empty_term_text() {
476 assert!(!MatchesTermFinder::new("").find("text"));
477 assert!(MatchesTermFinder::new("").find(""));
478 assert!(!MatchesTermFinder::new("text").find(""));
479 }
480
481 #[test]
482 fn leading_non_alphanumeric() {
483 assert!(MatchesTermFinder::new("/cat").find("dog/cat"));
484 assert!(MatchesTermFinder::new("dog/").find("dog/cat"));
485 assert!(MatchesTermFinder::new("dog/cat").find("dog/cat"));
486 }
487
488 #[test]
489 fn continues_searching_after_boundary_mismatch() {
490 assert!(!MatchesTermFinder::new("log").find("bloglog!"));
491 assert!(MatchesTermFinder::new("log").find("bloglog log"));
492 assert!(MatchesTermFinder::new("log").find("alogblog_log!"));
493
494 assert!(MatchesTermFinder::new("error").find("errorlog_error_case"));
495 assert!(MatchesTermFinder::new("test").find("atestbtestc_test_end"));
496 assert!(MatchesTermFinder::new("data").find("database_data_store"));
497 assert!(!MatchesTermFinder::new("data").find("database_datastore"));
498 assert!(MatchesTermFinder::new("log.txt").find("catalog.txt_log.txt!"));
499 assert!(!MatchesTermFinder::new("log.txt").find("catalog.txtlog.txt!"));
500 assert!(MatchesTermFinder::new("data-set").find("bigdata-set_data-set!"));
501
502 assert!(MatchesTermFinder::new("中文").find("这是中文测试,中文!"));
503 assert!(MatchesTermFinder::new("error").find("错误errorerror日志_error!"));
504 }
505
506 #[test]
507 fn han_terms_match_as_contiguous_substrings() {
508 assert!(MatchesTermFinder::new("行账号").find("中国农业银行账号"));
509 assert!(MatchesTermFinder::new("登录").find("登录手机号18888888888的动态key"));
510 }
511
512 #[test]
513 fn han_detection_uses_script_not_all_cjk() {
514 assert!(is_han('汉'));
515 assert!(is_han('\u{30000}'));
516 assert!(!is_han('あ'));
517 assert!(!is_han('한'));
518 }
519}