common_function/scalars/
matches_term.rs1use std::fmt;
16use std::sync::Arc;
17
18use datafusion_common::arrow::array::{Array, AsArray, BooleanArray, BooleanBuilder};
19use datafusion_common::arrow::compute;
20use datafusion_common::arrow::datatypes::DataType;
21use datafusion_common::{DataFusionError, ScalarValue};
22use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
23use memchr::memmem;
24
25use crate::function::Function;
26use crate::function_registry::FunctionRegistry;
27
28pub struct MatchesTermFunction {
75 signature: Signature,
76}
77
78impl MatchesTermFunction {
79 pub fn register(registry: &FunctionRegistry) {
80 registry.register_scalar(MatchesTermFunction::default());
81 }
82}
83
84impl Default for MatchesTermFunction {
85 fn default() -> Self {
86 Self {
87 signature: Signature::string(2, Volatility::Immutable),
88 }
89 }
90}
91
92impl fmt::Display for MatchesTermFunction {
93 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
94 write!(f, "MATCHES_TERM")
95 }
96}
97
98impl Function for MatchesTermFunction {
99 fn name(&self) -> &str {
100 "matches_term"
101 }
102
103 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
104 Ok(DataType::Boolean)
105 }
106
107 fn signature(&self) -> &Signature {
108 &self.signature
109 }
110
111 fn invoke_with_args(
112 &self,
113 args: ScalarFunctionArgs,
114 ) -> datafusion_common::Result<ColumnarValue> {
115 let [arg0, arg1] = datafusion_common::utils::take_function_args(self.name(), &args.args)?;
116
117 fn as_str(v: &ScalarValue) -> Option<&str> {
118 match v {
119 ScalarValue::Utf8View(Some(x))
120 | ScalarValue::Utf8(Some(x))
121 | ScalarValue::LargeUtf8(Some(x)) => Some(x.as_str()),
122 _ => None,
123 }
124 }
125
126 if let (ColumnarValue::Scalar(text), ColumnarValue::Scalar(term)) = (arg0, arg1) {
127 let text = as_str(text);
128 let term = as_str(term);
129 let result = match (text, term) {
130 (Some(text), Some(term)) => Some(MatchesTermFinder::new(term).find(text)),
131 _ => None,
132 };
133 return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result)));
134 };
135
136 let v = match (arg0, arg1) {
137 (ColumnarValue::Scalar(_), ColumnarValue::Scalar(_)) => {
138 unreachable!()
140 }
141 (ColumnarValue::Scalar(text), ColumnarValue::Array(terms)) => {
142 let text = as_str(text);
143 if let Some(text) = text {
144 let terms = compute::cast(terms, &DataType::Utf8View)?;
145 let terms = terms.as_string_view();
146
147 let mut builder = BooleanBuilder::with_capacity(terms.len());
148 terms.iter().for_each(|term| {
149 builder.append_option(term.map(|x| MatchesTermFinder::new(x).find(text)))
150 });
151 ColumnarValue::Array(Arc::new(builder.finish()))
152 } else {
153 ColumnarValue::Array(Arc::new(BooleanArray::new_null(terms.len())))
154 }
155 }
156 (ColumnarValue::Array(texts), ColumnarValue::Scalar(term)) => {
157 let term = as_str(term);
158 if let Some(term) = term {
159 let finder = MatchesTermFinder::new(term);
160
161 let texts = compute::cast(texts, &DataType::Utf8View)?;
162 let texts = texts.as_string_view();
163
164 let mut builder = BooleanBuilder::with_capacity(texts.len());
165 texts
166 .iter()
167 .for_each(|text| builder.append_option(text.map(|x| finder.find(x))));
168 ColumnarValue::Array(Arc::new(builder.finish()))
169 } else {
170 ColumnarValue::Array(Arc::new(BooleanArray::new_null(texts.len())))
171 }
172 }
173 (ColumnarValue::Array(texts), ColumnarValue::Array(terms)) => {
174 let terms = compute::cast(terms, &DataType::Utf8View)?;
175 let terms = terms.as_string_view();
176 let texts = compute::cast(texts, &DataType::Utf8View)?;
177 let texts = texts.as_string_view();
178
179 let len = texts.len();
180 if terms.len() != len {
181 return Err(DataFusionError::Internal(format!(
182 "input arrays have different lengths: {len}, {}",
183 terms.len()
184 )));
185 }
186
187 let mut builder = BooleanBuilder::with_capacity(len);
188 for (text, term) in texts.iter().zip(terms.iter()) {
189 let result = match (text, term) {
190 (Some(text), Some(term)) => Some(MatchesTermFinder::new(term).find(text)),
191 _ => None,
192 };
193 builder.append_option(result);
194 }
195 ColumnarValue::Array(Arc::new(builder.finish()))
196 }
197 };
198 Ok(v)
199 }
200}
201
202#[derive(Clone, Debug)]
222pub struct MatchesTermFinder {
223 finder: memmem::Finder<'static>,
224 term: String,
225 starts_with_non_alnum: bool,
226 ends_with_non_alnum: bool,
227}
228
229impl MatchesTermFinder {
230 pub fn new(term: &str) -> Self {
232 let starts_with_non_alnum = term.chars().next().is_some_and(|c| !c.is_alphanumeric());
233 let ends_with_non_alnum = term.chars().last().is_some_and(|c| !c.is_alphanumeric());
234
235 Self {
236 finder: memmem::Finder::new(term).into_owned(),
237 term: term.to_string(),
238 starts_with_non_alnum,
239 ends_with_non_alnum,
240 }
241 }
242
243 pub fn find(&self, text: &str) -> bool {
245 if self.term.is_empty() {
246 return text.is_empty();
247 }
248
249 if text.len() < self.term.len() {
250 return false;
251 }
252
253 let mut pos = 0;
254 while let Some(found_pos) = self.finder.find(&text.as_bytes()[pos..]) {
255 let actual_pos = pos + found_pos;
256
257 let prev_ok = self.starts_with_non_alnum
258 || text[..actual_pos]
259 .chars()
260 .last()
261 .map(|c| !c.is_alphanumeric())
262 .unwrap_or(true);
263
264 if prev_ok {
265 let next_pos = actual_pos + self.finder.needle().len();
266 let next_ok = self.ends_with_non_alnum
267 || text[next_pos..]
268 .chars()
269 .next()
270 .map(|c| !c.is_alphanumeric())
271 .unwrap_or(true);
272
273 if next_ok {
274 return true;
275 }
276 }
277
278 if let Some(next_char) = text[actual_pos..].chars().next() {
279 pos = actual_pos + next_char.len_utf8();
280 } else {
281 break;
282 }
283 }
284
285 false
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn matches_term_example() {
295 let finder = MatchesTermFinder::new("hello world");
296 assert!(finder.find("warning:hello world!"));
297 assert!(!finder.find("hello-world"));
298 assert!(!finder.find("hello world2023"));
299
300 let finder = MatchesTermFinder::new("critical error");
301 assert!(finder.find("ERROR:critical error!"));
302 assert!(!finder.find("critical_errors"));
303
304 let finder = MatchesTermFinder::new("");
305 assert!(finder.find(""));
306 assert!(!finder.find("any"));
307
308 let finder = MatchesTermFinder::new("Cat");
309 assert!(finder.find("Cat"));
310 assert!(!finder.find("cat"));
311 }
312
313 #[test]
314 fn matches_term_with_punctuation() {
315 assert!(MatchesTermFinder::new("cat").find("cat!"));
316 assert!(MatchesTermFinder::new("dog").find("!dog"));
317 }
318
319 #[test]
320 fn matches_phrase_with_boundaries() {
321 assert!(MatchesTermFinder::new("hello-world").find("hello-world"));
322 assert!(MatchesTermFinder::new("'foo bar'").find("test: 'foo bar'"));
323 }
324
325 #[test]
326 fn matches_at_text_boundaries() {
327 assert!(MatchesTermFinder::new("start").find("start..."));
328 assert!(MatchesTermFinder::new("end").find("...end"));
329 }
330
331 #[test]
333 fn rejects_partial_matches() {
334 assert!(!MatchesTermFinder::new("cat").find("category"));
335 assert!(!MatchesTermFinder::new("boot").find("rebooted"));
336 }
337
338 #[test]
339 fn rejects_missing_term() {
340 assert!(!MatchesTermFinder::new("foo").find("hello world"));
341 }
342
343 #[test]
345 fn handles_empty_inputs() {
346 assert!(!MatchesTermFinder::new("test").find(""));
347 assert!(!MatchesTermFinder::new("").find("text"));
348 }
349
350 #[test]
351 fn different_unicode_boundaries() {
352 assert!(MatchesTermFinder::new("café").find("café>"));
353 assert!(!MatchesTermFinder::new("café").find("口café>"));
354 assert!(!MatchesTermFinder::new("café").find("café口"));
355 assert!(!MatchesTermFinder::new("café").find("cafémore"));
356 assert!(MatchesTermFinder::new("русский").find("русский!"));
357 assert!(MatchesTermFinder::new("русский").find("русский!"));
358 }
359
360 #[test]
361 fn case_sensitive_matching() {
362 assert!(!MatchesTermFinder::new("cat").find("Cat"));
363 assert!(MatchesTermFinder::new("CaT").find("CaT"));
364 }
365
366 #[test]
367 fn numbers_in_term() {
368 assert!(MatchesTermFinder::new("v1.0").find("v1.0!"));
369 assert!(!MatchesTermFinder::new("v1.0").find("v1.0a"));
370 }
371
372 #[test]
373 fn adjacent_alphanumeric_fails() {
374 assert!(!MatchesTermFinder::new("cat").find("cat5"));
375 assert!(!MatchesTermFinder::new("dog").find("dogcat"));
376 }
377
378 #[test]
379 fn empty_term_text() {
380 assert!(!MatchesTermFinder::new("").find("text"));
381 assert!(MatchesTermFinder::new("").find(""));
382 assert!(!MatchesTermFinder::new("text").find(""));
383 }
384
385 #[test]
386 fn leading_non_alphanumeric() {
387 assert!(MatchesTermFinder::new("/cat").find("dog/cat"));
388 assert!(MatchesTermFinder::new("dog/").find("dog/cat"));
389 assert!(MatchesTermFinder::new("dog/cat").find("dog/cat"));
390 }
391
392 #[test]
393 fn continues_searching_after_boundary_mismatch() {
394 assert!(!MatchesTermFinder::new("log").find("bloglog!"));
395 assert!(MatchesTermFinder::new("log").find("bloglog log"));
396 assert!(MatchesTermFinder::new("log").find("alogblog_log!"));
397
398 assert!(MatchesTermFinder::new("error").find("errorlog_error_case"));
399 assert!(MatchesTermFinder::new("test").find("atestbtestc_test_end"));
400 assert!(MatchesTermFinder::new("data").find("database_data_store"));
401 assert!(!MatchesTermFinder::new("data").find("database_datastore"));
402 assert!(MatchesTermFinder::new("log.txt").find("catalog.txt_log.txt!"));
403 assert!(!MatchesTermFinder::new("log.txt").find("catalog.txtlog.txt!"));
404 assert!(MatchesTermFinder::new("data-set").find("bigdata-set_data-set!"));
405
406 assert!(MatchesTermFinder::new("中文").find("这是中文测试,中文!"));
407 assert!(MatchesTermFinder::new("error").find("错误errorerror日志_error!"));
408 }
409}