common_function/scalars/
matches.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt;
17use std::sync::Arc;
18
19use common_query::error::{IntoVectorSnafu, InvalidFuncArgsSnafu, InvalidInputTypeSnafu, Result};
20use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion};
21use datafusion::common::{DFSchema, Result as DfResult};
22use datafusion::execution::SessionStateBuilder;
23use datafusion::logical_expr::{self, Expr, Volatility};
24use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
25use datafusion_expr::Signature;
26use datatypes::arrow::array::RecordBatch;
27use datatypes::arrow::datatypes::{DataType, Field};
28use datatypes::prelude::VectorRef;
29use datatypes::vectors::BooleanVector;
30use snafu::{OptionExt, ResultExt, ensure};
31use store_api::storage::ConcreteDataType;
32
33use crate::function::{Function, FunctionContext};
34use crate::function_registry::FunctionRegistry;
35
36/// `matches` for full text search.
37///
38/// Usage: matches(`<col>`, `<pattern>`) -> boolean
39#[derive(Clone, Debug, Default)]
40pub struct MatchesFunction;
41
42impl MatchesFunction {
43    pub fn register(registry: &FunctionRegistry) {
44        registry.register_scalar(MatchesFunction);
45    }
46}
47
48impl fmt::Display for MatchesFunction {
49    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
50        write!(f, "MATCHES")
51    }
52}
53
54impl Function for MatchesFunction {
55    fn name(&self) -> &str {
56        "matches"
57    }
58
59    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
60        Ok(DataType::Boolean)
61    }
62
63    fn signature(&self) -> Signature {
64        Signature::string(2, Volatility::Immutable)
65    }
66
67    // TODO: read case-sensitive config
68    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
69        ensure!(
70            columns.len() == 2,
71            InvalidFuncArgsSnafu {
72                err_msg: format!(
73                    "The length of the args is not correct, expect exactly 2, have: {}",
74                    columns.len()
75                ),
76            }
77        );
78
79        let data_column = &columns[0];
80        if data_column.is_empty() {
81            return Ok(Arc::new(BooleanVector::from(Vec::<bool>::with_capacity(0))));
82        }
83
84        let pattern_vector = &columns[1]
85            .cast(&ConcreteDataType::string_datatype())
86            .context(InvalidInputTypeSnafu {
87                err_msg: "cannot cast `pattern` to string",
88            })?;
89        // Safety: both length and type are checked before
90        let pattern = pattern_vector.get(0).as_string().unwrap();
91        self.eval(data_column, pattern)
92    }
93}
94
95impl MatchesFunction {
96    fn eval(&self, data: &VectorRef, pattern: String) -> Result<VectorRef> {
97        let col_name = "data";
98        let parser_context = ParserContext::default();
99        let raw_ast = parser_context.parse_pattern(&pattern)?;
100        let ast = raw_ast.transform_ast()?;
101
102        let like_expr = ast.into_like_expr(col_name);
103
104        let input_schema = Self::input_schema();
105        let session_state = SessionStateBuilder::new().with_default_features().build();
106        let planner = DefaultPhysicalPlanner::default();
107        let physical_expr =
108            planner.create_physical_expr(&like_expr, &input_schema, &session_state)?;
109
110        let data_array = data.to_arrow_array();
111        let arrow_schema = Arc::new(input_schema.as_arrow().clone());
112        let input_record_batch = RecordBatch::try_new(arrow_schema, vec![data_array]).unwrap();
113
114        let num_rows = input_record_batch.num_rows();
115        let result = physical_expr.evaluate(&input_record_batch)?;
116        let result_array = result.into_array(num_rows)?;
117        let result_vector =
118            BooleanVector::try_from_arrow_array(result_array).context(IntoVectorSnafu {
119                data_type: DataType::Boolean,
120            })?;
121
122        Ok(Arc::new(result_vector))
123    }
124
125    fn input_schema() -> DFSchema {
126        DFSchema::from_unqualified_fields(
127            [Arc::new(Field::new("data", DataType::Utf8, true))].into(),
128            HashMap::new(),
129        )
130        .unwrap()
131    }
132}
133
134#[derive(Debug, Clone, PartialEq, Eq)]
135enum PatternAst {
136    // Distinguish this with `Group` for simplicity
137    /// A leaf node that matches a column with `pattern`
138    Literal { op: UnaryOp, pattern: String },
139    /// Flattened binary chains
140    Binary {
141        op: BinaryOp,
142        children: Vec<PatternAst>,
143    },
144    /// A sub-tree enclosed by parenthesis
145    Group { op: UnaryOp, child: Box<PatternAst> },
146}
147
148#[derive(Debug, Copy, Clone, PartialEq, Eq)]
149enum UnaryOp {
150    Must,
151    Optional,
152    Negative,
153}
154
155#[derive(Debug, Copy, Clone, PartialEq, Eq)]
156enum BinaryOp {
157    And,
158    Or,
159}
160
161impl PatternAst {
162    fn into_like_expr(self, column: &str) -> Expr {
163        match self {
164            PatternAst::Literal { op, pattern } => {
165                let expr = Self::convert_literal(column, &pattern);
166                match op {
167                    UnaryOp::Must => expr,
168                    UnaryOp::Optional => expr,
169                    UnaryOp::Negative => logical_expr::not(expr),
170                }
171            }
172            PatternAst::Binary { op, children } => {
173                if children.is_empty() {
174                    return logical_expr::lit(true);
175                }
176                let exprs = children
177                    .into_iter()
178                    .map(|child| child.into_like_expr(column));
179                // safety: children is not empty
180                match op {
181                    BinaryOp::And => exprs.reduce(Expr::and).unwrap(),
182                    BinaryOp::Or => exprs.reduce(Expr::or).unwrap(),
183                }
184            }
185            PatternAst::Group { op, child } => {
186                let child = child.into_like_expr(column);
187                match op {
188                    UnaryOp::Must => child,
189                    UnaryOp::Optional => child,
190                    UnaryOp::Negative => logical_expr::not(child),
191                }
192            }
193        }
194    }
195
196    fn convert_literal(column: &str, pattern: &str) -> Expr {
197        logical_expr::col(column).like(logical_expr::lit(format!(
198            "%{}%",
199            crate::utils::escape_like_pattern(pattern)
200        )))
201    }
202
203    /// Transform this AST with preset rules to make it correct.
204    fn transform_ast(self) -> Result<Self> {
205        self.transform_up(Self::collapse_binary_branch_fn)
206            .map(|data| data.data)?
207            .transform_up(Self::eliminate_optional_fn)
208            .map(|data| data.data)?
209            .transform_down(Self::eliminate_single_child_fn)
210            .map(|data| data.data)
211            .map_err(Into::into)
212    }
213
214    /// Collapse binary branch with the same operator. I.e., this transformer
215    /// changes the binary-tree AST into a multiple branching AST.
216    ///
217    /// This function is expected to be called in a bottom-up manner as
218    /// it won't recursion.
219    fn collapse_binary_branch_fn(self) -> DfResult<Transformed<Self>> {
220        let PatternAst::Binary {
221            op: parent_op,
222            children,
223        } = self
224        else {
225            return Ok(Transformed::no(self));
226        };
227
228        let mut collapsed = vec![];
229        let mut remains = vec![];
230
231        for child in children {
232            match child {
233                PatternAst::Literal { .. } | PatternAst::Group { .. } => {
234                    collapsed.push(child);
235                }
236                PatternAst::Binary { op, children } => {
237                    // no need to recursion because this function is expected to be called
238                    // in a bottom-up manner
239                    if op == parent_op {
240                        collapsed.extend(children);
241                    } else {
242                        remains.push(PatternAst::Binary { op, children });
243                    }
244                }
245            }
246        }
247
248        if collapsed.is_empty() {
249            Ok(Transformed::no(PatternAst::Binary {
250                op: parent_op,
251                children: remains,
252            }))
253        } else {
254            collapsed.extend(remains);
255            Ok(Transformed::yes(PatternAst::Binary {
256                op: parent_op,
257                children: collapsed,
258            }))
259        }
260    }
261
262    /// Eliminate optional pattern. An optional pattern can always be
263    /// omitted or transformed into a must pattern follows the following rules:
264    /// - If there is only one pattern and it's optional, change it to must
265    /// - If there is any must pattern, remove all other optional patterns
266    fn eliminate_optional_fn(self) -> DfResult<Transformed<Self>> {
267        let PatternAst::Binary {
268            op: parent_op,
269            children,
270        } = self
271        else {
272            return Ok(Transformed::no(self));
273        };
274
275        if parent_op == BinaryOp::Or {
276            let mut must_list = vec![];
277            let mut must_not_list = vec![];
278            let mut optional_list = vec![];
279            let mut compound_list = vec![];
280
281            for child in children {
282                match child {
283                    PatternAst::Literal { op, .. } | PatternAst::Group { op, .. } => match op {
284                        UnaryOp::Must => must_list.push(child),
285                        UnaryOp::Optional => optional_list.push(child),
286                        UnaryOp::Negative => must_not_list.push(child),
287                    },
288                    PatternAst::Binary { .. } => {
289                        compound_list.push(child);
290                    }
291                }
292            }
293
294            // Eliminate optional list if there is MUST.
295            if !must_list.is_empty() {
296                optional_list.clear();
297            }
298
299            let children_this_level = optional_list.into_iter().chain(compound_list).collect();
300            let new_node = if !must_list.is_empty() || !must_not_list.is_empty() {
301                let new_children = must_list
302                    .into_iter()
303                    .chain(must_not_list)
304                    .chain(Some(PatternAst::Binary {
305                        op: BinaryOp::Or,
306                        children: children_this_level,
307                    }))
308                    .collect();
309                PatternAst::Binary {
310                    op: BinaryOp::And,
311                    children: new_children,
312                }
313            } else {
314                PatternAst::Binary {
315                    op: BinaryOp::Or,
316                    children: children_this_level,
317                }
318            };
319
320            return Ok(Transformed::yes(new_node));
321        }
322
323        Ok(Transformed::no(PatternAst::Binary {
324            op: parent_op,
325            children,
326        }))
327    }
328
329    /// Eliminate single child [`PatternAst::Binary`] node. If a binary node has only one child, it can be
330    /// replaced by its only child.
331    ///
332    /// This function prefers to be applied in a top-down manner. But it's not required.
333    fn eliminate_single_child_fn(self) -> DfResult<Transformed<Self>> {
334        let PatternAst::Binary { op, mut children } = self else {
335            return Ok(Transformed::no(self));
336        };
337
338        // remove empty grand children
339        children.retain(|child| match child {
340            PatternAst::Binary {
341                children: grand_children,
342                ..
343            } => !grand_children.is_empty(),
344            PatternAst::Literal { .. } | PatternAst::Group { .. } => true,
345        });
346
347        if children.len() == 1 {
348            Ok(Transformed::yes(children.into_iter().next().unwrap()))
349        } else {
350            Ok(Transformed::no(PatternAst::Binary { op, children }))
351        }
352    }
353}
354
355impl TreeNode for PatternAst {
356    fn apply_children<'n, F: FnMut(&'n Self) -> DfResult<TreeNodeRecursion>>(
357        &'n self,
358        mut f: F,
359    ) -> DfResult<TreeNodeRecursion> {
360        match self {
361            PatternAst::Literal { .. } => Ok(TreeNodeRecursion::Continue),
362            PatternAst::Binary { op: _, children } => {
363                for child in children {
364                    if TreeNodeRecursion::Stop == f(child)? {
365                        return Ok(TreeNodeRecursion::Stop);
366                    }
367                }
368                Ok(TreeNodeRecursion::Continue)
369            }
370            PatternAst::Group { op: _, child } => f(child),
371        }
372    }
373
374    fn map_children<F: FnMut(Self) -> DfResult<Transformed<Self>>>(
375        self,
376        mut f: F,
377    ) -> DfResult<Transformed<Self>> {
378        match self {
379            PatternAst::Literal { .. } => Ok(Transformed::no(self)),
380            PatternAst::Binary { op, children } => children
381                .into_iter()
382                .map_until_stop_and_collect(&mut f)?
383                .map_data(|new_children| {
384                    Ok(PatternAst::Binary {
385                        op,
386                        children: new_children,
387                    })
388                }),
389            PatternAst::Group { op, child } => f(*child)?.map_data(|new_child| {
390                Ok(PatternAst::Group {
391                    op,
392                    child: Box::new(new_child),
393                })
394            }),
395        }
396    }
397}
398
399#[derive(Default)]
400struct ParserContext {
401    stack: Vec<PatternAst>,
402}
403
404impl ParserContext {
405    pub fn parse_pattern(mut self, pattern: &str) -> Result<PatternAst> {
406        let tokenizer = Tokenizer::default();
407        let raw_tokens = tokenizer.tokenize(pattern)?;
408        let raw_tokens = Self::accomplish_optional_unary_op(raw_tokens)?;
409        let mut tokens = Self::to_rpn(raw_tokens)?;
410
411        while !tokens.is_empty() {
412            self.parse_one_impl(&mut tokens)?;
413        }
414
415        ensure!(
416            !self.stack.is_empty(),
417            InvalidFuncArgsSnafu {
418                err_msg: "Empty pattern",
419            }
420        );
421
422        // conjoin them together
423        if self.stack.len() == 1 {
424            Ok(self.stack.pop().unwrap())
425        } else {
426            Ok(PatternAst::Binary {
427                op: BinaryOp::Or,
428                children: self.stack,
429            })
430        }
431    }
432
433    /// Add [`Token::Optional`] for all bare [`Token::Phase`] and [`Token::Or`]
434    /// for all adjacent [`Token::Phase`]s.
435    ///
436    /// This function also does some checks by the way. Like if two unary ops are
437    /// adjacent.
438    fn accomplish_optional_unary_op(raw_tokens: Vec<Token>) -> Result<Vec<Token>> {
439        let mut is_prev_unary_op = false;
440        // The first one doesn't need binary op
441        let mut is_binary_op_before = true;
442        let mut is_unary_op_before = false;
443        let mut new_tokens = Vec::with_capacity(raw_tokens.len());
444        for token in raw_tokens {
445            // fill `Token::Or`
446            if !is_binary_op_before
447                && matches!(
448                    token,
449                    Token::Phase(_)
450                        | Token::OpenParen
451                        | Token::Must
452                        | Token::Optional
453                        | Token::Negative
454                )
455            {
456                is_binary_op_before = true;
457                new_tokens.push(Token::Or);
458            }
459            if matches!(
460                token,
461                Token::OpenParen // treat open paren as begin of new group
462                | Token::And | Token::Or
463            ) {
464                is_binary_op_before = true;
465            } else if matches!(token, Token::Phase(_) | Token::CloseParen) {
466                // need binary op next time
467                is_binary_op_before = false;
468            }
469
470            // fill `Token::Optional`
471            if !is_prev_unary_op && matches!(token, Token::Phase(_) | Token::OpenParen) {
472                new_tokens.push(Token::Optional);
473            } else {
474                is_prev_unary_op = matches!(token, Token::Must | Token::Negative);
475            }
476
477            // check if unary ops are adjacent by the way
478            if matches!(token, Token::Must | Token::Optional | Token::Negative) {
479                if is_unary_op_before {
480                    return InvalidFuncArgsSnafu {
481                        err_msg: "Invalid pattern, unary operators should not be adjacent",
482                    }
483                    .fail();
484                }
485                is_unary_op_before = true;
486            } else {
487                is_unary_op_before = false;
488            }
489
490            new_tokens.push(token);
491        }
492
493        Ok(new_tokens)
494    }
495
496    /// Convert infix token stream to RPN
497    fn to_rpn(mut raw_tokens: Vec<Token>) -> Result<Vec<Token>> {
498        let mut operator_stack = vec![];
499        let mut result = vec![];
500        raw_tokens.reverse();
501
502        while let Some(token) = raw_tokens.pop() {
503            match token {
504                Token::Phase(_) => result.push(token),
505                Token::Must | Token::Negative | Token::Optional => {
506                    operator_stack.push(token);
507                }
508                Token::OpenParen => operator_stack.push(token),
509                Token::And | Token::Or => {
510                    // - Or has lower priority than And
511                    // - Binary op have lower priority than unary op
512                    while let Some(stack_top) = operator_stack.last()
513                        && ((*stack_top == Token::And && token == Token::Or)
514                            || matches!(
515                                *stack_top,
516                                Token::Must | Token::Optional | Token::Negative
517                            ))
518                    {
519                        result.push(operator_stack.pop().unwrap());
520                    }
521                    operator_stack.push(token);
522                }
523                Token::CloseParen => {
524                    let mut is_open_paren_found = false;
525                    while let Some(op) = operator_stack.pop() {
526                        if op == Token::OpenParen {
527                            is_open_paren_found = true;
528                            break;
529                        }
530                        result.push(op);
531                    }
532                    if !is_open_paren_found {
533                        return InvalidFuncArgsSnafu {
534                            err_msg: "Unmatched close parentheses",
535                        }
536                        .fail();
537                    }
538                }
539            }
540        }
541
542        while let Some(operator) = operator_stack.pop() {
543            if operator == Token::OpenParen {
544                return InvalidFuncArgsSnafu {
545                    err_msg: "Unmatched parentheses",
546                }
547                .fail();
548            }
549            result.push(operator);
550        }
551
552        Ok(result)
553    }
554
555    fn parse_one_impl(&mut self, tokens: &mut Vec<Token>) -> Result<()> {
556        if let Some(token) = tokens.pop() {
557            match token {
558                Token::Must => {
559                    if self.stack.is_empty() {
560                        self.parse_one_impl(tokens)?;
561                    }
562                    let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
563                        err_msg: "Invalid pattern, \"+\" operator should have one operand",
564                    })?;
565                    match phase_or_group {
566                        PatternAst::Literal { op: _, pattern } => {
567                            self.stack.push(PatternAst::Literal {
568                                op: UnaryOp::Must,
569                                pattern,
570                            });
571                        }
572                        PatternAst::Binary { .. } | PatternAst::Group { .. } => {
573                            self.stack.push(PatternAst::Group {
574                                op: UnaryOp::Must,
575                                child: Box::new(phase_or_group),
576                            })
577                        }
578                    }
579                    return Ok(());
580                }
581                Token::Negative => {
582                    if self.stack.is_empty() {
583                        self.parse_one_impl(tokens)?;
584                    }
585                    let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
586                        err_msg: "Invalid pattern, \"-\" operator should have one operand",
587                    })?;
588                    match phase_or_group {
589                        PatternAst::Literal { op: _, pattern } => {
590                            self.stack.push(PatternAst::Literal {
591                                op: UnaryOp::Negative,
592                                pattern,
593                            });
594                        }
595                        PatternAst::Binary { .. } | PatternAst::Group { .. } => {
596                            self.stack.push(PatternAst::Group {
597                                op: UnaryOp::Negative,
598                                child: Box::new(phase_or_group),
599                            })
600                        }
601                    }
602                    return Ok(());
603                }
604                Token::Optional => {
605                    if self.stack.is_empty() {
606                        self.parse_one_impl(tokens)?;
607                    }
608                    let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
609                        err_msg:
610                            "Invalid pattern, OPTIONAL(space) operator should have one operand",
611                    })?;
612                    match phase_or_group {
613                        PatternAst::Literal { op: _, pattern } => {
614                            self.stack.push(PatternAst::Literal {
615                                op: UnaryOp::Optional,
616                                pattern,
617                            });
618                        }
619                        PatternAst::Binary { .. } | PatternAst::Group { .. } => {
620                            self.stack.push(PatternAst::Group {
621                                op: UnaryOp::Optional,
622                                child: Box::new(phase_or_group),
623                            })
624                        }
625                    }
626                    return Ok(());
627                }
628                Token::Phase(pattern) => {
629                    self.stack.push(PatternAst::Literal {
630                        // Op here is a placeholder
631                        op: UnaryOp::Optional,
632                        pattern,
633                    })
634                }
635                Token::And => {
636                    if self.stack.is_empty() {
637                        self.parse_one_impl(tokens)?;
638                    };
639                    let rhs = self.stack.pop().context(InvalidFuncArgsSnafu {
640                        err_msg: "Invalid pattern, \"AND\" operator should have two operands",
641                    })?;
642                    if self.stack.is_empty() {
643                        self.parse_one_impl(tokens)?
644                    };
645                    let lhs = self.stack.pop().context(InvalidFuncArgsSnafu {
646                        err_msg: "Invalid pattern, \"AND\" operator should have two operands",
647                    })?;
648                    self.stack.push(PatternAst::Binary {
649                        op: BinaryOp::And,
650                        children: vec![lhs, rhs],
651                    });
652                    return Ok(());
653                }
654                Token::Or => {
655                    if self.stack.is_empty() {
656                        self.parse_one_impl(tokens)?
657                    };
658                    let rhs = self.stack.pop().context(InvalidFuncArgsSnafu {
659                        err_msg: "Invalid pattern, \"OR\" operator should have two operands",
660                    })?;
661                    if self.stack.is_empty() {
662                        self.parse_one_impl(tokens)?
663                    };
664                    let lhs = self.stack.pop().context(InvalidFuncArgsSnafu {
665                        err_msg: "Invalid pattern, \"OR\" operator should have two operands",
666                    })?;
667                    self.stack.push(PatternAst::Binary {
668                        op: BinaryOp::Or,
669                        children: vec![lhs, rhs],
670                    });
671                    return Ok(());
672                }
673                Token::OpenParen | Token::CloseParen => {
674                    return InvalidFuncArgsSnafu {
675                        err_msg: "Unexpected parentheses",
676                    }
677                    .fail();
678                }
679            }
680        }
681
682        Ok(())
683    }
684}
685
686#[derive(Clone, Debug, PartialEq, Eq)]
687enum Token {
688    /// "+"
689    Must,
690    /// "-"
691    Negative,
692    /// "AND"
693    And,
694    /// "OR"
695    Or,
696    /// "("
697    OpenParen,
698    /// ")"
699    CloseParen,
700    /// Any other phases
701    Phase(String),
702
703    /// This is not a token from user input, but a placeholder for internal use.
704    /// It's used to accomplish the unary operator class with Must and Negative.
705    /// In user provided pattern, optional is expressed by a bare phase or group
706    /// (simply nothing or writespace).
707    Optional,
708}
709
710#[derive(Default)]
711struct Tokenizer {
712    cursor: usize,
713}
714
715impl Tokenizer {
716    pub fn tokenize(mut self, pattern: &str) -> Result<Vec<Token>> {
717        let mut tokens = vec![];
718        let char_len = pattern.chars().count();
719        while self.cursor < char_len {
720            // TODO: collect pattern into Vec<char> if this tokenizer is bottleneck in the future
721            let c = pattern.chars().nth(self.cursor).unwrap();
722            match c {
723                '+' => tokens.push(Token::Must),
724                '-' => tokens.push(Token::Negative),
725                '(' => tokens.push(Token::OpenParen),
726                ')' => tokens.push(Token::CloseParen),
727                ' ' => {
728                    if let Some(last_token) = tokens.last() {
729                        match last_token {
730                            Token::Must | Token::Negative => {
731                                return InvalidFuncArgsSnafu {
732                                    err_msg: format!("Unexpected space after {:?}", last_token),
733                                }
734                                .fail();
735                            }
736                            _ => {}
737                        }
738                    }
739                }
740                '\"' => {
741                    self.step_next();
742                    let phase = self.consume_next_phase(true, pattern)?;
743                    tokens.push(Token::Phase(phase));
744                    // consume a writespace (or EOF) after quotes
745                    if let Some(ending_separator) = self.consume_next(pattern)
746                        && ending_separator != ' '
747                    {
748                        return InvalidFuncArgsSnafu {
749                            err_msg: "Expect a space after quotes ('\"')",
750                        }
751                        .fail();
752                    }
753                }
754                _ => {
755                    let phase = self.consume_next_phase(false, pattern)?;
756                    match phase.to_uppercase().as_str() {
757                        "AND" => tokens.push(Token::And),
758                        "OR" => tokens.push(Token::Or),
759                        _ => tokens.push(Token::Phase(phase)),
760                    }
761                }
762            }
763            self.cursor += 1;
764        }
765        Ok(tokens)
766    }
767
768    fn consume_next(&mut self, pattern: &str) -> Option<char> {
769        self.cursor += 1;
770        pattern.chars().nth(self.cursor)
771    }
772
773    fn step_next(&mut self) {
774        self.cursor += 1;
775    }
776
777    fn rewind_one(&mut self) {
778        self.cursor -= 1;
779    }
780
781    /// Current `cursor` points to the first character of the phase.
782    /// If the phase is enclosed by double quotes, consume the start quote before calling this.
783    fn consume_next_phase(&mut self, is_quoted: bool, pattern: &str) -> Result<String> {
784        let mut phase = String::new();
785        let mut is_quote_present = false;
786
787        let char_len = pattern.chars().count();
788        while self.cursor < char_len {
789            let mut c = pattern.chars().nth(self.cursor).unwrap();
790
791            match c {
792                '\"' => {
793                    is_quote_present = true;
794                    break;
795                }
796                ' ' => {
797                    if !is_quoted {
798                        break;
799                    }
800                }
801                '(' | ')' | '+' | '-' => {
802                    if !is_quoted {
803                        self.rewind_one();
804                        break;
805                    }
806                }
807                '\\' => {
808                    let Some(next) = self.consume_next(pattern) else {
809                        return InvalidFuncArgsSnafu {
810                            err_msg: "Unexpected end of pattern, expected a character after escape ('\\')",
811                        }.fail();
812                    };
813                    // it doesn't check whether the escaped character is valid or not
814                    c = next;
815                }
816                _ => {}
817            }
818
819            phase.push(c);
820            self.cursor += 1;
821        }
822
823        if is_quoted ^ is_quote_present {
824            return InvalidFuncArgsSnafu {
825                err_msg: "Unclosed quotes ('\"')",
826            }
827            .fail();
828        }
829
830        Ok(phase)
831    }
832}
833
834#[cfg(test)]
835mod test {
836    use datatypes::vectors::StringVector;
837
838    use super::*;
839
840    #[test]
841    fn valid_matches_tokenizer() {
842        use Token::*;
843        let cases = [
844            (
845                "a +b -c",
846                vec![
847                    Phase("a".to_string()),
848                    Must,
849                    Phase("b".to_string()),
850                    Negative,
851                    Phase("c".to_string()),
852                ],
853            ),
854            (
855                "+a(b-c)",
856                vec![
857                    Must,
858                    Phase("a".to_string()),
859                    OpenParen,
860                    Phase("b".to_string()),
861                    Negative,
862                    Phase("c".to_string()),
863                    CloseParen,
864                ],
865            ),
866            (
867                r#"Barack Obama"#,
868                vec![Phase("Barack".to_string()), Phase("Obama".to_string())],
869            ),
870            (
871                r#"+apple +fruit"#,
872                vec![
873                    Must,
874                    Phase("apple".to_string()),
875                    Must,
876                    Phase("fruit".to_string()),
877                ],
878            ),
879            (
880                r#""He said \"hello\"""#,
881                vec![Phase("He said \"hello\"".to_string())],
882            ),
883            (
884                r#"a AND b OR c"#,
885                vec![
886                    Phase("a".to_string()),
887                    And,
888                    Phase("b".to_string()),
889                    Or,
890                    Phase("c".to_string()),
891                ],
892            ),
893            (
894                r#"中文 测试"#,
895                vec![Phase("中文".to_string()), Phase("测试".to_string())],
896            ),
897            (
898                r#"中文 AND 测试"#,
899                vec![Phase("中文".to_string()), And, Phase("测试".to_string())],
900            ),
901            (
902                r#"中文 +测试"#,
903                vec![Phase("中文".to_string()), Must, Phase("测试".to_string())],
904            ),
905            (
906                r#"中文 -测试"#,
907                vec![
908                    Phase("中文".to_string()),
909                    Negative,
910                    Phase("测试".to_string()),
911                ],
912            ),
913        ];
914
915        for (query, expected) in cases {
916            let tokenizer = Tokenizer::default();
917            let tokens = tokenizer.tokenize(query).unwrap();
918            assert_eq!(expected, tokens, "{query}");
919        }
920    }
921
922    #[test]
923    fn invalid_matches_tokenizer() {
924        let cases = [
925            (r#""He said "hello""#, "Expect a space after quotes"),
926            (r#""He said hello"#, "Unclosed quotes"),
927            (r#"a + b - c"#, "Unexpected space after"),
928            (r#"ab "c"def"#, "Expect a space after quotes"),
929        ];
930
931        for (query, expected) in cases {
932            let tokenizer = Tokenizer::default();
933            let result = tokenizer.tokenize(query);
934            assert!(result.is_err(), "{query}");
935            let actual_error = result.unwrap_err().to_string();
936            assert!(actual_error.contains(expected), "{query}, {actual_error}");
937        }
938    }
939
940    #[test]
941    fn valid_ast_transformer() {
942        let cases = [
943            (
944                "a AND b OR c",
945                PatternAst::Binary {
946                    op: BinaryOp::Or,
947                    children: vec![
948                        PatternAst::Literal {
949                            op: UnaryOp::Optional,
950                            pattern: "c".to_string(),
951                        },
952                        PatternAst::Binary {
953                            op: BinaryOp::And,
954                            children: vec![
955                                PatternAst::Literal {
956                                    op: UnaryOp::Optional,
957                                    pattern: "a".to_string(),
958                                },
959                                PatternAst::Literal {
960                                    op: UnaryOp::Optional,
961                                    pattern: "b".to_string(),
962                                },
963                            ],
964                        },
965                    ],
966                },
967            ),
968            (
969                "a -b",
970                PatternAst::Binary {
971                    op: BinaryOp::And,
972                    children: vec![
973                        PatternAst::Literal {
974                            op: UnaryOp::Negative,
975                            pattern: "b".to_string(),
976                        },
977                        PatternAst::Literal {
978                            op: UnaryOp::Optional,
979                            pattern: "a".to_string(),
980                        },
981                    ],
982                },
983            ),
984            (
985                "a +b",
986                PatternAst::Literal {
987                    op: UnaryOp::Must,
988                    pattern: "b".to_string(),
989                },
990            ),
991            (
992                "a b c d",
993                PatternAst::Binary {
994                    op: BinaryOp::Or,
995                    children: vec![
996                        PatternAst::Literal {
997                            op: UnaryOp::Optional,
998                            pattern: "a".to_string(),
999                        },
1000                        PatternAst::Literal {
1001                            op: UnaryOp::Optional,
1002                            pattern: "b".to_string(),
1003                        },
1004                        PatternAst::Literal {
1005                            op: UnaryOp::Optional,
1006                            pattern: "c".to_string(),
1007                        },
1008                        PatternAst::Literal {
1009                            op: UnaryOp::Optional,
1010                            pattern: "d".to_string(),
1011                        },
1012                    ],
1013                },
1014            ),
1015            (
1016                "a b c AND d",
1017                PatternAst::Binary {
1018                    op: BinaryOp::Or,
1019                    children: vec![
1020                        PatternAst::Literal {
1021                            op: UnaryOp::Optional,
1022                            pattern: "a".to_string(),
1023                        },
1024                        PatternAst::Literal {
1025                            op: UnaryOp::Optional,
1026                            pattern: "b".to_string(),
1027                        },
1028                        PatternAst::Binary {
1029                            op: BinaryOp::And,
1030                            children: vec![
1031                                PatternAst::Literal {
1032                                    op: UnaryOp::Optional,
1033                                    pattern: "c".to_string(),
1034                                },
1035                                PatternAst::Literal {
1036                                    op: UnaryOp::Optional,
1037                                    pattern: "d".to_string(),
1038                                },
1039                            ],
1040                        },
1041                    ],
1042                },
1043            ),
1044            (
1045                r#"中文 测试"#,
1046                PatternAst::Binary {
1047                    op: BinaryOp::Or,
1048                    children: vec![
1049                        PatternAst::Literal {
1050                            op: UnaryOp::Optional,
1051                            pattern: "中文".to_string(),
1052                        },
1053                        PatternAst::Literal {
1054                            op: UnaryOp::Optional,
1055                            pattern: "测试".to_string(),
1056                        },
1057                    ],
1058                },
1059            ),
1060            (
1061                r#"中文 AND 测试"#,
1062                PatternAst::Binary {
1063                    op: BinaryOp::And,
1064                    children: vec![
1065                        PatternAst::Literal {
1066                            op: UnaryOp::Optional,
1067                            pattern: "中文".to_string(),
1068                        },
1069                        PatternAst::Literal {
1070                            op: UnaryOp::Optional,
1071                            pattern: "测试".to_string(),
1072                        },
1073                    ],
1074                },
1075            ),
1076            (
1077                r#"中文 +测试"#,
1078                PatternAst::Literal {
1079                    op: UnaryOp::Must,
1080                    pattern: "测试".to_string(),
1081                },
1082            ),
1083            (
1084                r#"中文 -测试"#,
1085                PatternAst::Binary {
1086                    op: BinaryOp::And,
1087                    children: vec![
1088                        PatternAst::Literal {
1089                            op: UnaryOp::Negative,
1090                            pattern: "测试".to_string(),
1091                        },
1092                        PatternAst::Literal {
1093                            op: UnaryOp::Optional,
1094                            pattern: "中文".to_string(),
1095                        },
1096                    ],
1097                },
1098            ),
1099        ];
1100
1101        for (query, expected) in cases {
1102            let parser = ParserContext { stack: vec![] };
1103            let ast = parser.parse_pattern(query).unwrap();
1104            let ast = ast.transform_ast().unwrap();
1105            assert_eq!(expected, ast, "{query}");
1106        }
1107    }
1108
1109    #[test]
1110    fn invalid_ast() {
1111        let cases = [
1112            (r#"a b (c"#, "Unmatched parentheses"),
1113            (r#"a b) c"#, "Unmatched close parentheses"),
1114            (r#"a +-b"#, "unary operators should not be adjacent"),
1115        ];
1116
1117        for (query, expected) in cases {
1118            let result: Result<()> = try {
1119                let parser = ParserContext { stack: vec![] };
1120                let ast = parser.parse_pattern(query)?;
1121                let _ast = ast.transform_ast()?;
1122            };
1123
1124            assert!(result.is_err(), "{query}");
1125            let actual_error = result.unwrap_err().to_string();
1126            assert!(actual_error.contains(expected), "{query}, {actual_error}");
1127        }
1128    }
1129
1130    #[test]
1131    fn valid_matches_parser() {
1132        let cases = [
1133            (
1134                "a AND b OR c",
1135                PatternAst::Binary {
1136                    op: BinaryOp::Or,
1137                    children: vec![
1138                        PatternAst::Binary {
1139                            op: BinaryOp::And,
1140                            children: vec![
1141                                PatternAst::Literal {
1142                                    op: UnaryOp::Optional,
1143                                    pattern: "a".to_string(),
1144                                },
1145                                PatternAst::Literal {
1146                                    op: UnaryOp::Optional,
1147                                    pattern: "b".to_string(),
1148                                },
1149                            ],
1150                        },
1151                        PatternAst::Literal {
1152                            op: UnaryOp::Optional,
1153                            pattern: "c".to_string(),
1154                        },
1155                    ],
1156                },
1157            ),
1158            (
1159                "(a AND b) OR c",
1160                PatternAst::Binary {
1161                    op: BinaryOp::Or,
1162                    children: vec![
1163                        PatternAst::Group {
1164                            op: UnaryOp::Optional,
1165                            child: Box::new(PatternAst::Binary {
1166                                op: BinaryOp::And,
1167                                children: vec![
1168                                    PatternAst::Literal {
1169                                        op: UnaryOp::Optional,
1170                                        pattern: "a".to_string(),
1171                                    },
1172                                    PatternAst::Literal {
1173                                        op: UnaryOp::Optional,
1174                                        pattern: "b".to_string(),
1175                                    },
1176                                ],
1177                            }),
1178                        },
1179                        PatternAst::Literal {
1180                            op: UnaryOp::Optional,
1181                            pattern: "c".to_string(),
1182                        },
1183                    ],
1184                },
1185            ),
1186            (
1187                "a AND (b OR c)",
1188                PatternAst::Binary {
1189                    op: BinaryOp::And,
1190                    children: vec![
1191                        PatternAst::Literal {
1192                            op: UnaryOp::Optional,
1193                            pattern: "a".to_string(),
1194                        },
1195                        PatternAst::Group {
1196                            op: UnaryOp::Optional,
1197                            child: Box::new(PatternAst::Binary {
1198                                op: BinaryOp::Or,
1199                                children: vec![
1200                                    PatternAst::Literal {
1201                                        op: UnaryOp::Optional,
1202                                        pattern: "b".to_string(),
1203                                    },
1204                                    PatternAst::Literal {
1205                                        op: UnaryOp::Optional,
1206                                        pattern: "c".to_string(),
1207                                    },
1208                                ],
1209                            }),
1210                        },
1211                    ],
1212                },
1213            ),
1214            (
1215                "a +b -c",
1216                PatternAst::Binary {
1217                    op: BinaryOp::Or,
1218                    children: vec![
1219                        PatternAst::Literal {
1220                            op: UnaryOp::Optional,
1221                            pattern: "a".to_string(),
1222                        },
1223                        PatternAst::Binary {
1224                            op: BinaryOp::Or,
1225                            children: vec![
1226                                PatternAst::Literal {
1227                                    op: UnaryOp::Must,
1228                                    pattern: "b".to_string(),
1229                                },
1230                                PatternAst::Literal {
1231                                    op: UnaryOp::Negative,
1232                                    pattern: "c".to_string(),
1233                                },
1234                            ],
1235                        },
1236                    ],
1237                },
1238            ),
1239            (
1240                "(+a +b) c",
1241                PatternAst::Binary {
1242                    op: BinaryOp::Or,
1243                    children: vec![
1244                        PatternAst::Group {
1245                            op: UnaryOp::Optional,
1246                            child: Box::new(PatternAst::Binary {
1247                                op: BinaryOp::Or,
1248                                children: vec![
1249                                    PatternAst::Literal {
1250                                        op: UnaryOp::Must,
1251                                        pattern: "a".to_string(),
1252                                    },
1253                                    PatternAst::Literal {
1254                                        op: UnaryOp::Must,
1255                                        pattern: "b".to_string(),
1256                                    },
1257                                ],
1258                            }),
1259                        },
1260                        PatternAst::Literal {
1261                            op: UnaryOp::Optional,
1262                            pattern: "c".to_string(),
1263                        },
1264                    ],
1265                },
1266            ),
1267            (
1268                "\"AND\" AnD \"OR\"",
1269                PatternAst::Binary {
1270                    op: BinaryOp::And,
1271                    children: vec![
1272                        PatternAst::Literal {
1273                            op: UnaryOp::Optional,
1274                            pattern: "AND".to_string(),
1275                        },
1276                        PatternAst::Literal {
1277                            op: UnaryOp::Optional,
1278                            pattern: "OR".to_string(),
1279                        },
1280                    ],
1281                },
1282            ),
1283        ];
1284
1285        for (query, expected) in cases {
1286            let parser = ParserContext { stack: vec![] };
1287            let ast = parser.parse_pattern(query).unwrap();
1288            assert_eq!(expected, ast, "{query}");
1289        }
1290    }
1291
1292    #[test]
1293    fn evaluate_matches() {
1294        let input_data = vec![
1295            "The quick brown fox jumps over the lazy dog",
1296            "The             fox jumps over the lazy dog",
1297            "The quick brown     jumps over the lazy dog",
1298            "The quick brown fox       over the lazy dog",
1299            "The quick brown fox jumps      the lazy dog",
1300            "The quick brown fox jumps over          dog",
1301            "The quick brown fox jumps over the      dog",
1302        ];
1303        let input_vector: VectorRef = Arc::new(StringVector::from(input_data));
1304        let cases = [
1305            // basic cases
1306            ("quick", vec![true, false, true, true, true, true, true]),
1307            (
1308                "\"quick brown\"",
1309                vec![true, false, true, true, true, true, true],
1310            ),
1311            (
1312                "\"fox jumps\"",
1313                vec![true, true, false, false, true, true, true],
1314            ),
1315            (
1316                "fox OR lazy",
1317                vec![true, true, true, true, true, true, true],
1318            ),
1319            (
1320                "fox AND lazy",
1321                vec![true, true, false, true, true, false, false],
1322            ),
1323            (
1324                "-over -lazy",
1325                vec![false, false, false, false, false, false, false],
1326            ),
1327            (
1328                "-over AND -lazy",
1329                vec![false, false, false, false, false, false, false],
1330            ),
1331            // priority between AND & OR
1332            (
1333                "fox AND jumps OR over",
1334                vec![true, true, true, true, true, true, true],
1335            ),
1336            (
1337                "fox OR brown AND quick",
1338                vec![true, true, true, true, true, true, true],
1339            ),
1340            (
1341                "(fox OR brown) AND quick",
1342                vec![true, false, true, true, true, true, true],
1343            ),
1344            (
1345                "brown AND quick OR fox",
1346                vec![true, true, true, true, true, true, true],
1347            ),
1348            (
1349                "brown AND (quick OR fox)",
1350                vec![true, false, true, true, true, true, true],
1351            ),
1352            (
1353                "brown AND quick AND fox  OR  jumps AND over AND lazy",
1354                vec![true, true, true, true, true, true, true],
1355            ),
1356            // optional & must conversion
1357            (
1358                "quick brown fox +jumps",
1359                vec![true, true, true, false, true, true, true],
1360            ),
1361            (
1362                "fox +jumps -over",
1363                vec![false, false, false, false, true, false, false],
1364            ),
1365            (
1366                "fox AND +jumps AND -over",
1367                vec![false, false, false, false, true, false, false],
1368            ),
1369            // weird parentheses cases
1370            (
1371                "(+fox +jumps) over",
1372                vec![true, true, true, true, true, true, true],
1373            ),
1374            (
1375                "+(fox jumps) AND over",
1376                vec![true, true, true, true, false, true, true],
1377            ),
1378            (
1379                "over -(fox jumps)",
1380                vec![false, false, false, false, false, false, false],
1381            ),
1382            (
1383                "over -(fox AND jumps)",
1384                vec![false, false, true, true, false, false, false],
1385            ),
1386            (
1387                "over AND -(-(fox OR jumps))",
1388                vec![true, true, true, true, false, true, true],
1389            ),
1390        ];
1391
1392        let f = MatchesFunction;
1393        for (pattern, expected) in cases {
1394            let actual: VectorRef = f.eval(&input_vector, pattern.to_string()).unwrap();
1395            let expected: VectorRef = Arc::new(BooleanVector::from(expected)) as _;
1396            assert_eq!(expected, actual, "{pattern}");
1397        }
1398    }
1399}