1use std::collections::HashMap;
16use std::fmt;
17use std::sync::Arc;
18
19use common_query::error::{InvalidFuncArgsSnafu, Result};
20use datafusion::arrow::array::{Array, ArrayRef, AsArray, BooleanArray};
21use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion};
22use datafusion::common::{DFSchema, Result as DfResult};
23use datafusion::execution::SessionStateBuilder;
24use datafusion::logical_expr::{self, ColumnarValue, Expr, Volatility};
25use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
26use datafusion_common::DataFusionError;
27use datafusion_expr::{ScalarFunctionArgs, Signature};
28use datatypes::arrow::array::RecordBatch;
29use datatypes::arrow::datatypes::{DataType, Field};
30use snafu::{OptionExt, ensure};
31
32use crate::function::{Function, extract_args};
33use crate::function_registry::FunctionRegistry;
34
35#[derive(Clone, Debug)]
39pub struct MatchesFunction {
40 signature: Signature,
41}
42
43impl MatchesFunction {
44 pub fn register(registry: &FunctionRegistry) {
45 registry.register_scalar(MatchesFunction::default());
46 }
47}
48
49impl Default for MatchesFunction {
50 fn default() -> Self {
51 Self {
52 signature: Signature::string(2, Volatility::Immutable),
53 }
54 }
55}
56
57impl fmt::Display for MatchesFunction {
58 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
59 write!(f, "MATCHES")
60 }
61}
62
63impl Function for MatchesFunction {
64 fn name(&self) -> &str {
65 "matches"
66 }
67
68 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
69 Ok(DataType::Boolean)
70 }
71
72 fn signature(&self) -> &Signature {
73 &self.signature
74 }
75
76 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DfResult<ColumnarValue> {
78 let [data_column, patterns] = extract_args(self.name(), &args)?;
79
80 if data_column.is_empty() {
81 return Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(
82 Vec::<bool>::with_capacity(0),
83 ))));
84 }
85
86 let pattern = match patterns.data_type() {
88 DataType::Utf8View => patterns.as_string_view().value(0),
89 DataType::Utf8 => patterns.as_string::<i32>().value(0),
90 DataType::LargeUtf8 => patterns.as_string::<i64>().value(0),
91 t => {
92 return Err(DataFusionError::Execution(format!(
93 "unsupported datatype {t}"
94 )));
95 }
96 };
97 self.eval(data_column, pattern)
98 }
99}
100
101impl MatchesFunction {
102 fn eval(&self, data_array: ArrayRef, pattern: &str) -> DfResult<ColumnarValue> {
103 let col_name = "data";
104 let parser_context = ParserContext::default();
105 let raw_ast = parser_context.parse_pattern(pattern)?;
106 let ast = raw_ast.transform_ast()?;
107
108 let like_expr = ast.into_like_expr(col_name);
109
110 let input_schema = Self::input_schema();
111 let session_state = SessionStateBuilder::new().with_default_features().build();
112 let planner = DefaultPhysicalPlanner::default();
113 let physical_expr =
114 planner.create_physical_expr(&like_expr, &input_schema, &session_state)?;
115
116 let arrow_schema = Arc::new(input_schema.as_arrow().clone());
117 let input_record_batch = RecordBatch::try_new(arrow_schema, vec![data_array]).unwrap();
118
119 let num_rows = input_record_batch.num_rows();
120 let result = physical_expr.evaluate(&input_record_batch)?;
121 let result_array = result.into_array(num_rows)?;
122
123 Ok(ColumnarValue::Array(Arc::new(result_array)))
124 }
125
126 fn input_schema() -> DFSchema {
127 DFSchema::from_unqualified_fields(
128 [Arc::new(Field::new("data", DataType::Utf8, true))].into(),
129 HashMap::new(),
130 )
131 .unwrap()
132 }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq)]
136enum PatternAst {
137 Literal { op: UnaryOp, pattern: String },
140 Binary {
142 op: BinaryOp,
143 children: Vec<PatternAst>,
144 },
145 Group { op: UnaryOp, child: Box<PatternAst> },
147}
148
149#[derive(Debug, Copy, Clone, PartialEq, Eq)]
150enum UnaryOp {
151 Must,
152 Optional,
153 Negative,
154}
155
156#[derive(Debug, Copy, Clone, PartialEq, Eq)]
157enum BinaryOp {
158 And,
159 Or,
160}
161
162impl PatternAst {
163 fn into_like_expr(self, column: &str) -> Expr {
164 match self {
165 PatternAst::Literal { op, pattern } => {
166 let expr = Self::convert_literal(column, &pattern);
167 match op {
168 UnaryOp::Must => expr,
169 UnaryOp::Optional => expr,
170 UnaryOp::Negative => logical_expr::not(expr),
171 }
172 }
173 PatternAst::Binary { op, children } => {
174 if children.is_empty() {
175 return logical_expr::lit(true);
176 }
177 let exprs = children
178 .into_iter()
179 .map(|child| child.into_like_expr(column));
180 match op {
182 BinaryOp::And => exprs.reduce(Expr::and).unwrap(),
183 BinaryOp::Or => exprs.reduce(Expr::or).unwrap(),
184 }
185 }
186 PatternAst::Group { op, child } => {
187 let child = child.into_like_expr(column);
188 match op {
189 UnaryOp::Must => child,
190 UnaryOp::Optional => child,
191 UnaryOp::Negative => logical_expr::not(child),
192 }
193 }
194 }
195 }
196
197 fn convert_literal(column: &str, pattern: &str) -> Expr {
198 logical_expr::col(column).like(logical_expr::lit(format!(
199 "%{}%",
200 crate::utils::escape_like_pattern(pattern)
201 )))
202 }
203
204 fn transform_ast(self) -> Result<Self> {
206 self.transform_up(Self::collapse_binary_branch_fn)
207 .map(|data| data.data)?
208 .transform_up(Self::eliminate_optional_fn)
209 .map(|data| data.data)?
210 .transform_down(Self::eliminate_single_child_fn)
211 .map(|data| data.data)
212 .map_err(Into::into)
213 }
214
215 fn collapse_binary_branch_fn(self) -> DfResult<Transformed<Self>> {
221 let PatternAst::Binary {
222 op: parent_op,
223 children,
224 } = self
225 else {
226 return Ok(Transformed::no(self));
227 };
228
229 let mut collapsed = vec![];
230 let mut remains = vec![];
231
232 for child in children {
233 match child {
234 PatternAst::Literal { .. } | PatternAst::Group { .. } => {
235 collapsed.push(child);
236 }
237 PatternAst::Binary { op, children } => {
238 if op == parent_op {
241 collapsed.extend(children);
242 } else {
243 remains.push(PatternAst::Binary { op, children });
244 }
245 }
246 }
247 }
248
249 if collapsed.is_empty() {
250 Ok(Transformed::no(PatternAst::Binary {
251 op: parent_op,
252 children: remains,
253 }))
254 } else {
255 collapsed.extend(remains);
256 Ok(Transformed::yes(PatternAst::Binary {
257 op: parent_op,
258 children: collapsed,
259 }))
260 }
261 }
262
263 fn eliminate_optional_fn(self) -> DfResult<Transformed<Self>> {
268 let PatternAst::Binary {
269 op: parent_op,
270 children,
271 } = self
272 else {
273 return Ok(Transformed::no(self));
274 };
275
276 if parent_op == BinaryOp::Or {
277 let mut must_list = vec![];
278 let mut must_not_list = vec![];
279 let mut optional_list = vec![];
280 let mut compound_list = vec![];
281
282 for child in children {
283 match child {
284 PatternAst::Literal { op, .. } | PatternAst::Group { op, .. } => match op {
285 UnaryOp::Must => must_list.push(child),
286 UnaryOp::Optional => optional_list.push(child),
287 UnaryOp::Negative => must_not_list.push(child),
288 },
289 PatternAst::Binary { .. } => {
290 compound_list.push(child);
291 }
292 }
293 }
294
295 if !must_list.is_empty() {
297 optional_list.clear();
298 }
299
300 let children_this_level = optional_list.into_iter().chain(compound_list).collect();
301 let new_node = if !must_list.is_empty() || !must_not_list.is_empty() {
302 let new_children = must_list
303 .into_iter()
304 .chain(must_not_list)
305 .chain(Some(PatternAst::Binary {
306 op: BinaryOp::Or,
307 children: children_this_level,
308 }))
309 .collect();
310 PatternAst::Binary {
311 op: BinaryOp::And,
312 children: new_children,
313 }
314 } else {
315 PatternAst::Binary {
316 op: BinaryOp::Or,
317 children: children_this_level,
318 }
319 };
320
321 return Ok(Transformed::yes(new_node));
322 }
323
324 Ok(Transformed::no(PatternAst::Binary {
325 op: parent_op,
326 children,
327 }))
328 }
329
330 fn eliminate_single_child_fn(self) -> DfResult<Transformed<Self>> {
335 let PatternAst::Binary { op, mut children } = self else {
336 return Ok(Transformed::no(self));
337 };
338
339 children.retain(|child| match child {
341 PatternAst::Binary {
342 children: grand_children,
343 ..
344 } => !grand_children.is_empty(),
345 PatternAst::Literal { .. } | PatternAst::Group { .. } => true,
346 });
347
348 if children.len() == 1 {
349 Ok(Transformed::yes(children.into_iter().next().unwrap()))
350 } else {
351 Ok(Transformed::no(PatternAst::Binary { op, children }))
352 }
353 }
354}
355
356impl TreeNode for PatternAst {
357 fn apply_children<'n, F: FnMut(&'n Self) -> DfResult<TreeNodeRecursion>>(
358 &'n self,
359 mut f: F,
360 ) -> DfResult<TreeNodeRecursion> {
361 match self {
362 PatternAst::Literal { .. } => Ok(TreeNodeRecursion::Continue),
363 PatternAst::Binary { op: _, children } => {
364 for child in children {
365 if TreeNodeRecursion::Stop == f(child)? {
366 return Ok(TreeNodeRecursion::Stop);
367 }
368 }
369 Ok(TreeNodeRecursion::Continue)
370 }
371 PatternAst::Group { op: _, child } => f(child),
372 }
373 }
374
375 fn map_children<F: FnMut(Self) -> DfResult<Transformed<Self>>>(
376 self,
377 mut f: F,
378 ) -> DfResult<Transformed<Self>> {
379 match self {
380 PatternAst::Literal { .. } => Ok(Transformed::no(self)),
381 PatternAst::Binary { op, children } => children
382 .into_iter()
383 .map_until_stop_and_collect(&mut f)?
384 .map_data(|new_children| {
385 Ok(PatternAst::Binary {
386 op,
387 children: new_children,
388 })
389 }),
390 PatternAst::Group { op, child } => f(*child)?.map_data(|new_child| {
391 Ok(PatternAst::Group {
392 op,
393 child: Box::new(new_child),
394 })
395 }),
396 }
397 }
398}
399
400#[derive(Default)]
401struct ParserContext {
402 stack: Vec<PatternAst>,
403}
404
405impl ParserContext {
406 pub fn parse_pattern(mut self, pattern: &str) -> Result<PatternAst> {
407 let tokenizer = Tokenizer::default();
408 let raw_tokens = tokenizer.tokenize(pattern)?;
409 let raw_tokens = Self::accomplish_optional_unary_op(raw_tokens)?;
410 let mut tokens = Self::to_rpn(raw_tokens)?;
411
412 while !tokens.is_empty() {
413 self.parse_one_impl(&mut tokens)?;
414 }
415
416 ensure!(
417 !self.stack.is_empty(),
418 InvalidFuncArgsSnafu {
419 err_msg: "Empty pattern",
420 }
421 );
422
423 if self.stack.len() == 1 {
425 Ok(self.stack.pop().unwrap())
426 } else {
427 Ok(PatternAst::Binary {
428 op: BinaryOp::Or,
429 children: self.stack,
430 })
431 }
432 }
433
434 fn accomplish_optional_unary_op(raw_tokens: Vec<Token>) -> Result<Vec<Token>> {
440 let mut is_prev_unary_op = false;
441 let mut is_binary_op_before = true;
443 let mut is_unary_op_before = false;
444 let mut new_tokens = Vec::with_capacity(raw_tokens.len());
445 for token in raw_tokens {
446 if !is_binary_op_before
448 && matches!(
449 token,
450 Token::Phase(_)
451 | Token::OpenParen
452 | Token::Must
453 | Token::Optional
454 | Token::Negative
455 )
456 {
457 is_binary_op_before = true;
458 new_tokens.push(Token::Or);
459 }
460 if matches!(
461 token,
462 Token::OpenParen | Token::And | Token::Or
464 ) {
465 is_binary_op_before = true;
466 } else if matches!(token, Token::Phase(_) | Token::CloseParen) {
467 is_binary_op_before = false;
469 }
470
471 if !is_prev_unary_op && matches!(token, Token::Phase(_) | Token::OpenParen) {
473 new_tokens.push(Token::Optional);
474 } else {
475 is_prev_unary_op = matches!(token, Token::Must | Token::Negative);
476 }
477
478 if matches!(token, Token::Must | Token::Optional | Token::Negative) {
480 if is_unary_op_before {
481 return InvalidFuncArgsSnafu {
482 err_msg: "Invalid pattern, unary operators should not be adjacent",
483 }
484 .fail();
485 }
486 is_unary_op_before = true;
487 } else {
488 is_unary_op_before = false;
489 }
490
491 new_tokens.push(token);
492 }
493
494 Ok(new_tokens)
495 }
496
497 fn to_rpn(mut raw_tokens: Vec<Token>) -> Result<Vec<Token>> {
499 let mut operator_stack = vec![];
500 let mut result = vec![];
501 raw_tokens.reverse();
502
503 while let Some(token) = raw_tokens.pop() {
504 match token {
505 Token::Phase(_) => result.push(token),
506 Token::Must | Token::Negative | Token::Optional => {
507 operator_stack.push(token);
508 }
509 Token::OpenParen => operator_stack.push(token),
510 Token::And | Token::Or => {
511 while let Some(stack_top) = operator_stack.last()
514 && ((*stack_top == Token::And && token == Token::Or)
515 || matches!(
516 *stack_top,
517 Token::Must | Token::Optional | Token::Negative
518 ))
519 {
520 result.push(operator_stack.pop().unwrap());
521 }
522 operator_stack.push(token);
523 }
524 Token::CloseParen => {
525 let mut is_open_paren_found = false;
526 while let Some(op) = operator_stack.pop() {
527 if op == Token::OpenParen {
528 is_open_paren_found = true;
529 break;
530 }
531 result.push(op);
532 }
533 if !is_open_paren_found {
534 return InvalidFuncArgsSnafu {
535 err_msg: "Unmatched close parentheses",
536 }
537 .fail();
538 }
539 }
540 }
541 }
542
543 while let Some(operator) = operator_stack.pop() {
544 if operator == Token::OpenParen {
545 return InvalidFuncArgsSnafu {
546 err_msg: "Unmatched parentheses",
547 }
548 .fail();
549 }
550 result.push(operator);
551 }
552
553 Ok(result)
554 }
555
556 fn parse_one_impl(&mut self, tokens: &mut Vec<Token>) -> Result<()> {
557 if let Some(token) = tokens.pop() {
558 match token {
559 Token::Must => {
560 if self.stack.is_empty() {
561 self.parse_one_impl(tokens)?;
562 }
563 let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
564 err_msg: "Invalid pattern, \"+\" operator should have one operand",
565 })?;
566 match phase_or_group {
567 PatternAst::Literal { op: _, pattern } => {
568 self.stack.push(PatternAst::Literal {
569 op: UnaryOp::Must,
570 pattern,
571 });
572 }
573 PatternAst::Binary { .. } | PatternAst::Group { .. } => {
574 self.stack.push(PatternAst::Group {
575 op: UnaryOp::Must,
576 child: Box::new(phase_or_group),
577 })
578 }
579 }
580 return Ok(());
581 }
582 Token::Negative => {
583 if self.stack.is_empty() {
584 self.parse_one_impl(tokens)?;
585 }
586 let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
587 err_msg: "Invalid pattern, \"-\" operator should have one operand",
588 })?;
589 match phase_or_group {
590 PatternAst::Literal { op: _, pattern } => {
591 self.stack.push(PatternAst::Literal {
592 op: UnaryOp::Negative,
593 pattern,
594 });
595 }
596 PatternAst::Binary { .. } | PatternAst::Group { .. } => {
597 self.stack.push(PatternAst::Group {
598 op: UnaryOp::Negative,
599 child: Box::new(phase_or_group),
600 })
601 }
602 }
603 return Ok(());
604 }
605 Token::Optional => {
606 if self.stack.is_empty() {
607 self.parse_one_impl(tokens)?;
608 }
609 let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
610 err_msg:
611 "Invalid pattern, OPTIONAL(space) operator should have one operand",
612 })?;
613 match phase_or_group {
614 PatternAst::Literal { op: _, pattern } => {
615 self.stack.push(PatternAst::Literal {
616 op: UnaryOp::Optional,
617 pattern,
618 });
619 }
620 PatternAst::Binary { .. } | PatternAst::Group { .. } => {
621 self.stack.push(PatternAst::Group {
622 op: UnaryOp::Optional,
623 child: Box::new(phase_or_group),
624 })
625 }
626 }
627 return Ok(());
628 }
629 Token::Phase(pattern) => {
630 self.stack.push(PatternAst::Literal {
631 op: UnaryOp::Optional,
633 pattern,
634 })
635 }
636 Token::And => {
637 if self.stack.is_empty() {
638 self.parse_one_impl(tokens)?;
639 };
640 let rhs = self.stack.pop().context(InvalidFuncArgsSnafu {
641 err_msg: "Invalid pattern, \"AND\" operator should have two operands",
642 })?;
643 if self.stack.is_empty() {
644 self.parse_one_impl(tokens)?
645 };
646 let lhs = self.stack.pop().context(InvalidFuncArgsSnafu {
647 err_msg: "Invalid pattern, \"AND\" operator should have two operands",
648 })?;
649 self.stack.push(PatternAst::Binary {
650 op: BinaryOp::And,
651 children: vec![lhs, rhs],
652 });
653 return Ok(());
654 }
655 Token::Or => {
656 if self.stack.is_empty() {
657 self.parse_one_impl(tokens)?
658 };
659 let rhs = self.stack.pop().context(InvalidFuncArgsSnafu {
660 err_msg: "Invalid pattern, \"OR\" operator should have two operands",
661 })?;
662 if self.stack.is_empty() {
663 self.parse_one_impl(tokens)?
664 };
665 let lhs = self.stack.pop().context(InvalidFuncArgsSnafu {
666 err_msg: "Invalid pattern, \"OR\" operator should have two operands",
667 })?;
668 self.stack.push(PatternAst::Binary {
669 op: BinaryOp::Or,
670 children: vec![lhs, rhs],
671 });
672 return Ok(());
673 }
674 Token::OpenParen | Token::CloseParen => {
675 return InvalidFuncArgsSnafu {
676 err_msg: "Unexpected parentheses",
677 }
678 .fail();
679 }
680 }
681 }
682
683 Ok(())
684 }
685}
686
687#[derive(Clone, Debug, PartialEq, Eq)]
688enum Token {
689 Must,
691 Negative,
693 And,
695 Or,
697 OpenParen,
699 CloseParen,
701 Phase(String),
703
704 Optional,
709}
710
711#[derive(Default)]
712struct Tokenizer {
713 cursor: usize,
714}
715
716impl Tokenizer {
717 pub fn tokenize(mut self, pattern: &str) -> Result<Vec<Token>> {
718 let mut tokens = vec![];
719 let char_len = pattern.chars().count();
720 while self.cursor < char_len {
721 let c = pattern.chars().nth(self.cursor).unwrap();
723 match c {
724 '+' => tokens.push(Token::Must),
725 '-' => tokens.push(Token::Negative),
726 '(' => tokens.push(Token::OpenParen),
727 ')' => tokens.push(Token::CloseParen),
728 ' ' => {
729 if let Some(last_token) = tokens.last() {
730 match last_token {
731 Token::Must | Token::Negative => {
732 return InvalidFuncArgsSnafu {
733 err_msg: format!("Unexpected space after {:?}", last_token),
734 }
735 .fail();
736 }
737 _ => {}
738 }
739 }
740 }
741 '\"' => {
742 self.step_next();
743 let phase = self.consume_next_phase(true, pattern)?;
744 tokens.push(Token::Phase(phase));
745 if let Some(ending_separator) = self.consume_next(pattern)
747 && ending_separator != ' '
748 {
749 return InvalidFuncArgsSnafu {
750 err_msg: "Expect a space after quotes ('\"')",
751 }
752 .fail();
753 }
754 }
755 _ => {
756 let phase = self.consume_next_phase(false, pattern)?;
757 match phase.to_uppercase().as_str() {
758 "AND" => tokens.push(Token::And),
759 "OR" => tokens.push(Token::Or),
760 _ => tokens.push(Token::Phase(phase)),
761 }
762 }
763 }
764 self.cursor += 1;
765 }
766 Ok(tokens)
767 }
768
769 fn consume_next(&mut self, pattern: &str) -> Option<char> {
770 self.cursor += 1;
771 pattern.chars().nth(self.cursor)
772 }
773
774 fn step_next(&mut self) {
775 self.cursor += 1;
776 }
777
778 fn rewind_one(&mut self) {
779 self.cursor -= 1;
780 }
781
782 fn consume_next_phase(&mut self, is_quoted: bool, pattern: &str) -> Result<String> {
785 let mut phase = String::new();
786 let mut is_quote_present = false;
787
788 let char_len = pattern.chars().count();
789 while self.cursor < char_len {
790 let mut c = pattern.chars().nth(self.cursor).unwrap();
791
792 match c {
793 '\"' => {
794 is_quote_present = true;
795 break;
796 }
797 ' ' => {
798 if !is_quoted {
799 break;
800 }
801 }
802 '(' | ')' | '+' | '-' => {
803 if !is_quoted {
804 self.rewind_one();
805 break;
806 }
807 }
808 '\\' => {
809 let Some(next) = self.consume_next(pattern) else {
810 return InvalidFuncArgsSnafu {
811 err_msg: "Unexpected end of pattern, expected a character after escape ('\\')",
812 }.fail();
813 };
814 c = next;
816 }
817 _ => {}
818 }
819
820 phase.push(c);
821 self.cursor += 1;
822 }
823
824 if is_quoted ^ is_quote_present {
825 return InvalidFuncArgsSnafu {
826 err_msg: "Unclosed quotes ('\"')",
827 }
828 .fail();
829 }
830
831 Ok(phase)
832 }
833}
834
835#[cfg(test)]
836mod test {
837 use datafusion::arrow::array::StringArray;
838 use datafusion_common::ScalarValue;
839 use datafusion_common::config::ConfigOptions;
840
841 use super::*;
842
843 #[test]
844 fn valid_matches_tokenizer() {
845 use Token::*;
846 let cases = [
847 (
848 "a +b -c",
849 vec![
850 Phase("a".to_string()),
851 Must,
852 Phase("b".to_string()),
853 Negative,
854 Phase("c".to_string()),
855 ],
856 ),
857 (
858 "+a(b-c)",
859 vec![
860 Must,
861 Phase("a".to_string()),
862 OpenParen,
863 Phase("b".to_string()),
864 Negative,
865 Phase("c".to_string()),
866 CloseParen,
867 ],
868 ),
869 (
870 r#"Barack Obama"#,
871 vec![Phase("Barack".to_string()), Phase("Obama".to_string())],
872 ),
873 (
874 r#"+apple +fruit"#,
875 vec![
876 Must,
877 Phase("apple".to_string()),
878 Must,
879 Phase("fruit".to_string()),
880 ],
881 ),
882 (
883 r#""He said \"hello\"""#,
884 vec![Phase("He said \"hello\"".to_string())],
885 ),
886 (
887 r#"a AND b OR c"#,
888 vec![
889 Phase("a".to_string()),
890 And,
891 Phase("b".to_string()),
892 Or,
893 Phase("c".to_string()),
894 ],
895 ),
896 (
897 r#"中文 测试"#,
898 vec![Phase("中文".to_string()), Phase("测试".to_string())],
899 ),
900 (
901 r#"中文 AND 测试"#,
902 vec![Phase("中文".to_string()), And, Phase("测试".to_string())],
903 ),
904 (
905 r#"中文 +测试"#,
906 vec![Phase("中文".to_string()), Must, Phase("测试".to_string())],
907 ),
908 (
909 r#"中文 -测试"#,
910 vec![
911 Phase("中文".to_string()),
912 Negative,
913 Phase("测试".to_string()),
914 ],
915 ),
916 ];
917
918 for (query, expected) in cases {
919 let tokenizer = Tokenizer::default();
920 let tokens = tokenizer.tokenize(query).unwrap();
921 assert_eq!(expected, tokens, "{query}");
922 }
923 }
924
925 #[test]
926 fn invalid_matches_tokenizer() {
927 let cases = [
928 (r#""He said "hello""#, "Expect a space after quotes"),
929 (r#""He said hello"#, "Unclosed quotes"),
930 (r#"a + b - c"#, "Unexpected space after"),
931 (r#"ab "c"def"#, "Expect a space after quotes"),
932 ];
933
934 for (query, expected) in cases {
935 let tokenizer = Tokenizer::default();
936 let result = tokenizer.tokenize(query);
937 assert!(result.is_err(), "{query}");
938 let actual_error = result.unwrap_err().to_string();
939 assert!(actual_error.contains(expected), "{query}, {actual_error}");
940 }
941 }
942
943 #[test]
944 fn valid_ast_transformer() {
945 let cases = [
946 (
947 "a AND b OR c",
948 PatternAst::Binary {
949 op: BinaryOp::Or,
950 children: vec![
951 PatternAst::Literal {
952 op: UnaryOp::Optional,
953 pattern: "c".to_string(),
954 },
955 PatternAst::Binary {
956 op: BinaryOp::And,
957 children: vec![
958 PatternAst::Literal {
959 op: UnaryOp::Optional,
960 pattern: "a".to_string(),
961 },
962 PatternAst::Literal {
963 op: UnaryOp::Optional,
964 pattern: "b".to_string(),
965 },
966 ],
967 },
968 ],
969 },
970 ),
971 (
972 "a -b",
973 PatternAst::Binary {
974 op: BinaryOp::And,
975 children: vec![
976 PatternAst::Literal {
977 op: UnaryOp::Negative,
978 pattern: "b".to_string(),
979 },
980 PatternAst::Literal {
981 op: UnaryOp::Optional,
982 pattern: "a".to_string(),
983 },
984 ],
985 },
986 ),
987 (
988 "a +b",
989 PatternAst::Literal {
990 op: UnaryOp::Must,
991 pattern: "b".to_string(),
992 },
993 ),
994 (
995 "a b c d",
996 PatternAst::Binary {
997 op: BinaryOp::Or,
998 children: vec![
999 PatternAst::Literal {
1000 op: UnaryOp::Optional,
1001 pattern: "a".to_string(),
1002 },
1003 PatternAst::Literal {
1004 op: UnaryOp::Optional,
1005 pattern: "b".to_string(),
1006 },
1007 PatternAst::Literal {
1008 op: UnaryOp::Optional,
1009 pattern: "c".to_string(),
1010 },
1011 PatternAst::Literal {
1012 op: UnaryOp::Optional,
1013 pattern: "d".to_string(),
1014 },
1015 ],
1016 },
1017 ),
1018 (
1019 "a b c AND d",
1020 PatternAst::Binary {
1021 op: BinaryOp::Or,
1022 children: vec![
1023 PatternAst::Literal {
1024 op: UnaryOp::Optional,
1025 pattern: "a".to_string(),
1026 },
1027 PatternAst::Literal {
1028 op: UnaryOp::Optional,
1029 pattern: "b".to_string(),
1030 },
1031 PatternAst::Binary {
1032 op: BinaryOp::And,
1033 children: vec![
1034 PatternAst::Literal {
1035 op: UnaryOp::Optional,
1036 pattern: "c".to_string(),
1037 },
1038 PatternAst::Literal {
1039 op: UnaryOp::Optional,
1040 pattern: "d".to_string(),
1041 },
1042 ],
1043 },
1044 ],
1045 },
1046 ),
1047 (
1048 r#"中文 测试"#,
1049 PatternAst::Binary {
1050 op: BinaryOp::Or,
1051 children: vec![
1052 PatternAst::Literal {
1053 op: UnaryOp::Optional,
1054 pattern: "中文".to_string(),
1055 },
1056 PatternAst::Literal {
1057 op: UnaryOp::Optional,
1058 pattern: "测试".to_string(),
1059 },
1060 ],
1061 },
1062 ),
1063 (
1064 r#"中文 AND 测试"#,
1065 PatternAst::Binary {
1066 op: BinaryOp::And,
1067 children: vec![
1068 PatternAst::Literal {
1069 op: UnaryOp::Optional,
1070 pattern: "中文".to_string(),
1071 },
1072 PatternAst::Literal {
1073 op: UnaryOp::Optional,
1074 pattern: "测试".to_string(),
1075 },
1076 ],
1077 },
1078 ),
1079 (
1080 r#"中文 +测试"#,
1081 PatternAst::Literal {
1082 op: UnaryOp::Must,
1083 pattern: "测试".to_string(),
1084 },
1085 ),
1086 (
1087 r#"中文 -测试"#,
1088 PatternAst::Binary {
1089 op: BinaryOp::And,
1090 children: vec![
1091 PatternAst::Literal {
1092 op: UnaryOp::Negative,
1093 pattern: "测试".to_string(),
1094 },
1095 PatternAst::Literal {
1096 op: UnaryOp::Optional,
1097 pattern: "中文".to_string(),
1098 },
1099 ],
1100 },
1101 ),
1102 ];
1103
1104 for (query, expected) in cases {
1105 let parser = ParserContext { stack: vec![] };
1106 let ast = parser.parse_pattern(query).unwrap();
1107 let ast = ast.transform_ast().unwrap();
1108 assert_eq!(expected, ast, "{query}");
1109 }
1110 }
1111
1112 #[test]
1113 fn invalid_ast() {
1114 let cases = [
1115 (r#"a b (c"#, "Unmatched parentheses"),
1116 (r#"a b) c"#, "Unmatched close parentheses"),
1117 (r#"a +-b"#, "unary operators should not be adjacent"),
1118 ];
1119
1120 for (query, expected) in cases {
1121 let result: Result<()> = try {
1122 let parser = ParserContext { stack: vec![] };
1123 let ast = parser.parse_pattern(query)?;
1124 let _ast = ast.transform_ast()?;
1125 };
1126
1127 assert!(result.is_err(), "{query}");
1128 let actual_error = result.unwrap_err().to_string();
1129 assert!(actual_error.contains(expected), "{query}, {actual_error}");
1130 }
1131 }
1132
1133 #[test]
1134 fn valid_matches_parser() {
1135 let cases = [
1136 (
1137 "a AND b OR c",
1138 PatternAst::Binary {
1139 op: BinaryOp::Or,
1140 children: vec![
1141 PatternAst::Binary {
1142 op: BinaryOp::And,
1143 children: vec![
1144 PatternAst::Literal {
1145 op: UnaryOp::Optional,
1146 pattern: "a".to_string(),
1147 },
1148 PatternAst::Literal {
1149 op: UnaryOp::Optional,
1150 pattern: "b".to_string(),
1151 },
1152 ],
1153 },
1154 PatternAst::Literal {
1155 op: UnaryOp::Optional,
1156 pattern: "c".to_string(),
1157 },
1158 ],
1159 },
1160 ),
1161 (
1162 "(a AND b) OR c",
1163 PatternAst::Binary {
1164 op: BinaryOp::Or,
1165 children: vec![
1166 PatternAst::Group {
1167 op: UnaryOp::Optional,
1168 child: Box::new(PatternAst::Binary {
1169 op: BinaryOp::And,
1170 children: vec![
1171 PatternAst::Literal {
1172 op: UnaryOp::Optional,
1173 pattern: "a".to_string(),
1174 },
1175 PatternAst::Literal {
1176 op: UnaryOp::Optional,
1177 pattern: "b".to_string(),
1178 },
1179 ],
1180 }),
1181 },
1182 PatternAst::Literal {
1183 op: UnaryOp::Optional,
1184 pattern: "c".to_string(),
1185 },
1186 ],
1187 },
1188 ),
1189 (
1190 "a AND (b OR c)",
1191 PatternAst::Binary {
1192 op: BinaryOp::And,
1193 children: vec![
1194 PatternAst::Literal {
1195 op: UnaryOp::Optional,
1196 pattern: "a".to_string(),
1197 },
1198 PatternAst::Group {
1199 op: UnaryOp::Optional,
1200 child: Box::new(PatternAst::Binary {
1201 op: BinaryOp::Or,
1202 children: vec![
1203 PatternAst::Literal {
1204 op: UnaryOp::Optional,
1205 pattern: "b".to_string(),
1206 },
1207 PatternAst::Literal {
1208 op: UnaryOp::Optional,
1209 pattern: "c".to_string(),
1210 },
1211 ],
1212 }),
1213 },
1214 ],
1215 },
1216 ),
1217 (
1218 "a +b -c",
1219 PatternAst::Binary {
1220 op: BinaryOp::Or,
1221 children: vec![
1222 PatternAst::Literal {
1223 op: UnaryOp::Optional,
1224 pattern: "a".to_string(),
1225 },
1226 PatternAst::Binary {
1227 op: BinaryOp::Or,
1228 children: vec![
1229 PatternAst::Literal {
1230 op: UnaryOp::Must,
1231 pattern: "b".to_string(),
1232 },
1233 PatternAst::Literal {
1234 op: UnaryOp::Negative,
1235 pattern: "c".to_string(),
1236 },
1237 ],
1238 },
1239 ],
1240 },
1241 ),
1242 (
1243 "(+a +b) c",
1244 PatternAst::Binary {
1245 op: BinaryOp::Or,
1246 children: vec![
1247 PatternAst::Group {
1248 op: UnaryOp::Optional,
1249 child: Box::new(PatternAst::Binary {
1250 op: BinaryOp::Or,
1251 children: vec![
1252 PatternAst::Literal {
1253 op: UnaryOp::Must,
1254 pattern: "a".to_string(),
1255 },
1256 PatternAst::Literal {
1257 op: UnaryOp::Must,
1258 pattern: "b".to_string(),
1259 },
1260 ],
1261 }),
1262 },
1263 PatternAst::Literal {
1264 op: UnaryOp::Optional,
1265 pattern: "c".to_string(),
1266 },
1267 ],
1268 },
1269 ),
1270 (
1271 "\"AND\" AnD \"OR\"",
1272 PatternAst::Binary {
1273 op: BinaryOp::And,
1274 children: vec![
1275 PatternAst::Literal {
1276 op: UnaryOp::Optional,
1277 pattern: "AND".to_string(),
1278 },
1279 PatternAst::Literal {
1280 op: UnaryOp::Optional,
1281 pattern: "OR".to_string(),
1282 },
1283 ],
1284 },
1285 ),
1286 ];
1287
1288 for (query, expected) in cases {
1289 let parser = ParserContext { stack: vec![] };
1290 let ast = parser.parse_pattern(query).unwrap();
1291 assert_eq!(expected, ast, "{query}");
1292 }
1293 }
1294
1295 #[test]
1296 fn evaluate_matches() {
1297 let input_data = vec![
1298 "The quick brown fox jumps over the lazy dog",
1299 "The fox jumps over the lazy dog",
1300 "The quick brown jumps over the lazy dog",
1301 "The quick brown fox over the lazy dog",
1302 "The quick brown fox jumps the lazy dog",
1303 "The quick brown fox jumps over dog",
1304 "The quick brown fox jumps over the dog",
1305 ];
1306 let col: ArrayRef = Arc::new(StringArray::from(input_data));
1307 let cases = [
1308 ("quick", vec![true, false, true, true, true, true, true]),
1310 (
1311 "\"quick brown\"",
1312 vec![true, false, true, true, true, true, true],
1313 ),
1314 (
1315 "\"fox jumps\"",
1316 vec![true, true, false, false, true, true, true],
1317 ),
1318 (
1319 "fox OR lazy",
1320 vec![true, true, true, true, true, true, true],
1321 ),
1322 (
1323 "fox AND lazy",
1324 vec![true, true, false, true, true, false, false],
1325 ),
1326 (
1327 "-over -lazy",
1328 vec![false, false, false, false, false, false, false],
1329 ),
1330 (
1331 "-over AND -lazy",
1332 vec![false, false, false, false, false, false, false],
1333 ),
1334 (
1336 "fox AND jumps OR over",
1337 vec![true, true, true, true, true, true, true],
1338 ),
1339 (
1340 "fox OR brown AND quick",
1341 vec![true, true, true, true, true, true, true],
1342 ),
1343 (
1344 "(fox OR brown) AND quick",
1345 vec![true, false, true, true, true, true, true],
1346 ),
1347 (
1348 "brown AND quick OR fox",
1349 vec![true, true, true, true, true, true, true],
1350 ),
1351 (
1352 "brown AND (quick OR fox)",
1353 vec![true, false, true, true, true, true, true],
1354 ),
1355 (
1356 "brown AND quick AND fox OR jumps AND over AND lazy",
1357 vec![true, true, true, true, true, true, true],
1358 ),
1359 (
1361 "quick brown fox +jumps",
1362 vec![true, true, true, false, true, true, true],
1363 ),
1364 (
1365 "fox +jumps -over",
1366 vec![false, false, false, false, true, false, false],
1367 ),
1368 (
1369 "fox AND +jumps AND -over",
1370 vec![false, false, false, false, true, false, false],
1371 ),
1372 (
1374 "(+fox +jumps) over",
1375 vec![true, true, true, true, true, true, true],
1376 ),
1377 (
1378 "+(fox jumps) AND over",
1379 vec![true, true, true, true, false, true, true],
1380 ),
1381 (
1382 "over -(fox jumps)",
1383 vec![false, false, false, false, false, false, false],
1384 ),
1385 (
1386 "over -(fox AND jumps)",
1387 vec![false, false, true, true, false, false, false],
1388 ),
1389 (
1390 "over AND -(-(fox OR jumps))",
1391 vec![true, true, true, true, false, true, true],
1392 ),
1393 ];
1394
1395 let f = MatchesFunction::default();
1396 for (pattern, expected) in cases {
1397 let args = ScalarFunctionArgs {
1398 args: vec![
1399 ColumnarValue::Array(col.clone()),
1400 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(pattern.to_string()))),
1401 ],
1402 arg_fields: vec![],
1403 number_rows: col.len(),
1404 return_field: Arc::new(Field::new("x", col.data_type().clone(), true)),
1405 config_options: Arc::new(ConfigOptions::new()),
1406 };
1407 let actual = f
1408 .invoke_with_args(args)
1409 .and_then(|x| x.to_array(col.len()))
1410 .unwrap();
1411 let expected: ArrayRef = Arc::new(BooleanArray::from(expected));
1412 assert_eq!(expected.as_ref(), actual.as_ref(), "{pattern}");
1413 }
1414 }
1415}