Skip to main content

common_function/scalars/
matches.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt;
17use std::sync::Arc;
18
19use common_query::error::{InvalidFuncArgsSnafu, Result};
20use datafusion::arrow::array::{Array, ArrayRef, AsArray, BooleanArray};
21use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion};
22use datafusion::common::{DFSchema, Result as DfResult};
23use datafusion::execution::SessionStateBuilder;
24use datafusion::logical_expr::{self, ColumnarValue, Expr, Volatility};
25use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
26use datafusion_common::DataFusionError;
27use datafusion_expr::{ScalarFunctionArgs, Signature};
28use datatypes::arrow::array::RecordBatch;
29use datatypes::arrow::datatypes::{DataType, Field};
30use snafu::{OptionExt, ensure};
31
32use crate::function::{Function, extract_args};
33use crate::function_registry::FunctionRegistry;
34
35/// `matches` for full text search.
36///
37/// Usage: matches(`<col>`, `<pattern>`) -> boolean
38#[derive(Clone, Debug)]
39pub struct MatchesFunction {
40    signature: Signature,
41}
42
43impl MatchesFunction {
44    pub fn register(registry: &FunctionRegistry) {
45        registry.register_scalar(MatchesFunction::default());
46    }
47}
48
49impl Default for MatchesFunction {
50    fn default() -> Self {
51        Self {
52            signature: Signature::string(2, Volatility::Immutable),
53        }
54    }
55}
56
57impl fmt::Display for MatchesFunction {
58    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
59        write!(f, "MATCHES")
60    }
61}
62
63impl Function for MatchesFunction {
64    fn name(&self) -> &str {
65        "matches"
66    }
67
68    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
69        Ok(DataType::Boolean)
70    }
71
72    fn signature(&self) -> &Signature {
73        &self.signature
74    }
75
76    // TODO: read case-sensitive config
77    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DfResult<ColumnarValue> {
78        let [data_column, patterns] = extract_args(self.name(), &args)?;
79
80        if data_column.is_empty() {
81            return Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(
82                Vec::<bool>::with_capacity(0),
83            ))));
84        }
85
86        // Safety: both length and type are checked before
87        let pattern = match patterns.data_type() {
88            DataType::Utf8View => patterns.as_string_view().value(0),
89            DataType::Utf8 => patterns.as_string::<i32>().value(0),
90            DataType::LargeUtf8 => patterns.as_string::<i64>().value(0),
91            t => {
92                return Err(DataFusionError::Execution(format!(
93                    "unsupported datatype {t}"
94                )));
95            }
96        };
97        self.eval(data_column, pattern)
98    }
99}
100
101impl MatchesFunction {
102    fn eval(&self, data_array: ArrayRef, pattern: &str) -> DfResult<ColumnarValue> {
103        let col_name = "data";
104        let parser_context = ParserContext::default();
105        let raw_ast = parser_context.parse_pattern(pattern)?;
106        let ast = raw_ast.transform_ast()?;
107
108        let like_expr = ast.into_like_expr(col_name);
109
110        let input_schema = Self::input_schema();
111        let session_state = SessionStateBuilder::new().with_default_features().build();
112        let planner = DefaultPhysicalPlanner::default();
113        let physical_expr =
114            planner.create_physical_expr(&like_expr, &input_schema, &session_state)?;
115
116        let arrow_schema = Arc::new(input_schema.as_arrow().clone());
117        let input_record_batch = RecordBatch::try_new(arrow_schema, vec![data_array]).unwrap();
118
119        let num_rows = input_record_batch.num_rows();
120        let result = physical_expr.evaluate(&input_record_batch)?;
121        let result_array = result.into_array(num_rows)?;
122
123        Ok(ColumnarValue::Array(Arc::new(result_array)))
124    }
125
126    fn input_schema() -> DFSchema {
127        DFSchema::from_unqualified_fields(
128            [Arc::new(Field::new("data", DataType::Utf8, true))].into(),
129            HashMap::new(),
130        )
131        .unwrap()
132    }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq)]
136enum PatternAst {
137    // Distinguish this with `Group` for simplicity
138    /// A leaf node that matches a column with `pattern`
139    Literal { op: UnaryOp, pattern: String },
140    /// Flattened binary chains
141    Binary {
142        op: BinaryOp,
143        children: Vec<PatternAst>,
144    },
145    /// A sub-tree enclosed by parenthesis
146    Group { op: UnaryOp, child: Box<PatternAst> },
147}
148
149#[derive(Debug, Copy, Clone, PartialEq, Eq)]
150enum UnaryOp {
151    Must,
152    Optional,
153    Negative,
154}
155
156#[derive(Debug, Copy, Clone, PartialEq, Eq)]
157enum BinaryOp {
158    And,
159    Or,
160}
161
162impl PatternAst {
163    fn into_like_expr(self, column: &str) -> Expr {
164        match self {
165            PatternAst::Literal { op, pattern } => {
166                let expr = Self::convert_literal(column, &pattern);
167                match op {
168                    UnaryOp::Must => expr,
169                    UnaryOp::Optional => expr,
170                    UnaryOp::Negative => logical_expr::not(expr),
171                }
172            }
173            PatternAst::Binary { op, children } => {
174                if children.is_empty() {
175                    return logical_expr::lit(true);
176                }
177                let exprs = children
178                    .into_iter()
179                    .map(|child| child.into_like_expr(column));
180                // safety: children is not empty
181                match op {
182                    BinaryOp::And => exprs.reduce(Expr::and).unwrap(),
183                    BinaryOp::Or => exprs.reduce(Expr::or).unwrap(),
184                }
185            }
186            PatternAst::Group { op, child } => {
187                let child = child.into_like_expr(column);
188                match op {
189                    UnaryOp::Must => child,
190                    UnaryOp::Optional => child,
191                    UnaryOp::Negative => logical_expr::not(child),
192                }
193            }
194        }
195    }
196
197    fn convert_literal(column: &str, pattern: &str) -> Expr {
198        logical_expr::col(column).like(logical_expr::lit(format!(
199            "%{}%",
200            crate::utils::escape_like_pattern(pattern)
201        )))
202    }
203
204    /// Transform this AST with preset rules to make it correct.
205    fn transform_ast(self) -> Result<Self> {
206        self.transform_up(Self::collapse_binary_branch_fn)
207            .map(|data| data.data)?
208            .transform_up(Self::eliminate_optional_fn)
209            .map(|data| data.data)?
210            .transform_down(Self::eliminate_single_child_fn)
211            .map(|data| data.data)
212            .map_err(Into::into)
213    }
214
215    /// Collapse binary branch with the same operator. I.e., this transformer
216    /// changes the binary-tree AST into a multiple branching AST.
217    ///
218    /// This function is expected to be called in a bottom-up manner as
219    /// it won't recursion.
220    fn collapse_binary_branch_fn(self) -> DfResult<Transformed<Self>> {
221        let PatternAst::Binary {
222            op: parent_op,
223            children,
224        } = self
225        else {
226            return Ok(Transformed::no(self));
227        };
228
229        let mut collapsed = vec![];
230        let mut remains = vec![];
231
232        for child in children {
233            match child {
234                PatternAst::Literal { .. } | PatternAst::Group { .. } => {
235                    collapsed.push(child);
236                }
237                PatternAst::Binary { op, children } => {
238                    // no need to recursion because this function is expected to be called
239                    // in a bottom-up manner
240                    if op == parent_op {
241                        collapsed.extend(children);
242                    } else {
243                        remains.push(PatternAst::Binary { op, children });
244                    }
245                }
246            }
247        }
248
249        if collapsed.is_empty() {
250            Ok(Transformed::no(PatternAst::Binary {
251                op: parent_op,
252                children: remains,
253            }))
254        } else {
255            collapsed.extend(remains);
256            Ok(Transformed::yes(PatternAst::Binary {
257                op: parent_op,
258                children: collapsed,
259            }))
260        }
261    }
262
263    /// Eliminate optional pattern. An optional pattern can always be
264    /// omitted or transformed into a must pattern follows the following rules:
265    /// - If there is only one pattern and it's optional, change it to must
266    /// - If there is any must pattern, remove all other optional patterns
267    fn eliminate_optional_fn(self) -> DfResult<Transformed<Self>> {
268        let PatternAst::Binary {
269            op: parent_op,
270            children,
271        } = self
272        else {
273            return Ok(Transformed::no(self));
274        };
275
276        if parent_op == BinaryOp::Or {
277            let mut must_list = vec![];
278            let mut must_not_list = vec![];
279            let mut optional_list = vec![];
280            let mut compound_list = vec![];
281
282            for child in children {
283                match child {
284                    PatternAst::Literal { op, .. } | PatternAst::Group { op, .. } => match op {
285                        UnaryOp::Must => must_list.push(child),
286                        UnaryOp::Optional => optional_list.push(child),
287                        UnaryOp::Negative => must_not_list.push(child),
288                    },
289                    PatternAst::Binary { .. } => {
290                        compound_list.push(child);
291                    }
292                }
293            }
294
295            // Eliminate optional list if there is MUST.
296            if !must_list.is_empty() {
297                optional_list.clear();
298            }
299
300            let children_this_level = optional_list.into_iter().chain(compound_list).collect();
301            let new_node = if !must_list.is_empty() || !must_not_list.is_empty() {
302                let new_children = must_list
303                    .into_iter()
304                    .chain(must_not_list)
305                    .chain(Some(PatternAst::Binary {
306                        op: BinaryOp::Or,
307                        children: children_this_level,
308                    }))
309                    .collect();
310                PatternAst::Binary {
311                    op: BinaryOp::And,
312                    children: new_children,
313                }
314            } else {
315                PatternAst::Binary {
316                    op: BinaryOp::Or,
317                    children: children_this_level,
318                }
319            };
320
321            return Ok(Transformed::yes(new_node));
322        }
323
324        Ok(Transformed::no(PatternAst::Binary {
325            op: parent_op,
326            children,
327        }))
328    }
329
330    /// Eliminate single child [`PatternAst::Binary`] node. If a binary node has only one child, it can be
331    /// replaced by its only child.
332    ///
333    /// This function prefers to be applied in a top-down manner. But it's not required.
334    fn eliminate_single_child_fn(self) -> DfResult<Transformed<Self>> {
335        let PatternAst::Binary { op, mut children } = self else {
336            return Ok(Transformed::no(self));
337        };
338
339        // remove empty grand children
340        children.retain(|child| match child {
341            PatternAst::Binary {
342                children: grand_children,
343                ..
344            } => !grand_children.is_empty(),
345            PatternAst::Literal { .. } | PatternAst::Group { .. } => true,
346        });
347
348        if children.len() == 1 {
349            Ok(Transformed::yes(children.into_iter().next().unwrap()))
350        } else {
351            Ok(Transformed::no(PatternAst::Binary { op, children }))
352        }
353    }
354}
355
356impl TreeNode for PatternAst {
357    fn apply_children<'n, F: FnMut(&'n Self) -> DfResult<TreeNodeRecursion>>(
358        &'n self,
359        mut f: F,
360    ) -> DfResult<TreeNodeRecursion> {
361        match self {
362            PatternAst::Literal { .. } => Ok(TreeNodeRecursion::Continue),
363            PatternAst::Binary { op: _, children } => {
364                for child in children {
365                    if TreeNodeRecursion::Stop == f(child)? {
366                        return Ok(TreeNodeRecursion::Stop);
367                    }
368                }
369                Ok(TreeNodeRecursion::Continue)
370            }
371            PatternAst::Group { op: _, child } => f(child),
372        }
373    }
374
375    fn map_children<F: FnMut(Self) -> DfResult<Transformed<Self>>>(
376        self,
377        mut f: F,
378    ) -> DfResult<Transformed<Self>> {
379        match self {
380            PatternAst::Literal { .. } => Ok(Transformed::no(self)),
381            PatternAst::Binary { op, children } => children
382                .into_iter()
383                .map_until_stop_and_collect(&mut f)?
384                .map_data(|new_children| {
385                    Ok(PatternAst::Binary {
386                        op,
387                        children: new_children,
388                    })
389                }),
390            PatternAst::Group { op, child } => f(*child)?.map_data(|new_child| {
391                Ok(PatternAst::Group {
392                    op,
393                    child: Box::new(new_child),
394                })
395            }),
396        }
397    }
398}
399
400#[derive(Default)]
401struct ParserContext {
402    stack: Vec<PatternAst>,
403}
404
405impl ParserContext {
406    pub fn parse_pattern(mut self, pattern: &str) -> Result<PatternAst> {
407        let tokenizer = Tokenizer::default();
408        let raw_tokens = tokenizer.tokenize(pattern)?;
409        let raw_tokens = Self::accomplish_optional_unary_op(raw_tokens)?;
410        let mut tokens = Self::to_rpn(raw_tokens)?;
411
412        while !tokens.is_empty() {
413            self.parse_one_impl(&mut tokens)?;
414        }
415
416        ensure!(
417            !self.stack.is_empty(),
418            InvalidFuncArgsSnafu {
419                err_msg: "Empty pattern",
420            }
421        );
422
423        // conjoin them together
424        if self.stack.len() == 1 {
425            Ok(self.stack.pop().unwrap())
426        } else {
427            Ok(PatternAst::Binary {
428                op: BinaryOp::Or,
429                children: self.stack,
430            })
431        }
432    }
433
434    /// Add [`Token::Optional`] for all bare [`Token::Phase`] and [`Token::Or`]
435    /// for all adjacent [`Token::Phase`]s.
436    ///
437    /// This function also does some checks by the way. Like if two unary ops are
438    /// adjacent.
439    fn accomplish_optional_unary_op(raw_tokens: Vec<Token>) -> Result<Vec<Token>> {
440        let mut is_prev_unary_op = false;
441        // The first one doesn't need binary op
442        let mut is_binary_op_before = true;
443        let mut is_unary_op_before = false;
444        let mut new_tokens = Vec::with_capacity(raw_tokens.len());
445        for token in raw_tokens {
446            // fill `Token::Or`
447            if !is_binary_op_before
448                && matches!(
449                    token,
450                    Token::Phase(_)
451                        | Token::OpenParen
452                        | Token::Must
453                        | Token::Optional
454                        | Token::Negative
455                )
456            {
457                is_binary_op_before = true;
458                new_tokens.push(Token::Or);
459            }
460            if matches!(
461                token,
462                Token::OpenParen // treat open paren as begin of new group
463                | Token::And | Token::Or
464            ) {
465                is_binary_op_before = true;
466            } else if matches!(token, Token::Phase(_) | Token::CloseParen) {
467                // need binary op next time
468                is_binary_op_before = false;
469            }
470
471            // fill `Token::Optional`
472            if !is_prev_unary_op && matches!(token, Token::Phase(_) | Token::OpenParen) {
473                new_tokens.push(Token::Optional);
474            } else {
475                is_prev_unary_op = matches!(token, Token::Must | Token::Negative);
476            }
477
478            // check if unary ops are adjacent by the way
479            if matches!(token, Token::Must | Token::Optional | Token::Negative) {
480                if is_unary_op_before {
481                    return InvalidFuncArgsSnafu {
482                        err_msg: "Invalid pattern, unary operators should not be adjacent",
483                    }
484                    .fail();
485                }
486                is_unary_op_before = true;
487            } else {
488                is_unary_op_before = false;
489            }
490
491            new_tokens.push(token);
492        }
493
494        Ok(new_tokens)
495    }
496
497    /// Convert infix token stream to RPN
498    fn to_rpn(mut raw_tokens: Vec<Token>) -> Result<Vec<Token>> {
499        let mut operator_stack = vec![];
500        let mut result = vec![];
501        raw_tokens.reverse();
502
503        while let Some(token) = raw_tokens.pop() {
504            match token {
505                Token::Phase(_) => result.push(token),
506                Token::Must | Token::Negative | Token::Optional => {
507                    operator_stack.push(token);
508                }
509                Token::OpenParen => operator_stack.push(token),
510                Token::And | Token::Or => {
511                    // - Or has lower priority than And
512                    // - Binary op have lower priority than unary op
513                    while let Some(stack_top) = operator_stack.last()
514                        && ((*stack_top == Token::And && token == Token::Or)
515                            || matches!(
516                                *stack_top,
517                                Token::Must | Token::Optional | Token::Negative
518                            ))
519                    {
520                        result.push(operator_stack.pop().unwrap());
521                    }
522                    operator_stack.push(token);
523                }
524                Token::CloseParen => {
525                    let mut is_open_paren_found = false;
526                    while let Some(op) = operator_stack.pop() {
527                        if op == Token::OpenParen {
528                            is_open_paren_found = true;
529                            break;
530                        }
531                        result.push(op);
532                    }
533                    if !is_open_paren_found {
534                        return InvalidFuncArgsSnafu {
535                            err_msg: "Unmatched close parentheses",
536                        }
537                        .fail();
538                    }
539                }
540            }
541        }
542
543        while let Some(operator) = operator_stack.pop() {
544            if operator == Token::OpenParen {
545                return InvalidFuncArgsSnafu {
546                    err_msg: "Unmatched parentheses",
547                }
548                .fail();
549            }
550            result.push(operator);
551        }
552
553        Ok(result)
554    }
555
556    fn parse_one_impl(&mut self, tokens: &mut Vec<Token>) -> Result<()> {
557        if let Some(token) = tokens.pop() {
558            match token {
559                Token::Must => {
560                    if self.stack.is_empty() {
561                        self.parse_one_impl(tokens)?;
562                    }
563                    let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
564                        err_msg: "Invalid pattern, \"+\" operator should have one operand",
565                    })?;
566                    match phase_or_group {
567                        PatternAst::Literal { op: _, pattern } => {
568                            self.stack.push(PatternAst::Literal {
569                                op: UnaryOp::Must,
570                                pattern,
571                            });
572                        }
573                        PatternAst::Binary { .. } | PatternAst::Group { .. } => {
574                            self.stack.push(PatternAst::Group {
575                                op: UnaryOp::Must,
576                                child: Box::new(phase_or_group),
577                            })
578                        }
579                    }
580                    return Ok(());
581                }
582                Token::Negative => {
583                    if self.stack.is_empty() {
584                        self.parse_one_impl(tokens)?;
585                    }
586                    let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
587                        err_msg: "Invalid pattern, \"-\" operator should have one operand",
588                    })?;
589                    match phase_or_group {
590                        PatternAst::Literal { op: _, pattern } => {
591                            self.stack.push(PatternAst::Literal {
592                                op: UnaryOp::Negative,
593                                pattern,
594                            });
595                        }
596                        PatternAst::Binary { .. } | PatternAst::Group { .. } => {
597                            self.stack.push(PatternAst::Group {
598                                op: UnaryOp::Negative,
599                                child: Box::new(phase_or_group),
600                            })
601                        }
602                    }
603                    return Ok(());
604                }
605                Token::Optional => {
606                    if self.stack.is_empty() {
607                        self.parse_one_impl(tokens)?;
608                    }
609                    let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
610                        err_msg:
611                            "Invalid pattern, OPTIONAL(space) operator should have one operand",
612                    })?;
613                    match phase_or_group {
614                        PatternAst::Literal { op: _, pattern } => {
615                            self.stack.push(PatternAst::Literal {
616                                op: UnaryOp::Optional,
617                                pattern,
618                            });
619                        }
620                        PatternAst::Binary { .. } | PatternAst::Group { .. } => {
621                            self.stack.push(PatternAst::Group {
622                                op: UnaryOp::Optional,
623                                child: Box::new(phase_or_group),
624                            })
625                        }
626                    }
627                    return Ok(());
628                }
629                Token::Phase(pattern) => {
630                    self.stack.push(PatternAst::Literal {
631                        // Op here is a placeholder
632                        op: UnaryOp::Optional,
633                        pattern,
634                    })
635                }
636                Token::And => {
637                    if self.stack.is_empty() {
638                        self.parse_one_impl(tokens)?;
639                    };
640                    let rhs = self.stack.pop().context(InvalidFuncArgsSnafu {
641                        err_msg: "Invalid pattern, \"AND\" operator should have two operands",
642                    })?;
643                    if self.stack.is_empty() {
644                        self.parse_one_impl(tokens)?
645                    };
646                    let lhs = self.stack.pop().context(InvalidFuncArgsSnafu {
647                        err_msg: "Invalid pattern, \"AND\" operator should have two operands",
648                    })?;
649                    self.stack.push(PatternAst::Binary {
650                        op: BinaryOp::And,
651                        children: vec![lhs, rhs],
652                    });
653                    return Ok(());
654                }
655                Token::Or => {
656                    if self.stack.is_empty() {
657                        self.parse_one_impl(tokens)?
658                    };
659                    let rhs = self.stack.pop().context(InvalidFuncArgsSnafu {
660                        err_msg: "Invalid pattern, \"OR\" operator should have two operands",
661                    })?;
662                    if self.stack.is_empty() {
663                        self.parse_one_impl(tokens)?
664                    };
665                    let lhs = self.stack.pop().context(InvalidFuncArgsSnafu {
666                        err_msg: "Invalid pattern, \"OR\" operator should have two operands",
667                    })?;
668                    self.stack.push(PatternAst::Binary {
669                        op: BinaryOp::Or,
670                        children: vec![lhs, rhs],
671                    });
672                    return Ok(());
673                }
674                Token::OpenParen | Token::CloseParen => {
675                    return InvalidFuncArgsSnafu {
676                        err_msg: "Unexpected parentheses",
677                    }
678                    .fail();
679                }
680            }
681        }
682
683        Ok(())
684    }
685}
686
687#[derive(Clone, Debug, PartialEq, Eq)]
688enum Token {
689    /// "+"
690    Must,
691    /// "-"
692    Negative,
693    /// "AND"
694    And,
695    /// "OR"
696    Or,
697    /// "("
698    OpenParen,
699    /// ")"
700    CloseParen,
701    /// Any other phases
702    Phase(String),
703
704    /// This is not a token from user input, but a placeholder for internal use.
705    /// It's used to accomplish the unary operator class with Must and Negative.
706    /// In user provided pattern, optional is expressed by a bare phase or group
707    /// (simply nothing or writespace).
708    Optional,
709}
710
711#[derive(Default)]
712struct Tokenizer {
713    cursor: usize,
714}
715
716impl Tokenizer {
717    pub fn tokenize(mut self, pattern: &str) -> Result<Vec<Token>> {
718        let mut tokens = vec![];
719        let char_len = pattern.chars().count();
720        while self.cursor < char_len {
721            // TODO: collect pattern into Vec<char> if this tokenizer is bottleneck in the future
722            let c = pattern.chars().nth(self.cursor).unwrap();
723            match c {
724                '+' => tokens.push(Token::Must),
725                '-' => tokens.push(Token::Negative),
726                '(' => tokens.push(Token::OpenParen),
727                ')' => tokens.push(Token::CloseParen),
728                ' ' => {
729                    if let Some(last_token) = tokens.last() {
730                        match last_token {
731                            Token::Must | Token::Negative => {
732                                return InvalidFuncArgsSnafu {
733                                    err_msg: format!("Unexpected space after {:?}", last_token),
734                                }
735                                .fail();
736                            }
737                            _ => {}
738                        }
739                    }
740                }
741                '\"' => {
742                    self.step_next();
743                    let phase = self.consume_next_phase(true, pattern)?;
744                    tokens.push(Token::Phase(phase));
745                    // consume a writespace (or EOF) after quotes
746                    if let Some(ending_separator) = self.consume_next(pattern)
747                        && ending_separator != ' '
748                    {
749                        return InvalidFuncArgsSnafu {
750                            err_msg: "Expect a space after quotes ('\"')",
751                        }
752                        .fail();
753                    }
754                }
755                _ => {
756                    let phase = self.consume_next_phase(false, pattern)?;
757                    match phase.to_uppercase().as_str() {
758                        "AND" => tokens.push(Token::And),
759                        "OR" => tokens.push(Token::Or),
760                        _ => tokens.push(Token::Phase(phase)),
761                    }
762                }
763            }
764            self.cursor += 1;
765        }
766        Ok(tokens)
767    }
768
769    fn consume_next(&mut self, pattern: &str) -> Option<char> {
770        self.cursor += 1;
771        pattern.chars().nth(self.cursor)
772    }
773
774    fn step_next(&mut self) {
775        self.cursor += 1;
776    }
777
778    fn rewind_one(&mut self) {
779        self.cursor -= 1;
780    }
781
782    /// Current `cursor` points to the first character of the phase.
783    /// If the phase is enclosed by double quotes, consume the start quote before calling this.
784    fn consume_next_phase(&mut self, is_quoted: bool, pattern: &str) -> Result<String> {
785        let mut phase = String::new();
786        let mut is_quote_present = false;
787
788        let char_len = pattern.chars().count();
789        while self.cursor < char_len {
790            let mut c = pattern.chars().nth(self.cursor).unwrap();
791
792            match c {
793                '\"' => {
794                    is_quote_present = true;
795                    break;
796                }
797                ' ' if !is_quoted => {
798                    break;
799                }
800                '(' | ')' | '+' | '-' if !is_quoted => {
801                    self.rewind_one();
802                    break;
803                }
804                '\\' => {
805                    let Some(next) = self.consume_next(pattern) else {
806                        return InvalidFuncArgsSnafu {
807                            err_msg: "Unexpected end of pattern, expected a character after escape ('\\')",
808                        }.fail();
809                    };
810                    // it doesn't check whether the escaped character is valid or not
811                    c = next;
812                }
813                _ => {}
814            }
815
816            phase.push(c);
817            self.cursor += 1;
818        }
819
820        if is_quoted ^ is_quote_present {
821            return InvalidFuncArgsSnafu {
822                err_msg: "Unclosed quotes ('\"')",
823            }
824            .fail();
825        }
826
827        Ok(phase)
828    }
829}
830
831#[cfg(test)]
832mod test {
833    use datafusion::arrow::array::StringArray;
834    use datafusion_common::ScalarValue;
835    use datafusion_common::config::ConfigOptions;
836
837    use super::*;
838
839    #[test]
840    fn valid_matches_tokenizer() {
841        use Token::*;
842        let cases = [
843            (
844                "a +b -c",
845                vec![
846                    Phase("a".to_string()),
847                    Must,
848                    Phase("b".to_string()),
849                    Negative,
850                    Phase("c".to_string()),
851                ],
852            ),
853            (
854                "+a(b-c)",
855                vec![
856                    Must,
857                    Phase("a".to_string()),
858                    OpenParen,
859                    Phase("b".to_string()),
860                    Negative,
861                    Phase("c".to_string()),
862                    CloseParen,
863                ],
864            ),
865            (
866                r#"Barack Obama"#,
867                vec![Phase("Barack".to_string()), Phase("Obama".to_string())],
868            ),
869            (
870                r#"+apple +fruit"#,
871                vec![
872                    Must,
873                    Phase("apple".to_string()),
874                    Must,
875                    Phase("fruit".to_string()),
876                ],
877            ),
878            (
879                r#""He said \"hello\"""#,
880                vec![Phase("He said \"hello\"".to_string())],
881            ),
882            (
883                r#"a AND b OR c"#,
884                vec![
885                    Phase("a".to_string()),
886                    And,
887                    Phase("b".to_string()),
888                    Or,
889                    Phase("c".to_string()),
890                ],
891            ),
892            (
893                r#"中文 测试"#,
894                vec![Phase("中文".to_string()), Phase("测试".to_string())],
895            ),
896            (
897                r#"中文 AND 测试"#,
898                vec![Phase("中文".to_string()), And, Phase("测试".to_string())],
899            ),
900            (
901                r#"中文 +测试"#,
902                vec![Phase("中文".to_string()), Must, Phase("测试".to_string())],
903            ),
904            (
905                r#"中文 -测试"#,
906                vec![
907                    Phase("中文".to_string()),
908                    Negative,
909                    Phase("测试".to_string()),
910                ],
911            ),
912        ];
913
914        for (query, expected) in cases {
915            let tokenizer = Tokenizer::default();
916            let tokens = tokenizer.tokenize(query).unwrap();
917            assert_eq!(expected, tokens, "{query}");
918        }
919    }
920
921    #[test]
922    fn invalid_matches_tokenizer() {
923        let cases = [
924            (r#""He said "hello""#, "Expect a space after quotes"),
925            (r#""He said hello"#, "Unclosed quotes"),
926            (r#"a + b - c"#, "Unexpected space after"),
927            (r#"ab "c"def"#, "Expect a space after quotes"),
928        ];
929
930        for (query, expected) in cases {
931            let tokenizer = Tokenizer::default();
932            let result = tokenizer.tokenize(query);
933            assert!(result.is_err(), "{query}");
934            let actual_error = result.unwrap_err().to_string();
935            assert!(actual_error.contains(expected), "{query}, {actual_error}");
936        }
937    }
938
939    #[test]
940    fn valid_ast_transformer() {
941        let cases = [
942            (
943                "a AND b OR c",
944                PatternAst::Binary {
945                    op: BinaryOp::Or,
946                    children: vec![
947                        PatternAst::Literal {
948                            op: UnaryOp::Optional,
949                            pattern: "c".to_string(),
950                        },
951                        PatternAst::Binary {
952                            op: BinaryOp::And,
953                            children: vec![
954                                PatternAst::Literal {
955                                    op: UnaryOp::Optional,
956                                    pattern: "a".to_string(),
957                                },
958                                PatternAst::Literal {
959                                    op: UnaryOp::Optional,
960                                    pattern: "b".to_string(),
961                                },
962                            ],
963                        },
964                    ],
965                },
966            ),
967            (
968                "a -b",
969                PatternAst::Binary {
970                    op: BinaryOp::And,
971                    children: vec![
972                        PatternAst::Literal {
973                            op: UnaryOp::Negative,
974                            pattern: "b".to_string(),
975                        },
976                        PatternAst::Literal {
977                            op: UnaryOp::Optional,
978                            pattern: "a".to_string(),
979                        },
980                    ],
981                },
982            ),
983            (
984                "a +b",
985                PatternAst::Literal {
986                    op: UnaryOp::Must,
987                    pattern: "b".to_string(),
988                },
989            ),
990            (
991                "a b c d",
992                PatternAst::Binary {
993                    op: BinaryOp::Or,
994                    children: vec![
995                        PatternAst::Literal {
996                            op: UnaryOp::Optional,
997                            pattern: "a".to_string(),
998                        },
999                        PatternAst::Literal {
1000                            op: UnaryOp::Optional,
1001                            pattern: "b".to_string(),
1002                        },
1003                        PatternAst::Literal {
1004                            op: UnaryOp::Optional,
1005                            pattern: "c".to_string(),
1006                        },
1007                        PatternAst::Literal {
1008                            op: UnaryOp::Optional,
1009                            pattern: "d".to_string(),
1010                        },
1011                    ],
1012                },
1013            ),
1014            (
1015                "a b c AND d",
1016                PatternAst::Binary {
1017                    op: BinaryOp::Or,
1018                    children: vec![
1019                        PatternAst::Literal {
1020                            op: UnaryOp::Optional,
1021                            pattern: "a".to_string(),
1022                        },
1023                        PatternAst::Literal {
1024                            op: UnaryOp::Optional,
1025                            pattern: "b".to_string(),
1026                        },
1027                        PatternAst::Binary {
1028                            op: BinaryOp::And,
1029                            children: vec![
1030                                PatternAst::Literal {
1031                                    op: UnaryOp::Optional,
1032                                    pattern: "c".to_string(),
1033                                },
1034                                PatternAst::Literal {
1035                                    op: UnaryOp::Optional,
1036                                    pattern: "d".to_string(),
1037                                },
1038                            ],
1039                        },
1040                    ],
1041                },
1042            ),
1043            (
1044                r#"中文 测试"#,
1045                PatternAst::Binary {
1046                    op: BinaryOp::Or,
1047                    children: vec![
1048                        PatternAst::Literal {
1049                            op: UnaryOp::Optional,
1050                            pattern: "中文".to_string(),
1051                        },
1052                        PatternAst::Literal {
1053                            op: UnaryOp::Optional,
1054                            pattern: "测试".to_string(),
1055                        },
1056                    ],
1057                },
1058            ),
1059            (
1060                r#"中文 AND 测试"#,
1061                PatternAst::Binary {
1062                    op: BinaryOp::And,
1063                    children: vec![
1064                        PatternAst::Literal {
1065                            op: UnaryOp::Optional,
1066                            pattern: "中文".to_string(),
1067                        },
1068                        PatternAst::Literal {
1069                            op: UnaryOp::Optional,
1070                            pattern: "测试".to_string(),
1071                        },
1072                    ],
1073                },
1074            ),
1075            (
1076                r#"中文 +测试"#,
1077                PatternAst::Literal {
1078                    op: UnaryOp::Must,
1079                    pattern: "测试".to_string(),
1080                },
1081            ),
1082            (
1083                r#"中文 -测试"#,
1084                PatternAst::Binary {
1085                    op: BinaryOp::And,
1086                    children: vec![
1087                        PatternAst::Literal {
1088                            op: UnaryOp::Negative,
1089                            pattern: "测试".to_string(),
1090                        },
1091                        PatternAst::Literal {
1092                            op: UnaryOp::Optional,
1093                            pattern: "中文".to_string(),
1094                        },
1095                    ],
1096                },
1097            ),
1098        ];
1099
1100        for (query, expected) in cases {
1101            let parser = ParserContext { stack: vec![] };
1102            let ast = parser.parse_pattern(query).unwrap();
1103            let ast = ast.transform_ast().unwrap();
1104            assert_eq!(expected, ast, "{query}");
1105        }
1106    }
1107
1108    #[test]
1109    fn invalid_ast() {
1110        let cases = [
1111            (r#"a b (c"#, "Unmatched parentheses"),
1112            (r#"a b) c"#, "Unmatched close parentheses"),
1113            (r#"a +-b"#, "unary operators should not be adjacent"),
1114        ];
1115
1116        for (query, expected) in cases {
1117            let result: Result<()> = try {
1118                let parser = ParserContext { stack: vec![] };
1119                let ast = parser.parse_pattern(query)?;
1120                let _ast = ast.transform_ast()?;
1121            };
1122
1123            assert!(result.is_err(), "{query}");
1124            let actual_error = result.unwrap_err().to_string();
1125            assert!(actual_error.contains(expected), "{query}, {actual_error}");
1126        }
1127    }
1128
1129    #[test]
1130    fn valid_matches_parser() {
1131        let cases = [
1132            (
1133                "a AND b OR c",
1134                PatternAst::Binary {
1135                    op: BinaryOp::Or,
1136                    children: vec![
1137                        PatternAst::Binary {
1138                            op: BinaryOp::And,
1139                            children: vec![
1140                                PatternAst::Literal {
1141                                    op: UnaryOp::Optional,
1142                                    pattern: "a".to_string(),
1143                                },
1144                                PatternAst::Literal {
1145                                    op: UnaryOp::Optional,
1146                                    pattern: "b".to_string(),
1147                                },
1148                            ],
1149                        },
1150                        PatternAst::Literal {
1151                            op: UnaryOp::Optional,
1152                            pattern: "c".to_string(),
1153                        },
1154                    ],
1155                },
1156            ),
1157            (
1158                "(a AND b) OR c",
1159                PatternAst::Binary {
1160                    op: BinaryOp::Or,
1161                    children: vec![
1162                        PatternAst::Group {
1163                            op: UnaryOp::Optional,
1164                            child: Box::new(PatternAst::Binary {
1165                                op: BinaryOp::And,
1166                                children: vec![
1167                                    PatternAst::Literal {
1168                                        op: UnaryOp::Optional,
1169                                        pattern: "a".to_string(),
1170                                    },
1171                                    PatternAst::Literal {
1172                                        op: UnaryOp::Optional,
1173                                        pattern: "b".to_string(),
1174                                    },
1175                                ],
1176                            }),
1177                        },
1178                        PatternAst::Literal {
1179                            op: UnaryOp::Optional,
1180                            pattern: "c".to_string(),
1181                        },
1182                    ],
1183                },
1184            ),
1185            (
1186                "a AND (b OR c)",
1187                PatternAst::Binary {
1188                    op: BinaryOp::And,
1189                    children: vec![
1190                        PatternAst::Literal {
1191                            op: UnaryOp::Optional,
1192                            pattern: "a".to_string(),
1193                        },
1194                        PatternAst::Group {
1195                            op: UnaryOp::Optional,
1196                            child: Box::new(PatternAst::Binary {
1197                                op: BinaryOp::Or,
1198                                children: vec![
1199                                    PatternAst::Literal {
1200                                        op: UnaryOp::Optional,
1201                                        pattern: "b".to_string(),
1202                                    },
1203                                    PatternAst::Literal {
1204                                        op: UnaryOp::Optional,
1205                                        pattern: "c".to_string(),
1206                                    },
1207                                ],
1208                            }),
1209                        },
1210                    ],
1211                },
1212            ),
1213            (
1214                "a +b -c",
1215                PatternAst::Binary {
1216                    op: BinaryOp::Or,
1217                    children: vec![
1218                        PatternAst::Literal {
1219                            op: UnaryOp::Optional,
1220                            pattern: "a".to_string(),
1221                        },
1222                        PatternAst::Binary {
1223                            op: BinaryOp::Or,
1224                            children: vec![
1225                                PatternAst::Literal {
1226                                    op: UnaryOp::Must,
1227                                    pattern: "b".to_string(),
1228                                },
1229                                PatternAst::Literal {
1230                                    op: UnaryOp::Negative,
1231                                    pattern: "c".to_string(),
1232                                },
1233                            ],
1234                        },
1235                    ],
1236                },
1237            ),
1238            (
1239                "(+a +b) c",
1240                PatternAst::Binary {
1241                    op: BinaryOp::Or,
1242                    children: vec![
1243                        PatternAst::Group {
1244                            op: UnaryOp::Optional,
1245                            child: Box::new(PatternAst::Binary {
1246                                op: BinaryOp::Or,
1247                                children: vec![
1248                                    PatternAst::Literal {
1249                                        op: UnaryOp::Must,
1250                                        pattern: "a".to_string(),
1251                                    },
1252                                    PatternAst::Literal {
1253                                        op: UnaryOp::Must,
1254                                        pattern: "b".to_string(),
1255                                    },
1256                                ],
1257                            }),
1258                        },
1259                        PatternAst::Literal {
1260                            op: UnaryOp::Optional,
1261                            pattern: "c".to_string(),
1262                        },
1263                    ],
1264                },
1265            ),
1266            (
1267                "\"AND\" AnD \"OR\"",
1268                PatternAst::Binary {
1269                    op: BinaryOp::And,
1270                    children: vec![
1271                        PatternAst::Literal {
1272                            op: UnaryOp::Optional,
1273                            pattern: "AND".to_string(),
1274                        },
1275                        PatternAst::Literal {
1276                            op: UnaryOp::Optional,
1277                            pattern: "OR".to_string(),
1278                        },
1279                    ],
1280                },
1281            ),
1282        ];
1283
1284        for (query, expected) in cases {
1285            let parser = ParserContext { stack: vec![] };
1286            let ast = parser.parse_pattern(query).unwrap();
1287            assert_eq!(expected, ast, "{query}");
1288        }
1289    }
1290
1291    #[test]
1292    fn evaluate_matches() {
1293        let input_data = vec![
1294            "The quick brown fox jumps over the lazy dog",
1295            "The             fox jumps over the lazy dog",
1296            "The quick brown     jumps over the lazy dog",
1297            "The quick brown fox       over the lazy dog",
1298            "The quick brown fox jumps      the lazy dog",
1299            "The quick brown fox jumps over          dog",
1300            "The quick brown fox jumps over the      dog",
1301        ];
1302        let col: ArrayRef = Arc::new(StringArray::from(input_data));
1303        let cases = [
1304            // basic cases
1305            ("quick", vec![true, false, true, true, true, true, true]),
1306            (
1307                "\"quick brown\"",
1308                vec![true, false, true, true, true, true, true],
1309            ),
1310            (
1311                "\"fox jumps\"",
1312                vec![true, true, false, false, true, true, true],
1313            ),
1314            (
1315                "fox OR lazy",
1316                vec![true, true, true, true, true, true, true],
1317            ),
1318            (
1319                "fox AND lazy",
1320                vec![true, true, false, true, true, false, false],
1321            ),
1322            (
1323                "-over -lazy",
1324                vec![false, false, false, false, false, false, false],
1325            ),
1326            (
1327                "-over AND -lazy",
1328                vec![false, false, false, false, false, false, false],
1329            ),
1330            // priority between AND & OR
1331            (
1332                "fox AND jumps OR over",
1333                vec![true, true, true, true, true, true, true],
1334            ),
1335            (
1336                "fox OR brown AND quick",
1337                vec![true, true, true, true, true, true, true],
1338            ),
1339            (
1340                "(fox OR brown) AND quick",
1341                vec![true, false, true, true, true, true, true],
1342            ),
1343            (
1344                "brown AND quick OR fox",
1345                vec![true, true, true, true, true, true, true],
1346            ),
1347            (
1348                "brown AND (quick OR fox)",
1349                vec![true, false, true, true, true, true, true],
1350            ),
1351            (
1352                "brown AND quick AND fox  OR  jumps AND over AND lazy",
1353                vec![true, true, true, true, true, true, true],
1354            ),
1355            // optional & must conversion
1356            (
1357                "quick brown fox +jumps",
1358                vec![true, true, true, false, true, true, true],
1359            ),
1360            (
1361                "fox +jumps -over",
1362                vec![false, false, false, false, true, false, false],
1363            ),
1364            (
1365                "fox AND +jumps AND -over",
1366                vec![false, false, false, false, true, false, false],
1367            ),
1368            // weird parentheses cases
1369            (
1370                "(+fox +jumps) over",
1371                vec![true, true, true, true, true, true, true],
1372            ),
1373            (
1374                "+(fox jumps) AND over",
1375                vec![true, true, true, true, false, true, true],
1376            ),
1377            (
1378                "over -(fox jumps)",
1379                vec![false, false, false, false, false, false, false],
1380            ),
1381            (
1382                "over -(fox AND jumps)",
1383                vec![false, false, true, true, false, false, false],
1384            ),
1385            (
1386                "over AND -(-(fox OR jumps))",
1387                vec![true, true, true, true, false, true, true],
1388            ),
1389        ];
1390
1391        let f = MatchesFunction::default();
1392        for (pattern, expected) in cases {
1393            let args = ScalarFunctionArgs {
1394                args: vec![
1395                    ColumnarValue::Array(col.clone()),
1396                    ColumnarValue::Scalar(ScalarValue::Utf8View(Some(pattern.to_string()))),
1397                ],
1398                arg_fields: vec![],
1399                number_rows: col.len(),
1400                return_field: Arc::new(Field::new("x", col.data_type().clone(), true)),
1401                config_options: Arc::new(ConfigOptions::new()),
1402            };
1403            let actual = f
1404                .invoke_with_args(args)
1405                .and_then(|x| x.to_array(col.len()))
1406                .unwrap();
1407            let expected: ArrayRef = Arc::new(BooleanArray::from(expected));
1408            assert_eq!(expected.as_ref(), actual.as_ref(), "{pattern}");
1409        }
1410    }
1411}