common_function/scalars/
matches_term.rs1use std::fmt;
16use std::iter::repeat_n;
17use std::sync::Arc;
18
19use common_query::error::{InvalidFuncArgsSnafu, Result};
20use datafusion::arrow::datatypes::DataType;
21use datafusion_expr::{Signature, Volatility};
22use datatypes::scalars::ScalarVectorBuilder;
23use datatypes::vectors::{BooleanVector, BooleanVectorBuilder, MutableVector, VectorRef};
24use memchr::memmem;
25use snafu::ensure;
26
27use crate::function::{Function, FunctionContext};
28use crate::function_registry::FunctionRegistry;
29
30pub struct MatchesTermFunction;
77
78impl MatchesTermFunction {
79 pub fn register(registry: &FunctionRegistry) {
80 registry.register_scalar(MatchesTermFunction);
81 }
82}
83
84impl fmt::Display for MatchesTermFunction {
85 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
86 write!(f, "MATCHES_TERM")
87 }
88}
89
90impl Function for MatchesTermFunction {
91 fn name(&self) -> &str {
92 "matches_term"
93 }
94
95 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
96 Ok(DataType::Boolean)
97 }
98
99 fn signature(&self) -> Signature {
100 Signature::string(2, Volatility::Immutable)
101 }
102
103 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
104 ensure!(
105 columns.len() == 2,
106 InvalidFuncArgsSnafu {
107 err_msg: format!(
108 "The length of the args is not correct, expect exactly 2, have: {}",
109 columns.len()
110 ),
111 }
112 );
113
114 let text_column = &columns[0];
115 if text_column.is_empty() {
116 return Ok(Arc::new(BooleanVector::from(Vec::<bool>::with_capacity(0))));
117 }
118
119 let term_column = &columns[1];
120 let compiled_finder = if term_column.is_const() {
121 let term = term_column.get_ref(0).as_string().unwrap();
122 match term {
123 None => {
124 return Ok(Arc::new(BooleanVector::from_iter(repeat_n(
125 None,
126 text_column.len(),
127 ))));
128 }
129 Some(term) => Some(MatchesTermFinder::new(term)),
130 }
131 } else {
132 None
133 };
134
135 let len = text_column.len();
136 let mut result = BooleanVectorBuilder::with_capacity(len);
137 for i in 0..len {
138 let text = text_column.get_ref(i).as_string().unwrap();
139 let Some(text) = text else {
140 result.push_null();
141 continue;
142 };
143
144 let contains = match &compiled_finder {
145 Some(finder) => finder.find(text),
146 None => {
147 let term = match term_column.get_ref(i).as_string().unwrap() {
148 None => {
149 result.push_null();
150 continue;
151 }
152 Some(term) => term,
153 };
154 MatchesTermFinder::new(term).find(text)
155 }
156 };
157 result.push(Some(contains));
158 }
159
160 Ok(result.to_vector())
161 }
162}
163
164#[derive(Clone, Debug)]
184pub struct MatchesTermFinder {
185 finder: memmem::Finder<'static>,
186 term: String,
187 starts_with_non_alnum: bool,
188 ends_with_non_alnum: bool,
189}
190
191impl MatchesTermFinder {
192 pub fn new(term: &str) -> Self {
194 let starts_with_non_alnum = term.chars().next().is_some_and(|c| !c.is_alphanumeric());
195 let ends_with_non_alnum = term.chars().last().is_some_and(|c| !c.is_alphanumeric());
196
197 Self {
198 finder: memmem::Finder::new(term).into_owned(),
199 term: term.to_string(),
200 starts_with_non_alnum,
201 ends_with_non_alnum,
202 }
203 }
204
205 pub fn find(&self, text: &str) -> bool {
207 if self.term.is_empty() {
208 return text.is_empty();
209 }
210
211 if text.len() < self.term.len() {
212 return false;
213 }
214
215 let mut pos = 0;
216 while let Some(found_pos) = self.finder.find(&text.as_bytes()[pos..]) {
217 let actual_pos = pos + found_pos;
218
219 let prev_ok = self.starts_with_non_alnum
220 || text[..actual_pos]
221 .chars()
222 .last()
223 .map(|c| !c.is_alphanumeric())
224 .unwrap_or(true);
225
226 if prev_ok {
227 let next_pos = actual_pos + self.finder.needle().len();
228 let next_ok = self.ends_with_non_alnum
229 || text[next_pos..]
230 .chars()
231 .next()
232 .map(|c| !c.is_alphanumeric())
233 .unwrap_or(true);
234
235 if next_ok {
236 return true;
237 }
238 }
239
240 if let Some(next_char) = text[actual_pos..].chars().next() {
241 pos = actual_pos + next_char.len_utf8();
242 } else {
243 break;
244 }
245 }
246
247 false
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn matches_term_example() {
257 let finder = MatchesTermFinder::new("hello world");
258 assert!(finder.find("warning:hello world!"));
259 assert!(!finder.find("hello-world"));
260 assert!(!finder.find("hello world2023"));
261
262 let finder = MatchesTermFinder::new("critical error");
263 assert!(finder.find("ERROR:critical error!"));
264 assert!(!finder.find("critical_errors"));
265
266 let finder = MatchesTermFinder::new("");
267 assert!(finder.find(""));
268 assert!(!finder.find("any"));
269
270 let finder = MatchesTermFinder::new("Cat");
271 assert!(finder.find("Cat"));
272 assert!(!finder.find("cat"));
273 }
274
275 #[test]
276 fn matches_term_with_punctuation() {
277 assert!(MatchesTermFinder::new("cat").find("cat!"));
278 assert!(MatchesTermFinder::new("dog").find("!dog"));
279 }
280
281 #[test]
282 fn matches_phrase_with_boundaries() {
283 assert!(MatchesTermFinder::new("hello-world").find("hello-world"));
284 assert!(MatchesTermFinder::new("'foo bar'").find("test: 'foo bar'"));
285 }
286
287 #[test]
288 fn matches_at_text_boundaries() {
289 assert!(MatchesTermFinder::new("start").find("start..."));
290 assert!(MatchesTermFinder::new("end").find("...end"));
291 }
292
293 #[test]
295 fn rejects_partial_matches() {
296 assert!(!MatchesTermFinder::new("cat").find("category"));
297 assert!(!MatchesTermFinder::new("boot").find("rebooted"));
298 }
299
300 #[test]
301 fn rejects_missing_term() {
302 assert!(!MatchesTermFinder::new("foo").find("hello world"));
303 }
304
305 #[test]
307 fn handles_empty_inputs() {
308 assert!(!MatchesTermFinder::new("test").find(""));
309 assert!(!MatchesTermFinder::new("").find("text"));
310 }
311
312 #[test]
313 fn different_unicode_boundaries() {
314 assert!(MatchesTermFinder::new("café").find("café>"));
315 assert!(!MatchesTermFinder::new("café").find("口café>"));
316 assert!(!MatchesTermFinder::new("café").find("café口"));
317 assert!(!MatchesTermFinder::new("café").find("cafémore"));
318 assert!(MatchesTermFinder::new("русский").find("русский!"));
319 assert!(MatchesTermFinder::new("русский").find("русский!"));
320 }
321
322 #[test]
323 fn case_sensitive_matching() {
324 assert!(!MatchesTermFinder::new("cat").find("Cat"));
325 assert!(MatchesTermFinder::new("CaT").find("CaT"));
326 }
327
328 #[test]
329 fn numbers_in_term() {
330 assert!(MatchesTermFinder::new("v1.0").find("v1.0!"));
331 assert!(!MatchesTermFinder::new("v1.0").find("v1.0a"));
332 }
333
334 #[test]
335 fn adjacent_alphanumeric_fails() {
336 assert!(!MatchesTermFinder::new("cat").find("cat5"));
337 assert!(!MatchesTermFinder::new("dog").find("dogcat"));
338 }
339
340 #[test]
341 fn empty_term_text() {
342 assert!(!MatchesTermFinder::new("").find("text"));
343 assert!(MatchesTermFinder::new("").find(""));
344 assert!(!MatchesTermFinder::new("text").find(""));
345 }
346
347 #[test]
348 fn leading_non_alphanumeric() {
349 assert!(MatchesTermFinder::new("/cat").find("dog/cat"));
350 assert!(MatchesTermFinder::new("dog/").find("dog/cat"));
351 assert!(MatchesTermFinder::new("dog/cat").find("dog/cat"));
352 }
353
354 #[test]
355 fn continues_searching_after_boundary_mismatch() {
356 assert!(!MatchesTermFinder::new("log").find("bloglog!"));
357 assert!(MatchesTermFinder::new("log").find("bloglog log"));
358 assert!(MatchesTermFinder::new("log").find("alogblog_log!"));
359
360 assert!(MatchesTermFinder::new("error").find("errorlog_error_case"));
361 assert!(MatchesTermFinder::new("test").find("atestbtestc_test_end"));
362 assert!(MatchesTermFinder::new("data").find("database_data_store"));
363 assert!(!MatchesTermFinder::new("data").find("database_datastore"));
364 assert!(MatchesTermFinder::new("log.txt").find("catalog.txt_log.txt!"));
365 assert!(!MatchesTermFinder::new("log.txt").find("catalog.txtlog.txt!"));
366 assert!(MatchesTermFinder::new("data-set").find("bigdata-set_data-set!"));
367
368 assert!(MatchesTermFinder::new("中文").find("这是中文测试,中文!"));
369 assert!(MatchesTermFinder::new("error").find("错误errorerror日志_error!"));
370 }
371}