1use std::collections::HashMap;
16use std::fmt;
17use std::sync::Arc;
18
19use common_query::error::{InvalidFuncArgsSnafu, Result};
20use datafusion::arrow::array::{Array, ArrayRef, AsArray, BooleanArray};
21use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion};
22use datafusion::common::{DFSchema, Result as DfResult};
23use datafusion::execution::SessionStateBuilder;
24use datafusion::logical_expr::{self, ColumnarValue, Expr, Volatility};
25use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
26use datafusion_common::DataFusionError;
27use datafusion_expr::{ScalarFunctionArgs, Signature};
28use datatypes::arrow::array::RecordBatch;
29use datatypes::arrow::datatypes::{DataType, Field};
30use snafu::{OptionExt, ensure};
31
32use crate::function::{Function, extract_args};
33use crate::function_registry::FunctionRegistry;
34
35#[derive(Clone, Debug)]
39pub struct MatchesFunction {
40 signature: Signature,
41}
42
43impl MatchesFunction {
44 pub fn register(registry: &FunctionRegistry) {
45 registry.register_scalar(MatchesFunction::default());
46 }
47}
48
49impl Default for MatchesFunction {
50 fn default() -> Self {
51 Self {
52 signature: Signature::string(2, Volatility::Immutable),
53 }
54 }
55}
56
57impl fmt::Display for MatchesFunction {
58 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
59 write!(f, "MATCHES")
60 }
61}
62
63impl Function for MatchesFunction {
64 fn name(&self) -> &str {
65 "matches"
66 }
67
68 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
69 Ok(DataType::Boolean)
70 }
71
72 fn signature(&self) -> &Signature {
73 &self.signature
74 }
75
76 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DfResult<ColumnarValue> {
78 let [data_column, patterns] = extract_args(self.name(), &args)?;
79
80 if data_column.is_empty() {
81 return Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(
82 Vec::<bool>::with_capacity(0),
83 ))));
84 }
85
86 let pattern = match patterns.data_type() {
88 DataType::Utf8View => patterns.as_string_view().value(0),
89 DataType::Utf8 => patterns.as_string::<i32>().value(0),
90 DataType::LargeUtf8 => patterns.as_string::<i64>().value(0),
91 t => {
92 return Err(DataFusionError::Execution(format!(
93 "unsupported datatype {t}"
94 )));
95 }
96 };
97 self.eval(data_column, pattern)
98 }
99}
100
101impl MatchesFunction {
102 fn eval(&self, data_array: ArrayRef, pattern: &str) -> DfResult<ColumnarValue> {
103 let col_name = "data";
104 let parser_context = ParserContext::default();
105 let raw_ast = parser_context.parse_pattern(pattern)?;
106 let ast = raw_ast.transform_ast()?;
107
108 let like_expr = ast.into_like_expr(col_name);
109
110 let input_schema = Self::input_schema();
111 let session_state = SessionStateBuilder::new().with_default_features().build();
112 let planner = DefaultPhysicalPlanner::default();
113 let physical_expr =
114 planner.create_physical_expr(&like_expr, &input_schema, &session_state)?;
115
116 let arrow_schema = Arc::new(input_schema.as_arrow().clone());
117 let input_record_batch = RecordBatch::try_new(arrow_schema, vec![data_array]).unwrap();
118
119 let num_rows = input_record_batch.num_rows();
120 let result = physical_expr.evaluate(&input_record_batch)?;
121 let result_array = result.into_array(num_rows)?;
122
123 Ok(ColumnarValue::Array(Arc::new(result_array)))
124 }
125
126 fn input_schema() -> DFSchema {
127 DFSchema::from_unqualified_fields(
128 [Arc::new(Field::new("data", DataType::Utf8, true))].into(),
129 HashMap::new(),
130 )
131 .unwrap()
132 }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq)]
136enum PatternAst {
137 Literal { op: UnaryOp, pattern: String },
140 Binary {
142 op: BinaryOp,
143 children: Vec<PatternAst>,
144 },
145 Group { op: UnaryOp, child: Box<PatternAst> },
147}
148
149#[derive(Debug, Copy, Clone, PartialEq, Eq)]
150enum UnaryOp {
151 Must,
152 Optional,
153 Negative,
154}
155
156#[derive(Debug, Copy, Clone, PartialEq, Eq)]
157enum BinaryOp {
158 And,
159 Or,
160}
161
162impl PatternAst {
163 fn into_like_expr(self, column: &str) -> Expr {
164 match self {
165 PatternAst::Literal { op, pattern } => {
166 let expr = Self::convert_literal(column, &pattern);
167 match op {
168 UnaryOp::Must => expr,
169 UnaryOp::Optional => expr,
170 UnaryOp::Negative => logical_expr::not(expr),
171 }
172 }
173 PatternAst::Binary { op, children } => {
174 if children.is_empty() {
175 return logical_expr::lit(true);
176 }
177 let exprs = children
178 .into_iter()
179 .map(|child| child.into_like_expr(column));
180 match op {
182 BinaryOp::And => exprs.reduce(Expr::and).unwrap(),
183 BinaryOp::Or => exprs.reduce(Expr::or).unwrap(),
184 }
185 }
186 PatternAst::Group { op, child } => {
187 let child = child.into_like_expr(column);
188 match op {
189 UnaryOp::Must => child,
190 UnaryOp::Optional => child,
191 UnaryOp::Negative => logical_expr::not(child),
192 }
193 }
194 }
195 }
196
197 fn convert_literal(column: &str, pattern: &str) -> Expr {
198 logical_expr::col(column).like(logical_expr::lit(format!(
199 "%{}%",
200 crate::utils::escape_like_pattern(pattern)
201 )))
202 }
203
204 fn transform_ast(self) -> Result<Self> {
206 self.transform_up(Self::collapse_binary_branch_fn)
207 .map(|data| data.data)?
208 .transform_up(Self::eliminate_optional_fn)
209 .map(|data| data.data)?
210 .transform_down(Self::eliminate_single_child_fn)
211 .map(|data| data.data)
212 .map_err(Into::into)
213 }
214
215 fn collapse_binary_branch_fn(self) -> DfResult<Transformed<Self>> {
221 let PatternAst::Binary {
222 op: parent_op,
223 children,
224 } = self
225 else {
226 return Ok(Transformed::no(self));
227 };
228
229 let mut collapsed = vec![];
230 let mut remains = vec![];
231
232 for child in children {
233 match child {
234 PatternAst::Literal { .. } | PatternAst::Group { .. } => {
235 collapsed.push(child);
236 }
237 PatternAst::Binary { op, children } => {
238 if op == parent_op {
241 collapsed.extend(children);
242 } else {
243 remains.push(PatternAst::Binary { op, children });
244 }
245 }
246 }
247 }
248
249 if collapsed.is_empty() {
250 Ok(Transformed::no(PatternAst::Binary {
251 op: parent_op,
252 children: remains,
253 }))
254 } else {
255 collapsed.extend(remains);
256 Ok(Transformed::yes(PatternAst::Binary {
257 op: parent_op,
258 children: collapsed,
259 }))
260 }
261 }
262
263 fn eliminate_optional_fn(self) -> DfResult<Transformed<Self>> {
268 let PatternAst::Binary {
269 op: parent_op,
270 children,
271 } = self
272 else {
273 return Ok(Transformed::no(self));
274 };
275
276 if parent_op == BinaryOp::Or {
277 let mut must_list = vec![];
278 let mut must_not_list = vec![];
279 let mut optional_list = vec![];
280 let mut compound_list = vec![];
281
282 for child in children {
283 match child {
284 PatternAst::Literal { op, .. } | PatternAst::Group { op, .. } => match op {
285 UnaryOp::Must => must_list.push(child),
286 UnaryOp::Optional => optional_list.push(child),
287 UnaryOp::Negative => must_not_list.push(child),
288 },
289 PatternAst::Binary { .. } => {
290 compound_list.push(child);
291 }
292 }
293 }
294
295 if !must_list.is_empty() {
297 optional_list.clear();
298 }
299
300 let children_this_level = optional_list.into_iter().chain(compound_list).collect();
301 let new_node = if !must_list.is_empty() || !must_not_list.is_empty() {
302 let new_children = must_list
303 .into_iter()
304 .chain(must_not_list)
305 .chain(Some(PatternAst::Binary {
306 op: BinaryOp::Or,
307 children: children_this_level,
308 }))
309 .collect();
310 PatternAst::Binary {
311 op: BinaryOp::And,
312 children: new_children,
313 }
314 } else {
315 PatternAst::Binary {
316 op: BinaryOp::Or,
317 children: children_this_level,
318 }
319 };
320
321 return Ok(Transformed::yes(new_node));
322 }
323
324 Ok(Transformed::no(PatternAst::Binary {
325 op: parent_op,
326 children,
327 }))
328 }
329
330 fn eliminate_single_child_fn(self) -> DfResult<Transformed<Self>> {
335 let PatternAst::Binary { op, mut children } = self else {
336 return Ok(Transformed::no(self));
337 };
338
339 children.retain(|child| match child {
341 PatternAst::Binary {
342 children: grand_children,
343 ..
344 } => !grand_children.is_empty(),
345 PatternAst::Literal { .. } | PatternAst::Group { .. } => true,
346 });
347
348 if children.len() == 1 {
349 Ok(Transformed::yes(children.into_iter().next().unwrap()))
350 } else {
351 Ok(Transformed::no(PatternAst::Binary { op, children }))
352 }
353 }
354}
355
356impl TreeNode for PatternAst {
357 fn apply_children<'n, F: FnMut(&'n Self) -> DfResult<TreeNodeRecursion>>(
358 &'n self,
359 mut f: F,
360 ) -> DfResult<TreeNodeRecursion> {
361 match self {
362 PatternAst::Literal { .. } => Ok(TreeNodeRecursion::Continue),
363 PatternAst::Binary { op: _, children } => {
364 for child in children {
365 if TreeNodeRecursion::Stop == f(child)? {
366 return Ok(TreeNodeRecursion::Stop);
367 }
368 }
369 Ok(TreeNodeRecursion::Continue)
370 }
371 PatternAst::Group { op: _, child } => f(child),
372 }
373 }
374
375 fn map_children<F: FnMut(Self) -> DfResult<Transformed<Self>>>(
376 self,
377 mut f: F,
378 ) -> DfResult<Transformed<Self>> {
379 match self {
380 PatternAst::Literal { .. } => Ok(Transformed::no(self)),
381 PatternAst::Binary { op, children } => children
382 .into_iter()
383 .map_until_stop_and_collect(&mut f)?
384 .map_data(|new_children| {
385 Ok(PatternAst::Binary {
386 op,
387 children: new_children,
388 })
389 }),
390 PatternAst::Group { op, child } => f(*child)?.map_data(|new_child| {
391 Ok(PatternAst::Group {
392 op,
393 child: Box::new(new_child),
394 })
395 }),
396 }
397 }
398}
399
400#[derive(Default)]
401struct ParserContext {
402 stack: Vec<PatternAst>,
403}
404
405impl ParserContext {
406 pub fn parse_pattern(mut self, pattern: &str) -> Result<PatternAst> {
407 let tokenizer = Tokenizer::default();
408 let raw_tokens = tokenizer.tokenize(pattern)?;
409 let raw_tokens = Self::accomplish_optional_unary_op(raw_tokens)?;
410 let mut tokens = Self::to_rpn(raw_tokens)?;
411
412 while !tokens.is_empty() {
413 self.parse_one_impl(&mut tokens)?;
414 }
415
416 ensure!(
417 !self.stack.is_empty(),
418 InvalidFuncArgsSnafu {
419 err_msg: "Empty pattern",
420 }
421 );
422
423 if self.stack.len() == 1 {
425 Ok(self.stack.pop().unwrap())
426 } else {
427 Ok(PatternAst::Binary {
428 op: BinaryOp::Or,
429 children: self.stack,
430 })
431 }
432 }
433
434 fn accomplish_optional_unary_op(raw_tokens: Vec<Token>) -> Result<Vec<Token>> {
440 let mut is_prev_unary_op = false;
441 let mut is_binary_op_before = true;
443 let mut is_unary_op_before = false;
444 let mut new_tokens = Vec::with_capacity(raw_tokens.len());
445 for token in raw_tokens {
446 if !is_binary_op_before
448 && matches!(
449 token,
450 Token::Phase(_)
451 | Token::OpenParen
452 | Token::Must
453 | Token::Optional
454 | Token::Negative
455 )
456 {
457 is_binary_op_before = true;
458 new_tokens.push(Token::Or);
459 }
460 if matches!(
461 token,
462 Token::OpenParen | Token::And | Token::Or
464 ) {
465 is_binary_op_before = true;
466 } else if matches!(token, Token::Phase(_) | Token::CloseParen) {
467 is_binary_op_before = false;
469 }
470
471 if !is_prev_unary_op && matches!(token, Token::Phase(_) | Token::OpenParen) {
473 new_tokens.push(Token::Optional);
474 } else {
475 is_prev_unary_op = matches!(token, Token::Must | Token::Negative);
476 }
477
478 if matches!(token, Token::Must | Token::Optional | Token::Negative) {
480 if is_unary_op_before {
481 return InvalidFuncArgsSnafu {
482 err_msg: "Invalid pattern, unary operators should not be adjacent",
483 }
484 .fail();
485 }
486 is_unary_op_before = true;
487 } else {
488 is_unary_op_before = false;
489 }
490
491 new_tokens.push(token);
492 }
493
494 Ok(new_tokens)
495 }
496
497 fn to_rpn(mut raw_tokens: Vec<Token>) -> Result<Vec<Token>> {
499 let mut operator_stack = vec![];
500 let mut result = vec![];
501 raw_tokens.reverse();
502
503 while let Some(token) = raw_tokens.pop() {
504 match token {
505 Token::Phase(_) => result.push(token),
506 Token::Must | Token::Negative | Token::Optional => {
507 operator_stack.push(token);
508 }
509 Token::OpenParen => operator_stack.push(token),
510 Token::And | Token::Or => {
511 while let Some(stack_top) = operator_stack.last()
514 && ((*stack_top == Token::And && token == Token::Or)
515 || matches!(
516 *stack_top,
517 Token::Must | Token::Optional | Token::Negative
518 ))
519 {
520 result.push(operator_stack.pop().unwrap());
521 }
522 operator_stack.push(token);
523 }
524 Token::CloseParen => {
525 let mut is_open_paren_found = false;
526 while let Some(op) = operator_stack.pop() {
527 if op == Token::OpenParen {
528 is_open_paren_found = true;
529 break;
530 }
531 result.push(op);
532 }
533 if !is_open_paren_found {
534 return InvalidFuncArgsSnafu {
535 err_msg: "Unmatched close parentheses",
536 }
537 .fail();
538 }
539 }
540 }
541 }
542
543 while let Some(operator) = operator_stack.pop() {
544 if operator == Token::OpenParen {
545 return InvalidFuncArgsSnafu {
546 err_msg: "Unmatched parentheses",
547 }
548 .fail();
549 }
550 result.push(operator);
551 }
552
553 Ok(result)
554 }
555
556 fn parse_one_impl(&mut self, tokens: &mut Vec<Token>) -> Result<()> {
557 if let Some(token) = tokens.pop() {
558 match token {
559 Token::Must => {
560 if self.stack.is_empty() {
561 self.parse_one_impl(tokens)?;
562 }
563 let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
564 err_msg: "Invalid pattern, \"+\" operator should have one operand",
565 })?;
566 match phase_or_group {
567 PatternAst::Literal { op: _, pattern } => {
568 self.stack.push(PatternAst::Literal {
569 op: UnaryOp::Must,
570 pattern,
571 });
572 }
573 PatternAst::Binary { .. } | PatternAst::Group { .. } => {
574 self.stack.push(PatternAst::Group {
575 op: UnaryOp::Must,
576 child: Box::new(phase_or_group),
577 })
578 }
579 }
580 return Ok(());
581 }
582 Token::Negative => {
583 if self.stack.is_empty() {
584 self.parse_one_impl(tokens)?;
585 }
586 let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
587 err_msg: "Invalid pattern, \"-\" operator should have one operand",
588 })?;
589 match phase_or_group {
590 PatternAst::Literal { op: _, pattern } => {
591 self.stack.push(PatternAst::Literal {
592 op: UnaryOp::Negative,
593 pattern,
594 });
595 }
596 PatternAst::Binary { .. } | PatternAst::Group { .. } => {
597 self.stack.push(PatternAst::Group {
598 op: UnaryOp::Negative,
599 child: Box::new(phase_or_group),
600 })
601 }
602 }
603 return Ok(());
604 }
605 Token::Optional => {
606 if self.stack.is_empty() {
607 self.parse_one_impl(tokens)?;
608 }
609 let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
610 err_msg:
611 "Invalid pattern, OPTIONAL(space) operator should have one operand",
612 })?;
613 match phase_or_group {
614 PatternAst::Literal { op: _, pattern } => {
615 self.stack.push(PatternAst::Literal {
616 op: UnaryOp::Optional,
617 pattern,
618 });
619 }
620 PatternAst::Binary { .. } | PatternAst::Group { .. } => {
621 self.stack.push(PatternAst::Group {
622 op: UnaryOp::Optional,
623 child: Box::new(phase_or_group),
624 })
625 }
626 }
627 return Ok(());
628 }
629 Token::Phase(pattern) => {
630 self.stack.push(PatternAst::Literal {
631 op: UnaryOp::Optional,
633 pattern,
634 })
635 }
636 Token::And => {
637 if self.stack.is_empty() {
638 self.parse_one_impl(tokens)?;
639 };
640 let rhs = self.stack.pop().context(InvalidFuncArgsSnafu {
641 err_msg: "Invalid pattern, \"AND\" operator should have two operands",
642 })?;
643 if self.stack.is_empty() {
644 self.parse_one_impl(tokens)?
645 };
646 let lhs = self.stack.pop().context(InvalidFuncArgsSnafu {
647 err_msg: "Invalid pattern, \"AND\" operator should have two operands",
648 })?;
649 self.stack.push(PatternAst::Binary {
650 op: BinaryOp::And,
651 children: vec![lhs, rhs],
652 });
653 return Ok(());
654 }
655 Token::Or => {
656 if self.stack.is_empty() {
657 self.parse_one_impl(tokens)?
658 };
659 let rhs = self.stack.pop().context(InvalidFuncArgsSnafu {
660 err_msg: "Invalid pattern, \"OR\" operator should have two operands",
661 })?;
662 if self.stack.is_empty() {
663 self.parse_one_impl(tokens)?
664 };
665 let lhs = self.stack.pop().context(InvalidFuncArgsSnafu {
666 err_msg: "Invalid pattern, \"OR\" operator should have two operands",
667 })?;
668 self.stack.push(PatternAst::Binary {
669 op: BinaryOp::Or,
670 children: vec![lhs, rhs],
671 });
672 return Ok(());
673 }
674 Token::OpenParen | Token::CloseParen => {
675 return InvalidFuncArgsSnafu {
676 err_msg: "Unexpected parentheses",
677 }
678 .fail();
679 }
680 }
681 }
682
683 Ok(())
684 }
685}
686
687#[derive(Clone, Debug, PartialEq, Eq)]
688enum Token {
689 Must,
691 Negative,
693 And,
695 Or,
697 OpenParen,
699 CloseParen,
701 Phase(String),
703
704 Optional,
709}
710
711#[derive(Default)]
712struct Tokenizer {
713 cursor: usize,
714}
715
716impl Tokenizer {
717 pub fn tokenize(mut self, pattern: &str) -> Result<Vec<Token>> {
718 let mut tokens = vec![];
719 let char_len = pattern.chars().count();
720 while self.cursor < char_len {
721 let c = pattern.chars().nth(self.cursor).unwrap();
723 match c {
724 '+' => tokens.push(Token::Must),
725 '-' => tokens.push(Token::Negative),
726 '(' => tokens.push(Token::OpenParen),
727 ')' => tokens.push(Token::CloseParen),
728 ' ' => {
729 if let Some(last_token) = tokens.last() {
730 match last_token {
731 Token::Must | Token::Negative => {
732 return InvalidFuncArgsSnafu {
733 err_msg: format!("Unexpected space after {:?}", last_token),
734 }
735 .fail();
736 }
737 _ => {}
738 }
739 }
740 }
741 '\"' => {
742 self.step_next();
743 let phase = self.consume_next_phase(true, pattern)?;
744 tokens.push(Token::Phase(phase));
745 if let Some(ending_separator) = self.consume_next(pattern)
747 && ending_separator != ' '
748 {
749 return InvalidFuncArgsSnafu {
750 err_msg: "Expect a space after quotes ('\"')",
751 }
752 .fail();
753 }
754 }
755 _ => {
756 let phase = self.consume_next_phase(false, pattern)?;
757 match phase.to_uppercase().as_str() {
758 "AND" => tokens.push(Token::And),
759 "OR" => tokens.push(Token::Or),
760 _ => tokens.push(Token::Phase(phase)),
761 }
762 }
763 }
764 self.cursor += 1;
765 }
766 Ok(tokens)
767 }
768
769 fn consume_next(&mut self, pattern: &str) -> Option<char> {
770 self.cursor += 1;
771 pattern.chars().nth(self.cursor)
772 }
773
774 fn step_next(&mut self) {
775 self.cursor += 1;
776 }
777
778 fn rewind_one(&mut self) {
779 self.cursor -= 1;
780 }
781
782 fn consume_next_phase(&mut self, is_quoted: bool, pattern: &str) -> Result<String> {
785 let mut phase = String::new();
786 let mut is_quote_present = false;
787
788 let char_len = pattern.chars().count();
789 while self.cursor < char_len {
790 let mut c = pattern.chars().nth(self.cursor).unwrap();
791
792 match c {
793 '\"' => {
794 is_quote_present = true;
795 break;
796 }
797 ' ' if !is_quoted => {
798 break;
799 }
800 '(' | ')' | '+' | '-' if !is_quoted => {
801 self.rewind_one();
802 break;
803 }
804 '\\' => {
805 let Some(next) = self.consume_next(pattern) else {
806 return InvalidFuncArgsSnafu {
807 err_msg: "Unexpected end of pattern, expected a character after escape ('\\')",
808 }.fail();
809 };
810 c = next;
812 }
813 _ => {}
814 }
815
816 phase.push(c);
817 self.cursor += 1;
818 }
819
820 if is_quoted ^ is_quote_present {
821 return InvalidFuncArgsSnafu {
822 err_msg: "Unclosed quotes ('\"')",
823 }
824 .fail();
825 }
826
827 Ok(phase)
828 }
829}
830
831#[cfg(test)]
832mod test {
833 use datafusion::arrow::array::StringArray;
834 use datafusion_common::ScalarValue;
835 use datafusion_common::config::ConfigOptions;
836
837 use super::*;
838
839 #[test]
840 fn valid_matches_tokenizer() {
841 use Token::*;
842 let cases = [
843 (
844 "a +b -c",
845 vec![
846 Phase("a".to_string()),
847 Must,
848 Phase("b".to_string()),
849 Negative,
850 Phase("c".to_string()),
851 ],
852 ),
853 (
854 "+a(b-c)",
855 vec![
856 Must,
857 Phase("a".to_string()),
858 OpenParen,
859 Phase("b".to_string()),
860 Negative,
861 Phase("c".to_string()),
862 CloseParen,
863 ],
864 ),
865 (
866 r#"Barack Obama"#,
867 vec![Phase("Barack".to_string()), Phase("Obama".to_string())],
868 ),
869 (
870 r#"+apple +fruit"#,
871 vec![
872 Must,
873 Phase("apple".to_string()),
874 Must,
875 Phase("fruit".to_string()),
876 ],
877 ),
878 (
879 r#""He said \"hello\"""#,
880 vec![Phase("He said \"hello\"".to_string())],
881 ),
882 (
883 r#"a AND b OR c"#,
884 vec![
885 Phase("a".to_string()),
886 And,
887 Phase("b".to_string()),
888 Or,
889 Phase("c".to_string()),
890 ],
891 ),
892 (
893 r#"中文 测试"#,
894 vec![Phase("中文".to_string()), Phase("测试".to_string())],
895 ),
896 (
897 r#"中文 AND 测试"#,
898 vec![Phase("中文".to_string()), And, Phase("测试".to_string())],
899 ),
900 (
901 r#"中文 +测试"#,
902 vec![Phase("中文".to_string()), Must, Phase("测试".to_string())],
903 ),
904 (
905 r#"中文 -测试"#,
906 vec![
907 Phase("中文".to_string()),
908 Negative,
909 Phase("测试".to_string()),
910 ],
911 ),
912 ];
913
914 for (query, expected) in cases {
915 let tokenizer = Tokenizer::default();
916 let tokens = tokenizer.tokenize(query).unwrap();
917 assert_eq!(expected, tokens, "{query}");
918 }
919 }
920
921 #[test]
922 fn invalid_matches_tokenizer() {
923 let cases = [
924 (r#""He said "hello""#, "Expect a space after quotes"),
925 (r#""He said hello"#, "Unclosed quotes"),
926 (r#"a + b - c"#, "Unexpected space after"),
927 (r#"ab "c"def"#, "Expect a space after quotes"),
928 ];
929
930 for (query, expected) in cases {
931 let tokenizer = Tokenizer::default();
932 let result = tokenizer.tokenize(query);
933 assert!(result.is_err(), "{query}");
934 let actual_error = result.unwrap_err().to_string();
935 assert!(actual_error.contains(expected), "{query}, {actual_error}");
936 }
937 }
938
939 #[test]
940 fn valid_ast_transformer() {
941 let cases = [
942 (
943 "a AND b OR c",
944 PatternAst::Binary {
945 op: BinaryOp::Or,
946 children: vec![
947 PatternAst::Literal {
948 op: UnaryOp::Optional,
949 pattern: "c".to_string(),
950 },
951 PatternAst::Binary {
952 op: BinaryOp::And,
953 children: vec![
954 PatternAst::Literal {
955 op: UnaryOp::Optional,
956 pattern: "a".to_string(),
957 },
958 PatternAst::Literal {
959 op: UnaryOp::Optional,
960 pattern: "b".to_string(),
961 },
962 ],
963 },
964 ],
965 },
966 ),
967 (
968 "a -b",
969 PatternAst::Binary {
970 op: BinaryOp::And,
971 children: vec![
972 PatternAst::Literal {
973 op: UnaryOp::Negative,
974 pattern: "b".to_string(),
975 },
976 PatternAst::Literal {
977 op: UnaryOp::Optional,
978 pattern: "a".to_string(),
979 },
980 ],
981 },
982 ),
983 (
984 "a +b",
985 PatternAst::Literal {
986 op: UnaryOp::Must,
987 pattern: "b".to_string(),
988 },
989 ),
990 (
991 "a b c d",
992 PatternAst::Binary {
993 op: BinaryOp::Or,
994 children: vec![
995 PatternAst::Literal {
996 op: UnaryOp::Optional,
997 pattern: "a".to_string(),
998 },
999 PatternAst::Literal {
1000 op: UnaryOp::Optional,
1001 pattern: "b".to_string(),
1002 },
1003 PatternAst::Literal {
1004 op: UnaryOp::Optional,
1005 pattern: "c".to_string(),
1006 },
1007 PatternAst::Literal {
1008 op: UnaryOp::Optional,
1009 pattern: "d".to_string(),
1010 },
1011 ],
1012 },
1013 ),
1014 (
1015 "a b c AND d",
1016 PatternAst::Binary {
1017 op: BinaryOp::Or,
1018 children: vec![
1019 PatternAst::Literal {
1020 op: UnaryOp::Optional,
1021 pattern: "a".to_string(),
1022 },
1023 PatternAst::Literal {
1024 op: UnaryOp::Optional,
1025 pattern: "b".to_string(),
1026 },
1027 PatternAst::Binary {
1028 op: BinaryOp::And,
1029 children: vec![
1030 PatternAst::Literal {
1031 op: UnaryOp::Optional,
1032 pattern: "c".to_string(),
1033 },
1034 PatternAst::Literal {
1035 op: UnaryOp::Optional,
1036 pattern: "d".to_string(),
1037 },
1038 ],
1039 },
1040 ],
1041 },
1042 ),
1043 (
1044 r#"中文 测试"#,
1045 PatternAst::Binary {
1046 op: BinaryOp::Or,
1047 children: vec![
1048 PatternAst::Literal {
1049 op: UnaryOp::Optional,
1050 pattern: "中文".to_string(),
1051 },
1052 PatternAst::Literal {
1053 op: UnaryOp::Optional,
1054 pattern: "测试".to_string(),
1055 },
1056 ],
1057 },
1058 ),
1059 (
1060 r#"中文 AND 测试"#,
1061 PatternAst::Binary {
1062 op: BinaryOp::And,
1063 children: vec![
1064 PatternAst::Literal {
1065 op: UnaryOp::Optional,
1066 pattern: "中文".to_string(),
1067 },
1068 PatternAst::Literal {
1069 op: UnaryOp::Optional,
1070 pattern: "测试".to_string(),
1071 },
1072 ],
1073 },
1074 ),
1075 (
1076 r#"中文 +测试"#,
1077 PatternAst::Literal {
1078 op: UnaryOp::Must,
1079 pattern: "测试".to_string(),
1080 },
1081 ),
1082 (
1083 r#"中文 -测试"#,
1084 PatternAst::Binary {
1085 op: BinaryOp::And,
1086 children: vec![
1087 PatternAst::Literal {
1088 op: UnaryOp::Negative,
1089 pattern: "测试".to_string(),
1090 },
1091 PatternAst::Literal {
1092 op: UnaryOp::Optional,
1093 pattern: "中文".to_string(),
1094 },
1095 ],
1096 },
1097 ),
1098 ];
1099
1100 for (query, expected) in cases {
1101 let parser = ParserContext { stack: vec![] };
1102 let ast = parser.parse_pattern(query).unwrap();
1103 let ast = ast.transform_ast().unwrap();
1104 assert_eq!(expected, ast, "{query}");
1105 }
1106 }
1107
1108 #[test]
1109 fn invalid_ast() {
1110 let cases = [
1111 (r#"a b (c"#, "Unmatched parentheses"),
1112 (r#"a b) c"#, "Unmatched close parentheses"),
1113 (r#"a +-b"#, "unary operators should not be adjacent"),
1114 ];
1115
1116 for (query, expected) in cases {
1117 let result: Result<()> = try {
1118 let parser = ParserContext { stack: vec![] };
1119 let ast = parser.parse_pattern(query)?;
1120 let _ast = ast.transform_ast()?;
1121 };
1122
1123 assert!(result.is_err(), "{query}");
1124 let actual_error = result.unwrap_err().to_string();
1125 assert!(actual_error.contains(expected), "{query}, {actual_error}");
1126 }
1127 }
1128
1129 #[test]
1130 fn valid_matches_parser() {
1131 let cases = [
1132 (
1133 "a AND b OR c",
1134 PatternAst::Binary {
1135 op: BinaryOp::Or,
1136 children: vec![
1137 PatternAst::Binary {
1138 op: BinaryOp::And,
1139 children: vec![
1140 PatternAst::Literal {
1141 op: UnaryOp::Optional,
1142 pattern: "a".to_string(),
1143 },
1144 PatternAst::Literal {
1145 op: UnaryOp::Optional,
1146 pattern: "b".to_string(),
1147 },
1148 ],
1149 },
1150 PatternAst::Literal {
1151 op: UnaryOp::Optional,
1152 pattern: "c".to_string(),
1153 },
1154 ],
1155 },
1156 ),
1157 (
1158 "(a AND b) OR c",
1159 PatternAst::Binary {
1160 op: BinaryOp::Or,
1161 children: vec![
1162 PatternAst::Group {
1163 op: UnaryOp::Optional,
1164 child: Box::new(PatternAst::Binary {
1165 op: BinaryOp::And,
1166 children: vec![
1167 PatternAst::Literal {
1168 op: UnaryOp::Optional,
1169 pattern: "a".to_string(),
1170 },
1171 PatternAst::Literal {
1172 op: UnaryOp::Optional,
1173 pattern: "b".to_string(),
1174 },
1175 ],
1176 }),
1177 },
1178 PatternAst::Literal {
1179 op: UnaryOp::Optional,
1180 pattern: "c".to_string(),
1181 },
1182 ],
1183 },
1184 ),
1185 (
1186 "a AND (b OR c)",
1187 PatternAst::Binary {
1188 op: BinaryOp::And,
1189 children: vec![
1190 PatternAst::Literal {
1191 op: UnaryOp::Optional,
1192 pattern: "a".to_string(),
1193 },
1194 PatternAst::Group {
1195 op: UnaryOp::Optional,
1196 child: Box::new(PatternAst::Binary {
1197 op: BinaryOp::Or,
1198 children: vec![
1199 PatternAst::Literal {
1200 op: UnaryOp::Optional,
1201 pattern: "b".to_string(),
1202 },
1203 PatternAst::Literal {
1204 op: UnaryOp::Optional,
1205 pattern: "c".to_string(),
1206 },
1207 ],
1208 }),
1209 },
1210 ],
1211 },
1212 ),
1213 (
1214 "a +b -c",
1215 PatternAst::Binary {
1216 op: BinaryOp::Or,
1217 children: vec![
1218 PatternAst::Literal {
1219 op: UnaryOp::Optional,
1220 pattern: "a".to_string(),
1221 },
1222 PatternAst::Binary {
1223 op: BinaryOp::Or,
1224 children: vec![
1225 PatternAst::Literal {
1226 op: UnaryOp::Must,
1227 pattern: "b".to_string(),
1228 },
1229 PatternAst::Literal {
1230 op: UnaryOp::Negative,
1231 pattern: "c".to_string(),
1232 },
1233 ],
1234 },
1235 ],
1236 },
1237 ),
1238 (
1239 "(+a +b) c",
1240 PatternAst::Binary {
1241 op: BinaryOp::Or,
1242 children: vec![
1243 PatternAst::Group {
1244 op: UnaryOp::Optional,
1245 child: Box::new(PatternAst::Binary {
1246 op: BinaryOp::Or,
1247 children: vec![
1248 PatternAst::Literal {
1249 op: UnaryOp::Must,
1250 pattern: "a".to_string(),
1251 },
1252 PatternAst::Literal {
1253 op: UnaryOp::Must,
1254 pattern: "b".to_string(),
1255 },
1256 ],
1257 }),
1258 },
1259 PatternAst::Literal {
1260 op: UnaryOp::Optional,
1261 pattern: "c".to_string(),
1262 },
1263 ],
1264 },
1265 ),
1266 (
1267 "\"AND\" AnD \"OR\"",
1268 PatternAst::Binary {
1269 op: BinaryOp::And,
1270 children: vec![
1271 PatternAst::Literal {
1272 op: UnaryOp::Optional,
1273 pattern: "AND".to_string(),
1274 },
1275 PatternAst::Literal {
1276 op: UnaryOp::Optional,
1277 pattern: "OR".to_string(),
1278 },
1279 ],
1280 },
1281 ),
1282 ];
1283
1284 for (query, expected) in cases {
1285 let parser = ParserContext { stack: vec![] };
1286 let ast = parser.parse_pattern(query).unwrap();
1287 assert_eq!(expected, ast, "{query}");
1288 }
1289 }
1290
1291 #[test]
1292 fn evaluate_matches() {
1293 let input_data = vec![
1294 "The quick brown fox jumps over the lazy dog",
1295 "The fox jumps over the lazy dog",
1296 "The quick brown jumps over the lazy dog",
1297 "The quick brown fox over the lazy dog",
1298 "The quick brown fox jumps the lazy dog",
1299 "The quick brown fox jumps over dog",
1300 "The quick brown fox jumps over the dog",
1301 ];
1302 let col: ArrayRef = Arc::new(StringArray::from(input_data));
1303 let cases = [
1304 ("quick", vec![true, false, true, true, true, true, true]),
1306 (
1307 "\"quick brown\"",
1308 vec![true, false, true, true, true, true, true],
1309 ),
1310 (
1311 "\"fox jumps\"",
1312 vec![true, true, false, false, true, true, true],
1313 ),
1314 (
1315 "fox OR lazy",
1316 vec![true, true, true, true, true, true, true],
1317 ),
1318 (
1319 "fox AND lazy",
1320 vec![true, true, false, true, true, false, false],
1321 ),
1322 (
1323 "-over -lazy",
1324 vec![false, false, false, false, false, false, false],
1325 ),
1326 (
1327 "-over AND -lazy",
1328 vec![false, false, false, false, false, false, false],
1329 ),
1330 (
1332 "fox AND jumps OR over",
1333 vec![true, true, true, true, true, true, true],
1334 ),
1335 (
1336 "fox OR brown AND quick",
1337 vec![true, true, true, true, true, true, true],
1338 ),
1339 (
1340 "(fox OR brown) AND quick",
1341 vec![true, false, true, true, true, true, true],
1342 ),
1343 (
1344 "brown AND quick OR fox",
1345 vec![true, true, true, true, true, true, true],
1346 ),
1347 (
1348 "brown AND (quick OR fox)",
1349 vec![true, false, true, true, true, true, true],
1350 ),
1351 (
1352 "brown AND quick AND fox OR jumps AND over AND lazy",
1353 vec![true, true, true, true, true, true, true],
1354 ),
1355 (
1357 "quick brown fox +jumps",
1358 vec![true, true, true, false, true, true, true],
1359 ),
1360 (
1361 "fox +jumps -over",
1362 vec![false, false, false, false, true, false, false],
1363 ),
1364 (
1365 "fox AND +jumps AND -over",
1366 vec![false, false, false, false, true, false, false],
1367 ),
1368 (
1370 "(+fox +jumps) over",
1371 vec![true, true, true, true, true, true, true],
1372 ),
1373 (
1374 "+(fox jumps) AND over",
1375 vec![true, true, true, true, false, true, true],
1376 ),
1377 (
1378 "over -(fox jumps)",
1379 vec![false, false, false, false, false, false, false],
1380 ),
1381 (
1382 "over -(fox AND jumps)",
1383 vec![false, false, true, true, false, false, false],
1384 ),
1385 (
1386 "over AND -(-(fox OR jumps))",
1387 vec![true, true, true, true, false, true, true],
1388 ),
1389 ];
1390
1391 let f = MatchesFunction::default();
1392 for (pattern, expected) in cases {
1393 let args = ScalarFunctionArgs {
1394 args: vec![
1395 ColumnarValue::Array(col.clone()),
1396 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(pattern.to_string()))),
1397 ],
1398 arg_fields: vec![],
1399 number_rows: col.len(),
1400 return_field: Arc::new(Field::new("x", col.data_type().clone(), true)),
1401 config_options: Arc::new(ConfigOptions::new()),
1402 };
1403 let actual = f
1404 .invoke_with_args(args)
1405 .and_then(|x| x.to_array(col.len()))
1406 .unwrap();
1407 let expected: ArrayRef = Arc::new(BooleanArray::from(expected));
1408 assert_eq!(expected.as_ref(), actual.as_ref(), "{pattern}");
1409 }
1410 }
1411}