common_function/scalars/
matches.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt;
17use std::sync::Arc;
18
19use common_query::error::{InvalidFuncArgsSnafu, Result};
20use datafusion::arrow::array::{Array, ArrayRef, AsArray, BooleanArray};
21use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion};
22use datafusion::common::{DFSchema, Result as DfResult};
23use datafusion::execution::SessionStateBuilder;
24use datafusion::logical_expr::{self, ColumnarValue, Expr, Volatility};
25use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
26use datafusion_common::DataFusionError;
27use datafusion_expr::{ScalarFunctionArgs, Signature};
28use datatypes::arrow::array::RecordBatch;
29use datatypes::arrow::datatypes::{DataType, Field};
30use snafu::{OptionExt, ensure};
31
32use crate::function::{Function, extract_args};
33use crate::function_registry::FunctionRegistry;
34
35/// `matches` for full text search.
36///
37/// Usage: matches(`<col>`, `<pattern>`) -> boolean
38#[derive(Clone, Debug)]
39pub struct MatchesFunction {
40    signature: Signature,
41}
42
43impl MatchesFunction {
44    pub fn register(registry: &FunctionRegistry) {
45        registry.register_scalar(MatchesFunction::default());
46    }
47}
48
49impl Default for MatchesFunction {
50    fn default() -> Self {
51        Self {
52            signature: Signature::string(2, Volatility::Immutable),
53        }
54    }
55}
56
57impl fmt::Display for MatchesFunction {
58    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
59        write!(f, "MATCHES")
60    }
61}
62
63impl Function for MatchesFunction {
64    fn name(&self) -> &str {
65        "matches"
66    }
67
68    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
69        Ok(DataType::Boolean)
70    }
71
72    fn signature(&self) -> &Signature {
73        &self.signature
74    }
75
76    // TODO: read case-sensitive config
77    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DfResult<ColumnarValue> {
78        let [data_column, patterns] = extract_args(self.name(), &args)?;
79
80        if data_column.is_empty() {
81            return Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(
82                Vec::<bool>::with_capacity(0),
83            ))));
84        }
85
86        // Safety: both length and type are checked before
87        let pattern = match patterns.data_type() {
88            DataType::Utf8View => patterns.as_string_view().value(0),
89            DataType::Utf8 => patterns.as_string::<i32>().value(0),
90            DataType::LargeUtf8 => patterns.as_string::<i64>().value(0),
91            t => {
92                return Err(DataFusionError::Execution(format!(
93                    "unsupported datatype {t}"
94                )));
95            }
96        };
97        self.eval(data_column, pattern)
98    }
99}
100
101impl MatchesFunction {
102    fn eval(&self, data_array: ArrayRef, pattern: &str) -> DfResult<ColumnarValue> {
103        let col_name = "data";
104        let parser_context = ParserContext::default();
105        let raw_ast = parser_context.parse_pattern(pattern)?;
106        let ast = raw_ast.transform_ast()?;
107
108        let like_expr = ast.into_like_expr(col_name);
109
110        let input_schema = Self::input_schema();
111        let session_state = SessionStateBuilder::new().with_default_features().build();
112        let planner = DefaultPhysicalPlanner::default();
113        let physical_expr =
114            planner.create_physical_expr(&like_expr, &input_schema, &session_state)?;
115
116        let arrow_schema = Arc::new(input_schema.as_arrow().clone());
117        let input_record_batch = RecordBatch::try_new(arrow_schema, vec![data_array]).unwrap();
118
119        let num_rows = input_record_batch.num_rows();
120        let result = physical_expr.evaluate(&input_record_batch)?;
121        let result_array = result.into_array(num_rows)?;
122
123        Ok(ColumnarValue::Array(Arc::new(result_array)))
124    }
125
126    fn input_schema() -> DFSchema {
127        DFSchema::from_unqualified_fields(
128            [Arc::new(Field::new("data", DataType::Utf8, true))].into(),
129            HashMap::new(),
130        )
131        .unwrap()
132    }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq)]
136enum PatternAst {
137    // Distinguish this with `Group` for simplicity
138    /// A leaf node that matches a column with `pattern`
139    Literal { op: UnaryOp, pattern: String },
140    /// Flattened binary chains
141    Binary {
142        op: BinaryOp,
143        children: Vec<PatternAst>,
144    },
145    /// A sub-tree enclosed by parenthesis
146    Group { op: UnaryOp, child: Box<PatternAst> },
147}
148
149#[derive(Debug, Copy, Clone, PartialEq, Eq)]
150enum UnaryOp {
151    Must,
152    Optional,
153    Negative,
154}
155
156#[derive(Debug, Copy, Clone, PartialEq, Eq)]
157enum BinaryOp {
158    And,
159    Or,
160}
161
162impl PatternAst {
163    fn into_like_expr(self, column: &str) -> Expr {
164        match self {
165            PatternAst::Literal { op, pattern } => {
166                let expr = Self::convert_literal(column, &pattern);
167                match op {
168                    UnaryOp::Must => expr,
169                    UnaryOp::Optional => expr,
170                    UnaryOp::Negative => logical_expr::not(expr),
171                }
172            }
173            PatternAst::Binary { op, children } => {
174                if children.is_empty() {
175                    return logical_expr::lit(true);
176                }
177                let exprs = children
178                    .into_iter()
179                    .map(|child| child.into_like_expr(column));
180                // safety: children is not empty
181                match op {
182                    BinaryOp::And => exprs.reduce(Expr::and).unwrap(),
183                    BinaryOp::Or => exprs.reduce(Expr::or).unwrap(),
184                }
185            }
186            PatternAst::Group { op, child } => {
187                let child = child.into_like_expr(column);
188                match op {
189                    UnaryOp::Must => child,
190                    UnaryOp::Optional => child,
191                    UnaryOp::Negative => logical_expr::not(child),
192                }
193            }
194        }
195    }
196
197    fn convert_literal(column: &str, pattern: &str) -> Expr {
198        logical_expr::col(column).like(logical_expr::lit(format!(
199            "%{}%",
200            crate::utils::escape_like_pattern(pattern)
201        )))
202    }
203
204    /// Transform this AST with preset rules to make it correct.
205    fn transform_ast(self) -> Result<Self> {
206        self.transform_up(Self::collapse_binary_branch_fn)
207            .map(|data| data.data)?
208            .transform_up(Self::eliminate_optional_fn)
209            .map(|data| data.data)?
210            .transform_down(Self::eliminate_single_child_fn)
211            .map(|data| data.data)
212            .map_err(Into::into)
213    }
214
215    /// Collapse binary branch with the same operator. I.e., this transformer
216    /// changes the binary-tree AST into a multiple branching AST.
217    ///
218    /// This function is expected to be called in a bottom-up manner as
219    /// it won't recursion.
220    fn collapse_binary_branch_fn(self) -> DfResult<Transformed<Self>> {
221        let PatternAst::Binary {
222            op: parent_op,
223            children,
224        } = self
225        else {
226            return Ok(Transformed::no(self));
227        };
228
229        let mut collapsed = vec![];
230        let mut remains = vec![];
231
232        for child in children {
233            match child {
234                PatternAst::Literal { .. } | PatternAst::Group { .. } => {
235                    collapsed.push(child);
236                }
237                PatternAst::Binary { op, children } => {
238                    // no need to recursion because this function is expected to be called
239                    // in a bottom-up manner
240                    if op == parent_op {
241                        collapsed.extend(children);
242                    } else {
243                        remains.push(PatternAst::Binary { op, children });
244                    }
245                }
246            }
247        }
248
249        if collapsed.is_empty() {
250            Ok(Transformed::no(PatternAst::Binary {
251                op: parent_op,
252                children: remains,
253            }))
254        } else {
255            collapsed.extend(remains);
256            Ok(Transformed::yes(PatternAst::Binary {
257                op: parent_op,
258                children: collapsed,
259            }))
260        }
261    }
262
263    /// Eliminate optional pattern. An optional pattern can always be
264    /// omitted or transformed into a must pattern follows the following rules:
265    /// - If there is only one pattern and it's optional, change it to must
266    /// - If there is any must pattern, remove all other optional patterns
267    fn eliminate_optional_fn(self) -> DfResult<Transformed<Self>> {
268        let PatternAst::Binary {
269            op: parent_op,
270            children,
271        } = self
272        else {
273            return Ok(Transformed::no(self));
274        };
275
276        if parent_op == BinaryOp::Or {
277            let mut must_list = vec![];
278            let mut must_not_list = vec![];
279            let mut optional_list = vec![];
280            let mut compound_list = vec![];
281
282            for child in children {
283                match child {
284                    PatternAst::Literal { op, .. } | PatternAst::Group { op, .. } => match op {
285                        UnaryOp::Must => must_list.push(child),
286                        UnaryOp::Optional => optional_list.push(child),
287                        UnaryOp::Negative => must_not_list.push(child),
288                    },
289                    PatternAst::Binary { .. } => {
290                        compound_list.push(child);
291                    }
292                }
293            }
294
295            // Eliminate optional list if there is MUST.
296            if !must_list.is_empty() {
297                optional_list.clear();
298            }
299
300            let children_this_level = optional_list.into_iter().chain(compound_list).collect();
301            let new_node = if !must_list.is_empty() || !must_not_list.is_empty() {
302                let new_children = must_list
303                    .into_iter()
304                    .chain(must_not_list)
305                    .chain(Some(PatternAst::Binary {
306                        op: BinaryOp::Or,
307                        children: children_this_level,
308                    }))
309                    .collect();
310                PatternAst::Binary {
311                    op: BinaryOp::And,
312                    children: new_children,
313                }
314            } else {
315                PatternAst::Binary {
316                    op: BinaryOp::Or,
317                    children: children_this_level,
318                }
319            };
320
321            return Ok(Transformed::yes(new_node));
322        }
323
324        Ok(Transformed::no(PatternAst::Binary {
325            op: parent_op,
326            children,
327        }))
328    }
329
330    /// Eliminate single child [`PatternAst::Binary`] node. If a binary node has only one child, it can be
331    /// replaced by its only child.
332    ///
333    /// This function prefers to be applied in a top-down manner. But it's not required.
334    fn eliminate_single_child_fn(self) -> DfResult<Transformed<Self>> {
335        let PatternAst::Binary { op, mut children } = self else {
336            return Ok(Transformed::no(self));
337        };
338
339        // remove empty grand children
340        children.retain(|child| match child {
341            PatternAst::Binary {
342                children: grand_children,
343                ..
344            } => !grand_children.is_empty(),
345            PatternAst::Literal { .. } | PatternAst::Group { .. } => true,
346        });
347
348        if children.len() == 1 {
349            Ok(Transformed::yes(children.into_iter().next().unwrap()))
350        } else {
351            Ok(Transformed::no(PatternAst::Binary { op, children }))
352        }
353    }
354}
355
356impl TreeNode for PatternAst {
357    fn apply_children<'n, F: FnMut(&'n Self) -> DfResult<TreeNodeRecursion>>(
358        &'n self,
359        mut f: F,
360    ) -> DfResult<TreeNodeRecursion> {
361        match self {
362            PatternAst::Literal { .. } => Ok(TreeNodeRecursion::Continue),
363            PatternAst::Binary { op: _, children } => {
364                for child in children {
365                    if TreeNodeRecursion::Stop == f(child)? {
366                        return Ok(TreeNodeRecursion::Stop);
367                    }
368                }
369                Ok(TreeNodeRecursion::Continue)
370            }
371            PatternAst::Group { op: _, child } => f(child),
372        }
373    }
374
375    fn map_children<F: FnMut(Self) -> DfResult<Transformed<Self>>>(
376        self,
377        mut f: F,
378    ) -> DfResult<Transformed<Self>> {
379        match self {
380            PatternAst::Literal { .. } => Ok(Transformed::no(self)),
381            PatternAst::Binary { op, children } => children
382                .into_iter()
383                .map_until_stop_and_collect(&mut f)?
384                .map_data(|new_children| {
385                    Ok(PatternAst::Binary {
386                        op,
387                        children: new_children,
388                    })
389                }),
390            PatternAst::Group { op, child } => f(*child)?.map_data(|new_child| {
391                Ok(PatternAst::Group {
392                    op,
393                    child: Box::new(new_child),
394                })
395            }),
396        }
397    }
398}
399
400#[derive(Default)]
401struct ParserContext {
402    stack: Vec<PatternAst>,
403}
404
405impl ParserContext {
406    pub fn parse_pattern(mut self, pattern: &str) -> Result<PatternAst> {
407        let tokenizer = Tokenizer::default();
408        let raw_tokens = tokenizer.tokenize(pattern)?;
409        let raw_tokens = Self::accomplish_optional_unary_op(raw_tokens)?;
410        let mut tokens = Self::to_rpn(raw_tokens)?;
411
412        while !tokens.is_empty() {
413            self.parse_one_impl(&mut tokens)?;
414        }
415
416        ensure!(
417            !self.stack.is_empty(),
418            InvalidFuncArgsSnafu {
419                err_msg: "Empty pattern",
420            }
421        );
422
423        // conjoin them together
424        if self.stack.len() == 1 {
425            Ok(self.stack.pop().unwrap())
426        } else {
427            Ok(PatternAst::Binary {
428                op: BinaryOp::Or,
429                children: self.stack,
430            })
431        }
432    }
433
434    /// Add [`Token::Optional`] for all bare [`Token::Phase`] and [`Token::Or`]
435    /// for all adjacent [`Token::Phase`]s.
436    ///
437    /// This function also does some checks by the way. Like if two unary ops are
438    /// adjacent.
439    fn accomplish_optional_unary_op(raw_tokens: Vec<Token>) -> Result<Vec<Token>> {
440        let mut is_prev_unary_op = false;
441        // The first one doesn't need binary op
442        let mut is_binary_op_before = true;
443        let mut is_unary_op_before = false;
444        let mut new_tokens = Vec::with_capacity(raw_tokens.len());
445        for token in raw_tokens {
446            // fill `Token::Or`
447            if !is_binary_op_before
448                && matches!(
449                    token,
450                    Token::Phase(_)
451                        | Token::OpenParen
452                        | Token::Must
453                        | Token::Optional
454                        | Token::Negative
455                )
456            {
457                is_binary_op_before = true;
458                new_tokens.push(Token::Or);
459            }
460            if matches!(
461                token,
462                Token::OpenParen // treat open paren as begin of new group
463                | Token::And | Token::Or
464            ) {
465                is_binary_op_before = true;
466            } else if matches!(token, Token::Phase(_) | Token::CloseParen) {
467                // need binary op next time
468                is_binary_op_before = false;
469            }
470
471            // fill `Token::Optional`
472            if !is_prev_unary_op && matches!(token, Token::Phase(_) | Token::OpenParen) {
473                new_tokens.push(Token::Optional);
474            } else {
475                is_prev_unary_op = matches!(token, Token::Must | Token::Negative);
476            }
477
478            // check if unary ops are adjacent by the way
479            if matches!(token, Token::Must | Token::Optional | Token::Negative) {
480                if is_unary_op_before {
481                    return InvalidFuncArgsSnafu {
482                        err_msg: "Invalid pattern, unary operators should not be adjacent",
483                    }
484                    .fail();
485                }
486                is_unary_op_before = true;
487            } else {
488                is_unary_op_before = false;
489            }
490
491            new_tokens.push(token);
492        }
493
494        Ok(new_tokens)
495    }
496
497    /// Convert infix token stream to RPN
498    fn to_rpn(mut raw_tokens: Vec<Token>) -> Result<Vec<Token>> {
499        let mut operator_stack = vec![];
500        let mut result = vec![];
501        raw_tokens.reverse();
502
503        while let Some(token) = raw_tokens.pop() {
504            match token {
505                Token::Phase(_) => result.push(token),
506                Token::Must | Token::Negative | Token::Optional => {
507                    operator_stack.push(token);
508                }
509                Token::OpenParen => operator_stack.push(token),
510                Token::And | Token::Or => {
511                    // - Or has lower priority than And
512                    // - Binary op have lower priority than unary op
513                    while let Some(stack_top) = operator_stack.last()
514                        && ((*stack_top == Token::And && token == Token::Or)
515                            || matches!(
516                                *stack_top,
517                                Token::Must | Token::Optional | Token::Negative
518                            ))
519                    {
520                        result.push(operator_stack.pop().unwrap());
521                    }
522                    operator_stack.push(token);
523                }
524                Token::CloseParen => {
525                    let mut is_open_paren_found = false;
526                    while let Some(op) = operator_stack.pop() {
527                        if op == Token::OpenParen {
528                            is_open_paren_found = true;
529                            break;
530                        }
531                        result.push(op);
532                    }
533                    if !is_open_paren_found {
534                        return InvalidFuncArgsSnafu {
535                            err_msg: "Unmatched close parentheses",
536                        }
537                        .fail();
538                    }
539                }
540            }
541        }
542
543        while let Some(operator) = operator_stack.pop() {
544            if operator == Token::OpenParen {
545                return InvalidFuncArgsSnafu {
546                    err_msg: "Unmatched parentheses",
547                }
548                .fail();
549            }
550            result.push(operator);
551        }
552
553        Ok(result)
554    }
555
556    fn parse_one_impl(&mut self, tokens: &mut Vec<Token>) -> Result<()> {
557        if let Some(token) = tokens.pop() {
558            match token {
559                Token::Must => {
560                    if self.stack.is_empty() {
561                        self.parse_one_impl(tokens)?;
562                    }
563                    let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
564                        err_msg: "Invalid pattern, \"+\" operator should have one operand",
565                    })?;
566                    match phase_or_group {
567                        PatternAst::Literal { op: _, pattern } => {
568                            self.stack.push(PatternAst::Literal {
569                                op: UnaryOp::Must,
570                                pattern,
571                            });
572                        }
573                        PatternAst::Binary { .. } | PatternAst::Group { .. } => {
574                            self.stack.push(PatternAst::Group {
575                                op: UnaryOp::Must,
576                                child: Box::new(phase_or_group),
577                            })
578                        }
579                    }
580                    return Ok(());
581                }
582                Token::Negative => {
583                    if self.stack.is_empty() {
584                        self.parse_one_impl(tokens)?;
585                    }
586                    let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
587                        err_msg: "Invalid pattern, \"-\" operator should have one operand",
588                    })?;
589                    match phase_or_group {
590                        PatternAst::Literal { op: _, pattern } => {
591                            self.stack.push(PatternAst::Literal {
592                                op: UnaryOp::Negative,
593                                pattern,
594                            });
595                        }
596                        PatternAst::Binary { .. } | PatternAst::Group { .. } => {
597                            self.stack.push(PatternAst::Group {
598                                op: UnaryOp::Negative,
599                                child: Box::new(phase_or_group),
600                            })
601                        }
602                    }
603                    return Ok(());
604                }
605                Token::Optional => {
606                    if self.stack.is_empty() {
607                        self.parse_one_impl(tokens)?;
608                    }
609                    let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
610                        err_msg:
611                            "Invalid pattern, OPTIONAL(space) operator should have one operand",
612                    })?;
613                    match phase_or_group {
614                        PatternAst::Literal { op: _, pattern } => {
615                            self.stack.push(PatternAst::Literal {
616                                op: UnaryOp::Optional,
617                                pattern,
618                            });
619                        }
620                        PatternAst::Binary { .. } | PatternAst::Group { .. } => {
621                            self.stack.push(PatternAst::Group {
622                                op: UnaryOp::Optional,
623                                child: Box::new(phase_or_group),
624                            })
625                        }
626                    }
627                    return Ok(());
628                }
629                Token::Phase(pattern) => {
630                    self.stack.push(PatternAst::Literal {
631                        // Op here is a placeholder
632                        op: UnaryOp::Optional,
633                        pattern,
634                    })
635                }
636                Token::And => {
637                    if self.stack.is_empty() {
638                        self.parse_one_impl(tokens)?;
639                    };
640                    let rhs = self.stack.pop().context(InvalidFuncArgsSnafu {
641                        err_msg: "Invalid pattern, \"AND\" operator should have two operands",
642                    })?;
643                    if self.stack.is_empty() {
644                        self.parse_one_impl(tokens)?
645                    };
646                    let lhs = self.stack.pop().context(InvalidFuncArgsSnafu {
647                        err_msg: "Invalid pattern, \"AND\" operator should have two operands",
648                    })?;
649                    self.stack.push(PatternAst::Binary {
650                        op: BinaryOp::And,
651                        children: vec![lhs, rhs],
652                    });
653                    return Ok(());
654                }
655                Token::Or => {
656                    if self.stack.is_empty() {
657                        self.parse_one_impl(tokens)?
658                    };
659                    let rhs = self.stack.pop().context(InvalidFuncArgsSnafu {
660                        err_msg: "Invalid pattern, \"OR\" operator should have two operands",
661                    })?;
662                    if self.stack.is_empty() {
663                        self.parse_one_impl(tokens)?
664                    };
665                    let lhs = self.stack.pop().context(InvalidFuncArgsSnafu {
666                        err_msg: "Invalid pattern, \"OR\" operator should have two operands",
667                    })?;
668                    self.stack.push(PatternAst::Binary {
669                        op: BinaryOp::Or,
670                        children: vec![lhs, rhs],
671                    });
672                    return Ok(());
673                }
674                Token::OpenParen | Token::CloseParen => {
675                    return InvalidFuncArgsSnafu {
676                        err_msg: "Unexpected parentheses",
677                    }
678                    .fail();
679                }
680            }
681        }
682
683        Ok(())
684    }
685}
686
687#[derive(Clone, Debug, PartialEq, Eq)]
688enum Token {
689    /// "+"
690    Must,
691    /// "-"
692    Negative,
693    /// "AND"
694    And,
695    /// "OR"
696    Or,
697    /// "("
698    OpenParen,
699    /// ")"
700    CloseParen,
701    /// Any other phases
702    Phase(String),
703
704    /// This is not a token from user input, but a placeholder for internal use.
705    /// It's used to accomplish the unary operator class with Must and Negative.
706    /// In user provided pattern, optional is expressed by a bare phase or group
707    /// (simply nothing or writespace).
708    Optional,
709}
710
711#[derive(Default)]
712struct Tokenizer {
713    cursor: usize,
714}
715
716impl Tokenizer {
717    pub fn tokenize(mut self, pattern: &str) -> Result<Vec<Token>> {
718        let mut tokens = vec![];
719        let char_len = pattern.chars().count();
720        while self.cursor < char_len {
721            // TODO: collect pattern into Vec<char> if this tokenizer is bottleneck in the future
722            let c = pattern.chars().nth(self.cursor).unwrap();
723            match c {
724                '+' => tokens.push(Token::Must),
725                '-' => tokens.push(Token::Negative),
726                '(' => tokens.push(Token::OpenParen),
727                ')' => tokens.push(Token::CloseParen),
728                ' ' => {
729                    if let Some(last_token) = tokens.last() {
730                        match last_token {
731                            Token::Must | Token::Negative => {
732                                return InvalidFuncArgsSnafu {
733                                    err_msg: format!("Unexpected space after {:?}", last_token),
734                                }
735                                .fail();
736                            }
737                            _ => {}
738                        }
739                    }
740                }
741                '\"' => {
742                    self.step_next();
743                    let phase = self.consume_next_phase(true, pattern)?;
744                    tokens.push(Token::Phase(phase));
745                    // consume a writespace (or EOF) after quotes
746                    if let Some(ending_separator) = self.consume_next(pattern)
747                        && ending_separator != ' '
748                    {
749                        return InvalidFuncArgsSnafu {
750                            err_msg: "Expect a space after quotes ('\"')",
751                        }
752                        .fail();
753                    }
754                }
755                _ => {
756                    let phase = self.consume_next_phase(false, pattern)?;
757                    match phase.to_uppercase().as_str() {
758                        "AND" => tokens.push(Token::And),
759                        "OR" => tokens.push(Token::Or),
760                        _ => tokens.push(Token::Phase(phase)),
761                    }
762                }
763            }
764            self.cursor += 1;
765        }
766        Ok(tokens)
767    }
768
769    fn consume_next(&mut self, pattern: &str) -> Option<char> {
770        self.cursor += 1;
771        pattern.chars().nth(self.cursor)
772    }
773
774    fn step_next(&mut self) {
775        self.cursor += 1;
776    }
777
778    fn rewind_one(&mut self) {
779        self.cursor -= 1;
780    }
781
782    /// Current `cursor` points to the first character of the phase.
783    /// If the phase is enclosed by double quotes, consume the start quote before calling this.
784    fn consume_next_phase(&mut self, is_quoted: bool, pattern: &str) -> Result<String> {
785        let mut phase = String::new();
786        let mut is_quote_present = false;
787
788        let char_len = pattern.chars().count();
789        while self.cursor < char_len {
790            let mut c = pattern.chars().nth(self.cursor).unwrap();
791
792            match c {
793                '\"' => {
794                    is_quote_present = true;
795                    break;
796                }
797                ' ' => {
798                    if !is_quoted {
799                        break;
800                    }
801                }
802                '(' | ')' | '+' | '-' => {
803                    if !is_quoted {
804                        self.rewind_one();
805                        break;
806                    }
807                }
808                '\\' => {
809                    let Some(next) = self.consume_next(pattern) else {
810                        return InvalidFuncArgsSnafu {
811                            err_msg: "Unexpected end of pattern, expected a character after escape ('\\')",
812                        }.fail();
813                    };
814                    // it doesn't check whether the escaped character is valid or not
815                    c = next;
816                }
817                _ => {}
818            }
819
820            phase.push(c);
821            self.cursor += 1;
822        }
823
824        if is_quoted ^ is_quote_present {
825            return InvalidFuncArgsSnafu {
826                err_msg: "Unclosed quotes ('\"')",
827            }
828            .fail();
829        }
830
831        Ok(phase)
832    }
833}
834
835#[cfg(test)]
836mod test {
837    use datafusion::arrow::array::StringArray;
838    use datafusion_common::ScalarValue;
839    use datafusion_common::config::ConfigOptions;
840
841    use super::*;
842
843    #[test]
844    fn valid_matches_tokenizer() {
845        use Token::*;
846        let cases = [
847            (
848                "a +b -c",
849                vec![
850                    Phase("a".to_string()),
851                    Must,
852                    Phase("b".to_string()),
853                    Negative,
854                    Phase("c".to_string()),
855                ],
856            ),
857            (
858                "+a(b-c)",
859                vec![
860                    Must,
861                    Phase("a".to_string()),
862                    OpenParen,
863                    Phase("b".to_string()),
864                    Negative,
865                    Phase("c".to_string()),
866                    CloseParen,
867                ],
868            ),
869            (
870                r#"Barack Obama"#,
871                vec![Phase("Barack".to_string()), Phase("Obama".to_string())],
872            ),
873            (
874                r#"+apple +fruit"#,
875                vec![
876                    Must,
877                    Phase("apple".to_string()),
878                    Must,
879                    Phase("fruit".to_string()),
880                ],
881            ),
882            (
883                r#""He said \"hello\"""#,
884                vec![Phase("He said \"hello\"".to_string())],
885            ),
886            (
887                r#"a AND b OR c"#,
888                vec![
889                    Phase("a".to_string()),
890                    And,
891                    Phase("b".to_string()),
892                    Or,
893                    Phase("c".to_string()),
894                ],
895            ),
896            (
897                r#"中文 测试"#,
898                vec![Phase("中文".to_string()), Phase("测试".to_string())],
899            ),
900            (
901                r#"中文 AND 测试"#,
902                vec![Phase("中文".to_string()), And, Phase("测试".to_string())],
903            ),
904            (
905                r#"中文 +测试"#,
906                vec![Phase("中文".to_string()), Must, Phase("测试".to_string())],
907            ),
908            (
909                r#"中文 -测试"#,
910                vec![
911                    Phase("中文".to_string()),
912                    Negative,
913                    Phase("测试".to_string()),
914                ],
915            ),
916        ];
917
918        for (query, expected) in cases {
919            let tokenizer = Tokenizer::default();
920            let tokens = tokenizer.tokenize(query).unwrap();
921            assert_eq!(expected, tokens, "{query}");
922        }
923    }
924
925    #[test]
926    fn invalid_matches_tokenizer() {
927        let cases = [
928            (r#""He said "hello""#, "Expect a space after quotes"),
929            (r#""He said hello"#, "Unclosed quotes"),
930            (r#"a + b - c"#, "Unexpected space after"),
931            (r#"ab "c"def"#, "Expect a space after quotes"),
932        ];
933
934        for (query, expected) in cases {
935            let tokenizer = Tokenizer::default();
936            let result = tokenizer.tokenize(query);
937            assert!(result.is_err(), "{query}");
938            let actual_error = result.unwrap_err().to_string();
939            assert!(actual_error.contains(expected), "{query}, {actual_error}");
940        }
941    }
942
943    #[test]
944    fn valid_ast_transformer() {
945        let cases = [
946            (
947                "a AND b OR c",
948                PatternAst::Binary {
949                    op: BinaryOp::Or,
950                    children: vec![
951                        PatternAst::Literal {
952                            op: UnaryOp::Optional,
953                            pattern: "c".to_string(),
954                        },
955                        PatternAst::Binary {
956                            op: BinaryOp::And,
957                            children: vec![
958                                PatternAst::Literal {
959                                    op: UnaryOp::Optional,
960                                    pattern: "a".to_string(),
961                                },
962                                PatternAst::Literal {
963                                    op: UnaryOp::Optional,
964                                    pattern: "b".to_string(),
965                                },
966                            ],
967                        },
968                    ],
969                },
970            ),
971            (
972                "a -b",
973                PatternAst::Binary {
974                    op: BinaryOp::And,
975                    children: vec![
976                        PatternAst::Literal {
977                            op: UnaryOp::Negative,
978                            pattern: "b".to_string(),
979                        },
980                        PatternAst::Literal {
981                            op: UnaryOp::Optional,
982                            pattern: "a".to_string(),
983                        },
984                    ],
985                },
986            ),
987            (
988                "a +b",
989                PatternAst::Literal {
990                    op: UnaryOp::Must,
991                    pattern: "b".to_string(),
992                },
993            ),
994            (
995                "a b c d",
996                PatternAst::Binary {
997                    op: BinaryOp::Or,
998                    children: vec![
999                        PatternAst::Literal {
1000                            op: UnaryOp::Optional,
1001                            pattern: "a".to_string(),
1002                        },
1003                        PatternAst::Literal {
1004                            op: UnaryOp::Optional,
1005                            pattern: "b".to_string(),
1006                        },
1007                        PatternAst::Literal {
1008                            op: UnaryOp::Optional,
1009                            pattern: "c".to_string(),
1010                        },
1011                        PatternAst::Literal {
1012                            op: UnaryOp::Optional,
1013                            pattern: "d".to_string(),
1014                        },
1015                    ],
1016                },
1017            ),
1018            (
1019                "a b c AND d",
1020                PatternAst::Binary {
1021                    op: BinaryOp::Or,
1022                    children: vec![
1023                        PatternAst::Literal {
1024                            op: UnaryOp::Optional,
1025                            pattern: "a".to_string(),
1026                        },
1027                        PatternAst::Literal {
1028                            op: UnaryOp::Optional,
1029                            pattern: "b".to_string(),
1030                        },
1031                        PatternAst::Binary {
1032                            op: BinaryOp::And,
1033                            children: vec![
1034                                PatternAst::Literal {
1035                                    op: UnaryOp::Optional,
1036                                    pattern: "c".to_string(),
1037                                },
1038                                PatternAst::Literal {
1039                                    op: UnaryOp::Optional,
1040                                    pattern: "d".to_string(),
1041                                },
1042                            ],
1043                        },
1044                    ],
1045                },
1046            ),
1047            (
1048                r#"中文 测试"#,
1049                PatternAst::Binary {
1050                    op: BinaryOp::Or,
1051                    children: vec![
1052                        PatternAst::Literal {
1053                            op: UnaryOp::Optional,
1054                            pattern: "中文".to_string(),
1055                        },
1056                        PatternAst::Literal {
1057                            op: UnaryOp::Optional,
1058                            pattern: "测试".to_string(),
1059                        },
1060                    ],
1061                },
1062            ),
1063            (
1064                r#"中文 AND 测试"#,
1065                PatternAst::Binary {
1066                    op: BinaryOp::And,
1067                    children: vec![
1068                        PatternAst::Literal {
1069                            op: UnaryOp::Optional,
1070                            pattern: "中文".to_string(),
1071                        },
1072                        PatternAst::Literal {
1073                            op: UnaryOp::Optional,
1074                            pattern: "测试".to_string(),
1075                        },
1076                    ],
1077                },
1078            ),
1079            (
1080                r#"中文 +测试"#,
1081                PatternAst::Literal {
1082                    op: UnaryOp::Must,
1083                    pattern: "测试".to_string(),
1084                },
1085            ),
1086            (
1087                r#"中文 -测试"#,
1088                PatternAst::Binary {
1089                    op: BinaryOp::And,
1090                    children: vec![
1091                        PatternAst::Literal {
1092                            op: UnaryOp::Negative,
1093                            pattern: "测试".to_string(),
1094                        },
1095                        PatternAst::Literal {
1096                            op: UnaryOp::Optional,
1097                            pattern: "中文".to_string(),
1098                        },
1099                    ],
1100                },
1101            ),
1102        ];
1103
1104        for (query, expected) in cases {
1105            let parser = ParserContext { stack: vec![] };
1106            let ast = parser.parse_pattern(query).unwrap();
1107            let ast = ast.transform_ast().unwrap();
1108            assert_eq!(expected, ast, "{query}");
1109        }
1110    }
1111
1112    #[test]
1113    fn invalid_ast() {
1114        let cases = [
1115            (r#"a b (c"#, "Unmatched parentheses"),
1116            (r#"a b) c"#, "Unmatched close parentheses"),
1117            (r#"a +-b"#, "unary operators should not be adjacent"),
1118        ];
1119
1120        for (query, expected) in cases {
1121            let result: Result<()> = try {
1122                let parser = ParserContext { stack: vec![] };
1123                let ast = parser.parse_pattern(query)?;
1124                let _ast = ast.transform_ast()?;
1125            };
1126
1127            assert!(result.is_err(), "{query}");
1128            let actual_error = result.unwrap_err().to_string();
1129            assert!(actual_error.contains(expected), "{query}, {actual_error}");
1130        }
1131    }
1132
1133    #[test]
1134    fn valid_matches_parser() {
1135        let cases = [
1136            (
1137                "a AND b OR c",
1138                PatternAst::Binary {
1139                    op: BinaryOp::Or,
1140                    children: vec![
1141                        PatternAst::Binary {
1142                            op: BinaryOp::And,
1143                            children: vec![
1144                                PatternAst::Literal {
1145                                    op: UnaryOp::Optional,
1146                                    pattern: "a".to_string(),
1147                                },
1148                                PatternAst::Literal {
1149                                    op: UnaryOp::Optional,
1150                                    pattern: "b".to_string(),
1151                                },
1152                            ],
1153                        },
1154                        PatternAst::Literal {
1155                            op: UnaryOp::Optional,
1156                            pattern: "c".to_string(),
1157                        },
1158                    ],
1159                },
1160            ),
1161            (
1162                "(a AND b) OR c",
1163                PatternAst::Binary {
1164                    op: BinaryOp::Or,
1165                    children: vec![
1166                        PatternAst::Group {
1167                            op: UnaryOp::Optional,
1168                            child: Box::new(PatternAst::Binary {
1169                                op: BinaryOp::And,
1170                                children: vec![
1171                                    PatternAst::Literal {
1172                                        op: UnaryOp::Optional,
1173                                        pattern: "a".to_string(),
1174                                    },
1175                                    PatternAst::Literal {
1176                                        op: UnaryOp::Optional,
1177                                        pattern: "b".to_string(),
1178                                    },
1179                                ],
1180                            }),
1181                        },
1182                        PatternAst::Literal {
1183                            op: UnaryOp::Optional,
1184                            pattern: "c".to_string(),
1185                        },
1186                    ],
1187                },
1188            ),
1189            (
1190                "a AND (b OR c)",
1191                PatternAst::Binary {
1192                    op: BinaryOp::And,
1193                    children: vec![
1194                        PatternAst::Literal {
1195                            op: UnaryOp::Optional,
1196                            pattern: "a".to_string(),
1197                        },
1198                        PatternAst::Group {
1199                            op: UnaryOp::Optional,
1200                            child: Box::new(PatternAst::Binary {
1201                                op: BinaryOp::Or,
1202                                children: vec![
1203                                    PatternAst::Literal {
1204                                        op: UnaryOp::Optional,
1205                                        pattern: "b".to_string(),
1206                                    },
1207                                    PatternAst::Literal {
1208                                        op: UnaryOp::Optional,
1209                                        pattern: "c".to_string(),
1210                                    },
1211                                ],
1212                            }),
1213                        },
1214                    ],
1215                },
1216            ),
1217            (
1218                "a +b -c",
1219                PatternAst::Binary {
1220                    op: BinaryOp::Or,
1221                    children: vec![
1222                        PatternAst::Literal {
1223                            op: UnaryOp::Optional,
1224                            pattern: "a".to_string(),
1225                        },
1226                        PatternAst::Binary {
1227                            op: BinaryOp::Or,
1228                            children: vec![
1229                                PatternAst::Literal {
1230                                    op: UnaryOp::Must,
1231                                    pattern: "b".to_string(),
1232                                },
1233                                PatternAst::Literal {
1234                                    op: UnaryOp::Negative,
1235                                    pattern: "c".to_string(),
1236                                },
1237                            ],
1238                        },
1239                    ],
1240                },
1241            ),
1242            (
1243                "(+a +b) c",
1244                PatternAst::Binary {
1245                    op: BinaryOp::Or,
1246                    children: vec![
1247                        PatternAst::Group {
1248                            op: UnaryOp::Optional,
1249                            child: Box::new(PatternAst::Binary {
1250                                op: BinaryOp::Or,
1251                                children: vec![
1252                                    PatternAst::Literal {
1253                                        op: UnaryOp::Must,
1254                                        pattern: "a".to_string(),
1255                                    },
1256                                    PatternAst::Literal {
1257                                        op: UnaryOp::Must,
1258                                        pattern: "b".to_string(),
1259                                    },
1260                                ],
1261                            }),
1262                        },
1263                        PatternAst::Literal {
1264                            op: UnaryOp::Optional,
1265                            pattern: "c".to_string(),
1266                        },
1267                    ],
1268                },
1269            ),
1270            (
1271                "\"AND\" AnD \"OR\"",
1272                PatternAst::Binary {
1273                    op: BinaryOp::And,
1274                    children: vec![
1275                        PatternAst::Literal {
1276                            op: UnaryOp::Optional,
1277                            pattern: "AND".to_string(),
1278                        },
1279                        PatternAst::Literal {
1280                            op: UnaryOp::Optional,
1281                            pattern: "OR".to_string(),
1282                        },
1283                    ],
1284                },
1285            ),
1286        ];
1287
1288        for (query, expected) in cases {
1289            let parser = ParserContext { stack: vec![] };
1290            let ast = parser.parse_pattern(query).unwrap();
1291            assert_eq!(expected, ast, "{query}");
1292        }
1293    }
1294
1295    #[test]
1296    fn evaluate_matches() {
1297        let input_data = vec![
1298            "The quick brown fox jumps over the lazy dog",
1299            "The             fox jumps over the lazy dog",
1300            "The quick brown     jumps over the lazy dog",
1301            "The quick brown fox       over the lazy dog",
1302            "The quick brown fox jumps      the lazy dog",
1303            "The quick brown fox jumps over          dog",
1304            "The quick brown fox jumps over the      dog",
1305        ];
1306        let col: ArrayRef = Arc::new(StringArray::from(input_data));
1307        let cases = [
1308            // basic cases
1309            ("quick", vec![true, false, true, true, true, true, true]),
1310            (
1311                "\"quick brown\"",
1312                vec![true, false, true, true, true, true, true],
1313            ),
1314            (
1315                "\"fox jumps\"",
1316                vec![true, true, false, false, true, true, true],
1317            ),
1318            (
1319                "fox OR lazy",
1320                vec![true, true, true, true, true, true, true],
1321            ),
1322            (
1323                "fox AND lazy",
1324                vec![true, true, false, true, true, false, false],
1325            ),
1326            (
1327                "-over -lazy",
1328                vec![false, false, false, false, false, false, false],
1329            ),
1330            (
1331                "-over AND -lazy",
1332                vec![false, false, false, false, false, false, false],
1333            ),
1334            // priority between AND & OR
1335            (
1336                "fox AND jumps OR over",
1337                vec![true, true, true, true, true, true, true],
1338            ),
1339            (
1340                "fox OR brown AND quick",
1341                vec![true, true, true, true, true, true, true],
1342            ),
1343            (
1344                "(fox OR brown) AND quick",
1345                vec![true, false, true, true, true, true, true],
1346            ),
1347            (
1348                "brown AND quick OR fox",
1349                vec![true, true, true, true, true, true, true],
1350            ),
1351            (
1352                "brown AND (quick OR fox)",
1353                vec![true, false, true, true, true, true, true],
1354            ),
1355            (
1356                "brown AND quick AND fox  OR  jumps AND over AND lazy",
1357                vec![true, true, true, true, true, true, true],
1358            ),
1359            // optional & must conversion
1360            (
1361                "quick brown fox +jumps",
1362                vec![true, true, true, false, true, true, true],
1363            ),
1364            (
1365                "fox +jumps -over",
1366                vec![false, false, false, false, true, false, false],
1367            ),
1368            (
1369                "fox AND +jumps AND -over",
1370                vec![false, false, false, false, true, false, false],
1371            ),
1372            // weird parentheses cases
1373            (
1374                "(+fox +jumps) over",
1375                vec![true, true, true, true, true, true, true],
1376            ),
1377            (
1378                "+(fox jumps) AND over",
1379                vec![true, true, true, true, false, true, true],
1380            ),
1381            (
1382                "over -(fox jumps)",
1383                vec![false, false, false, false, false, false, false],
1384            ),
1385            (
1386                "over -(fox AND jumps)",
1387                vec![false, false, true, true, false, false, false],
1388            ),
1389            (
1390                "over AND -(-(fox OR jumps))",
1391                vec![true, true, true, true, false, true, true],
1392            ),
1393        ];
1394
1395        let f = MatchesFunction::default();
1396        for (pattern, expected) in cases {
1397            let args = ScalarFunctionArgs {
1398                args: vec![
1399                    ColumnarValue::Array(col.clone()),
1400                    ColumnarValue::Scalar(ScalarValue::Utf8View(Some(pattern.to_string()))),
1401                ],
1402                arg_fields: vec![],
1403                number_rows: col.len(),
1404                return_field: Arc::new(Field::new("x", col.data_type().clone(), true)),
1405                config_options: Arc::new(ConfigOptions::new()),
1406            };
1407            let actual = f
1408                .invoke_with_args(args)
1409                .and_then(|x| x.to_array(col.len()))
1410                .unwrap();
1411            let expected: ArrayRef = Arc::new(BooleanArray::from(expected));
1412            assert_eq!(expected.as_ref(), actual.as_ref(), "{pattern}");
1413        }
1414    }
1415}