common_function/scalars/
matches.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt;
17use std::sync::Arc;
18
19use common_query::error::{
20    GeneralDataFusionSnafu, IntoVectorSnafu, InvalidFuncArgsSnafu, InvalidInputTypeSnafu, Result,
21};
22use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion};
23use datafusion::common::{DFSchema, Result as DfResult};
24use datafusion::execution::SessionStateBuilder;
25use datafusion::logical_expr::{self, Expr, Volatility};
26use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
27use datatypes::arrow::array::RecordBatch;
28use datatypes::arrow::datatypes::{DataType, Field};
29use datatypes::prelude::VectorRef;
30use datatypes::vectors::BooleanVector;
31use snafu::{ensure, OptionExt, ResultExt};
32use store_api::storage::ConcreteDataType;
33
34use crate::function::{Function, FunctionContext};
35use crate::function_registry::FunctionRegistry;
36
37/// `matches` for full text search.
38///
39/// Usage: matches(`<col>`, `<pattern>`) -> boolean
40#[derive(Clone, Debug, Default)]
41pub(crate) struct MatchesFunction;
42
43impl MatchesFunction {
44    pub fn register(registry: &FunctionRegistry) {
45        registry.register(Arc::new(MatchesFunction));
46    }
47}
48
49impl fmt::Display for MatchesFunction {
50    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
51        write!(f, "MATCHES")
52    }
53}
54
55impl Function for MatchesFunction {
56    fn name(&self) -> &str {
57        "matches"
58    }
59
60    fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
61        Ok(ConcreteDataType::boolean_datatype())
62    }
63
64    fn signature(&self) -> common_query::prelude::Signature {
65        common_query::prelude::Signature::exact(
66            vec![
67                ConcreteDataType::string_datatype(),
68                ConcreteDataType::string_datatype(),
69            ],
70            Volatility::Immutable,
71        )
72    }
73
74    // TODO: read case-sensitive config
75    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
76        ensure!(
77            columns.len() == 2,
78            InvalidFuncArgsSnafu {
79                err_msg: format!(
80                    "The length of the args is not correct, expect exactly 2, have: {}",
81                    columns.len()
82                ),
83            }
84        );
85
86        let data_column = &columns[0];
87        if data_column.is_empty() {
88            return Ok(Arc::new(BooleanVector::from(Vec::<bool>::with_capacity(0))));
89        }
90
91        let pattern_vector = &columns[1]
92            .cast(&ConcreteDataType::string_datatype())
93            .context(InvalidInputTypeSnafu {
94                err_msg: "cannot cast `pattern` to string",
95            })?;
96        // Safety: both length and type are checked before
97        let pattern = pattern_vector.get(0).as_string().unwrap();
98        self.eval(data_column, pattern)
99    }
100}
101
102impl MatchesFunction {
103    fn eval(&self, data: &VectorRef, pattern: String) -> Result<VectorRef> {
104        let col_name = "data";
105        let parser_context = ParserContext::default();
106        let raw_ast = parser_context.parse_pattern(&pattern)?;
107        let ast = raw_ast.transform_ast()?;
108
109        let like_expr = ast.into_like_expr(col_name);
110
111        let input_schema = Self::input_schema();
112        let session_state = SessionStateBuilder::new().with_default_features().build();
113        let planner = DefaultPhysicalPlanner::default();
114        let physical_expr = planner
115            .create_physical_expr(&like_expr, &input_schema, &session_state)
116            .context(GeneralDataFusionSnafu)?;
117
118        let data_array = data.to_arrow_array();
119        let arrow_schema = Arc::new(input_schema.as_arrow().clone());
120        let input_record_batch = RecordBatch::try_new(arrow_schema, vec![data_array]).unwrap();
121
122        let num_rows = input_record_batch.num_rows();
123        let result = physical_expr
124            .evaluate(&input_record_batch)
125            .context(GeneralDataFusionSnafu)?;
126        let result_array = result
127            .into_array(num_rows)
128            .context(GeneralDataFusionSnafu)?;
129        let result_vector =
130            BooleanVector::try_from_arrow_array(result_array).context(IntoVectorSnafu {
131                data_type: DataType::Boolean,
132            })?;
133
134        Ok(Arc::new(result_vector))
135    }
136
137    fn input_schema() -> DFSchema {
138        DFSchema::from_unqualified_fields(
139            [Arc::new(Field::new("data", DataType::Utf8, true))].into(),
140            HashMap::new(),
141        )
142        .unwrap()
143    }
144}
145
146#[derive(Debug, Clone, PartialEq, Eq)]
147enum PatternAst {
148    // Distinguish this with `Group` for simplicity
149    /// A leaf node that matches a column with `pattern`
150    Literal { op: UnaryOp, pattern: String },
151    /// Flattened binary chains
152    Binary {
153        op: BinaryOp,
154        children: Vec<PatternAst>,
155    },
156    /// A sub-tree enclosed by parenthesis
157    Group { op: UnaryOp, child: Box<PatternAst> },
158}
159
160#[derive(Debug, Copy, Clone, PartialEq, Eq)]
161enum UnaryOp {
162    Must,
163    Optional,
164    Negative,
165}
166
167#[derive(Debug, Copy, Clone, PartialEq, Eq)]
168enum BinaryOp {
169    And,
170    Or,
171}
172
173impl PatternAst {
174    fn into_like_expr(self, column: &str) -> Expr {
175        match self {
176            PatternAst::Literal { op, pattern } => {
177                let expr = Self::convert_literal(column, &pattern);
178                match op {
179                    UnaryOp::Must => expr,
180                    UnaryOp::Optional => expr,
181                    UnaryOp::Negative => logical_expr::not(expr),
182                }
183            }
184            PatternAst::Binary { op, children } => {
185                if children.is_empty() {
186                    return logical_expr::lit(true);
187                }
188                let exprs = children
189                    .into_iter()
190                    .map(|child| child.into_like_expr(column));
191                // safety: children is not empty
192                match op {
193                    BinaryOp::And => exprs.reduce(Expr::and).unwrap(),
194                    BinaryOp::Or => exprs.reduce(Expr::or).unwrap(),
195                }
196            }
197            PatternAst::Group { op, child } => {
198                let child = child.into_like_expr(column);
199                match op {
200                    UnaryOp::Must => child,
201                    UnaryOp::Optional => child,
202                    UnaryOp::Negative => logical_expr::not(child),
203                }
204            }
205        }
206    }
207
208    fn convert_literal(column: &str, pattern: &str) -> Expr {
209        logical_expr::col(column).like(logical_expr::lit(format!(
210            "%{}%",
211            crate::utils::escape_like_pattern(pattern)
212        )))
213    }
214
215    /// Transform this AST with preset rules to make it correct.
216    fn transform_ast(self) -> Result<Self> {
217        self.transform_up(Self::collapse_binary_branch_fn)
218            .context(GeneralDataFusionSnafu)
219            .map(|data| data.data)?
220            .transform_up(Self::eliminate_optional_fn)
221            .context(GeneralDataFusionSnafu)
222            .map(|data| data.data)?
223            .transform_down(Self::eliminate_single_child_fn)
224            .context(GeneralDataFusionSnafu)
225            .map(|data| data.data)
226    }
227
228    /// Collapse binary branch with the same operator. I.e., this transformer
229    /// changes the binary-tree AST into a multiple branching AST.
230    ///
231    /// This function is expected to be called in a bottom-up manner as
232    /// it won't recursion.
233    fn collapse_binary_branch_fn(self) -> DfResult<Transformed<Self>> {
234        let PatternAst::Binary {
235            op: parent_op,
236            children,
237        } = self
238        else {
239            return Ok(Transformed::no(self));
240        };
241
242        let mut collapsed = vec![];
243        let mut remains = vec![];
244
245        for child in children {
246            match child {
247                PatternAst::Literal { .. } | PatternAst::Group { .. } => {
248                    collapsed.push(child);
249                }
250                PatternAst::Binary { op, children } => {
251                    // no need to recursion because this function is expected to be called
252                    // in a bottom-up manner
253                    if op == parent_op {
254                        collapsed.extend(children);
255                    } else {
256                        remains.push(PatternAst::Binary { op, children });
257                    }
258                }
259            }
260        }
261
262        if collapsed.is_empty() {
263            Ok(Transformed::no(PatternAst::Binary {
264                op: parent_op,
265                children: remains,
266            }))
267        } else {
268            collapsed.extend(remains);
269            Ok(Transformed::yes(PatternAst::Binary {
270                op: parent_op,
271                children: collapsed,
272            }))
273        }
274    }
275
276    /// Eliminate optional pattern. An optional pattern can always be
277    /// omitted or transformed into a must pattern follows the following rules:
278    /// - If there is only one pattern and it's optional, change it to must
279    /// - If there is any must pattern, remove all other optional patterns
280    fn eliminate_optional_fn(self) -> DfResult<Transformed<Self>> {
281        let PatternAst::Binary {
282            op: parent_op,
283            children,
284        } = self
285        else {
286            return Ok(Transformed::no(self));
287        };
288
289        if parent_op == BinaryOp::Or {
290            let mut must_list = vec![];
291            let mut must_not_list = vec![];
292            let mut optional_list = vec![];
293            let mut compound_list = vec![];
294
295            for child in children {
296                match child {
297                    PatternAst::Literal { op, .. } | PatternAst::Group { op, .. } => match op {
298                        UnaryOp::Must => must_list.push(child),
299                        UnaryOp::Optional => optional_list.push(child),
300                        UnaryOp::Negative => must_not_list.push(child),
301                    },
302                    PatternAst::Binary { .. } => {
303                        compound_list.push(child);
304                    }
305                }
306            }
307
308            // Eliminate optional list if there is MUST.
309            if !must_list.is_empty() {
310                optional_list.clear();
311            }
312
313            let children_this_level = optional_list.into_iter().chain(compound_list).collect();
314            let new_node = if !must_list.is_empty() || !must_not_list.is_empty() {
315                let new_children = must_list
316                    .into_iter()
317                    .chain(must_not_list)
318                    .chain(Some(PatternAst::Binary {
319                        op: BinaryOp::Or,
320                        children: children_this_level,
321                    }))
322                    .collect();
323                PatternAst::Binary {
324                    op: BinaryOp::And,
325                    children: new_children,
326                }
327            } else {
328                PatternAst::Binary {
329                    op: BinaryOp::Or,
330                    children: children_this_level,
331                }
332            };
333
334            return Ok(Transformed::yes(new_node));
335        }
336
337        Ok(Transformed::no(PatternAst::Binary {
338            op: parent_op,
339            children,
340        }))
341    }
342
343    /// Eliminate single child [`PatternAst::Binary`] node. If a binary node has only one child, it can be
344    /// replaced by its only child.
345    ///
346    /// This function prefers to be applied in a top-down manner. But it's not required.
347    fn eliminate_single_child_fn(self) -> DfResult<Transformed<Self>> {
348        let PatternAst::Binary { op, mut children } = self else {
349            return Ok(Transformed::no(self));
350        };
351
352        // remove empty grand children
353        children.retain(|child| match child {
354            PatternAst::Binary {
355                children: grand_children,
356                ..
357            } => !grand_children.is_empty(),
358            PatternAst::Literal { .. } | PatternAst::Group { .. } => true,
359        });
360
361        if children.len() == 1 {
362            Ok(Transformed::yes(children.into_iter().next().unwrap()))
363        } else {
364            Ok(Transformed::no(PatternAst::Binary { op, children }))
365        }
366    }
367}
368
369impl TreeNode for PatternAst {
370    fn apply_children<'n, F: FnMut(&'n Self) -> DfResult<TreeNodeRecursion>>(
371        &'n self,
372        mut f: F,
373    ) -> DfResult<TreeNodeRecursion> {
374        match self {
375            PatternAst::Literal { .. } => Ok(TreeNodeRecursion::Continue),
376            PatternAst::Binary { op: _, children } => {
377                for child in children {
378                    if TreeNodeRecursion::Stop == f(child)? {
379                        return Ok(TreeNodeRecursion::Stop);
380                    }
381                }
382                Ok(TreeNodeRecursion::Continue)
383            }
384            PatternAst::Group { op: _, child } => f(child),
385        }
386    }
387
388    fn map_children<F: FnMut(Self) -> DfResult<Transformed<Self>>>(
389        self,
390        mut f: F,
391    ) -> DfResult<Transformed<Self>> {
392        match self {
393            PatternAst::Literal { .. } => Ok(Transformed::no(self)),
394            PatternAst::Binary { op, children } => children
395                .into_iter()
396                .map_until_stop_and_collect(&mut f)?
397                .map_data(|new_children| {
398                    Ok(PatternAst::Binary {
399                        op,
400                        children: new_children,
401                    })
402                }),
403            PatternAst::Group { op, child } => f(*child)?.map_data(|new_child| {
404                Ok(PatternAst::Group {
405                    op,
406                    child: Box::new(new_child),
407                })
408            }),
409        }
410    }
411}
412
413#[derive(Default)]
414struct ParserContext {
415    stack: Vec<PatternAst>,
416}
417
418impl ParserContext {
419    pub fn parse_pattern(mut self, pattern: &str) -> Result<PatternAst> {
420        let tokenizer = Tokenizer::default();
421        let raw_tokens = tokenizer.tokenize(pattern)?;
422        let raw_tokens = Self::accomplish_optional_unary_op(raw_tokens)?;
423        let mut tokens = Self::to_rpn(raw_tokens)?;
424
425        while !tokens.is_empty() {
426            self.parse_one_impl(&mut tokens)?;
427        }
428
429        ensure!(
430            !self.stack.is_empty(),
431            InvalidFuncArgsSnafu {
432                err_msg: "Empty pattern",
433            }
434        );
435
436        // conjoin them together
437        if self.stack.len() == 1 {
438            Ok(self.stack.pop().unwrap())
439        } else {
440            Ok(PatternAst::Binary {
441                op: BinaryOp::Or,
442                children: self.stack,
443            })
444        }
445    }
446
447    /// Add [`Token::Optional`] for all bare [`Token::Phase`] and [`Token::Or`]
448    /// for all adjacent [`Token::Phase`]s.
449    ///
450    /// This function also does some checks by the way. Like if two unary ops are
451    /// adjacent.
452    fn accomplish_optional_unary_op(raw_tokens: Vec<Token>) -> Result<Vec<Token>> {
453        let mut is_prev_unary_op = false;
454        // The first one doesn't need binary op
455        let mut is_binary_op_before = true;
456        let mut is_unary_op_before = false;
457        let mut new_tokens = Vec::with_capacity(raw_tokens.len());
458        for token in raw_tokens {
459            // fill `Token::Or`
460            if !is_binary_op_before
461                && matches!(
462                    token,
463                    Token::Phase(_)
464                        | Token::OpenParen
465                        | Token::Must
466                        | Token::Optional
467                        | Token::Negative
468                )
469            {
470                is_binary_op_before = true;
471                new_tokens.push(Token::Or);
472            }
473            if matches!(
474                token,
475                Token::OpenParen // treat open paren as begin of new group
476                | Token::And | Token::Or
477            ) {
478                is_binary_op_before = true;
479            } else if matches!(token, Token::Phase(_) | Token::CloseParen) {
480                // need binary op next time
481                is_binary_op_before = false;
482            }
483
484            // fill `Token::Optional`
485            if !is_prev_unary_op && matches!(token, Token::Phase(_) | Token::OpenParen) {
486                new_tokens.push(Token::Optional);
487            } else {
488                is_prev_unary_op = matches!(token, Token::Must | Token::Negative);
489            }
490
491            // check if unary ops are adjacent by the way
492            if matches!(token, Token::Must | Token::Optional | Token::Negative) {
493                if is_unary_op_before {
494                    return InvalidFuncArgsSnafu {
495                        err_msg: "Invalid pattern, unary operators should not be adjacent",
496                    }
497                    .fail();
498                }
499                is_unary_op_before = true;
500            } else {
501                is_unary_op_before = false;
502            }
503
504            new_tokens.push(token);
505        }
506
507        Ok(new_tokens)
508    }
509
510    /// Convert infix token stream to RPN
511    fn to_rpn(mut raw_tokens: Vec<Token>) -> Result<Vec<Token>> {
512        let mut operator_stack = vec![];
513        let mut result = vec![];
514        raw_tokens.reverse();
515
516        while let Some(token) = raw_tokens.pop() {
517            match token {
518                Token::Phase(_) => result.push(token),
519                Token::Must | Token::Negative | Token::Optional => {
520                    operator_stack.push(token);
521                }
522                Token::OpenParen => operator_stack.push(token),
523                Token::And | Token::Or => {
524                    // - Or has lower priority than And
525                    // - Binary op have lower priority than unary op
526                    while let Some(stack_top) = operator_stack.last()
527                        && ((*stack_top == Token::And && token == Token::Or)
528                            || matches!(
529                                *stack_top,
530                                Token::Must | Token::Optional | Token::Negative
531                            ))
532                    {
533                        result.push(operator_stack.pop().unwrap());
534                    }
535                    operator_stack.push(token);
536                }
537                Token::CloseParen => {
538                    let mut is_open_paren_found = false;
539                    while let Some(op) = operator_stack.pop() {
540                        if op == Token::OpenParen {
541                            is_open_paren_found = true;
542                            break;
543                        }
544                        result.push(op);
545                    }
546                    if !is_open_paren_found {
547                        return InvalidFuncArgsSnafu {
548                            err_msg: "Unmatched close parentheses",
549                        }
550                        .fail();
551                    }
552                }
553            }
554        }
555
556        while let Some(operator) = operator_stack.pop() {
557            if operator == Token::OpenParen {
558                return InvalidFuncArgsSnafu {
559                    err_msg: "Unmatched parentheses",
560                }
561                .fail();
562            }
563            result.push(operator);
564        }
565
566        Ok(result)
567    }
568
569    fn parse_one_impl(&mut self, tokens: &mut Vec<Token>) -> Result<()> {
570        if let Some(token) = tokens.pop() {
571            match token {
572                Token::Must => {
573                    if self.stack.is_empty() {
574                        self.parse_one_impl(tokens)?;
575                    }
576                    let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
577                        err_msg: "Invalid pattern, \"+\" operator should have one operand",
578                    })?;
579                    match phase_or_group {
580                        PatternAst::Literal { op: _, pattern } => {
581                            self.stack.push(PatternAst::Literal {
582                                op: UnaryOp::Must,
583                                pattern,
584                            });
585                        }
586                        PatternAst::Binary { .. } | PatternAst::Group { .. } => {
587                            self.stack.push(PatternAst::Group {
588                                op: UnaryOp::Must,
589                                child: Box::new(phase_or_group),
590                            })
591                        }
592                    }
593                    return Ok(());
594                }
595                Token::Negative => {
596                    if self.stack.is_empty() {
597                        self.parse_one_impl(tokens)?;
598                    }
599                    let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
600                        err_msg: "Invalid pattern, \"-\" operator should have one operand",
601                    })?;
602                    match phase_or_group {
603                        PatternAst::Literal { op: _, pattern } => {
604                            self.stack.push(PatternAst::Literal {
605                                op: UnaryOp::Negative,
606                                pattern,
607                            });
608                        }
609                        PatternAst::Binary { .. } | PatternAst::Group { .. } => {
610                            self.stack.push(PatternAst::Group {
611                                op: UnaryOp::Negative,
612                                child: Box::new(phase_or_group),
613                            })
614                        }
615                    }
616                    return Ok(());
617                }
618                Token::Optional => {
619                    if self.stack.is_empty() {
620                        self.parse_one_impl(tokens)?;
621                    }
622                    let phase_or_group = self.stack.pop().context(InvalidFuncArgsSnafu {
623                        err_msg:
624                            "Invalid pattern, OPTIONAL(space) operator should have one operand",
625                    })?;
626                    match phase_or_group {
627                        PatternAst::Literal { op: _, pattern } => {
628                            self.stack.push(PatternAst::Literal {
629                                op: UnaryOp::Optional,
630                                pattern,
631                            });
632                        }
633                        PatternAst::Binary { .. } | PatternAst::Group { .. } => {
634                            self.stack.push(PatternAst::Group {
635                                op: UnaryOp::Optional,
636                                child: Box::new(phase_or_group),
637                            })
638                        }
639                    }
640                    return Ok(());
641                }
642                Token::Phase(pattern) => {
643                    self.stack.push(PatternAst::Literal {
644                        // Op here is a placeholder
645                        op: UnaryOp::Optional,
646                        pattern,
647                    })
648                }
649                Token::And => {
650                    if self.stack.is_empty() {
651                        self.parse_one_impl(tokens)?;
652                    };
653                    let rhs = self.stack.pop().context(InvalidFuncArgsSnafu {
654                        err_msg: "Invalid pattern, \"AND\" operator should have two operands",
655                    })?;
656                    if self.stack.is_empty() {
657                        self.parse_one_impl(tokens)?
658                    };
659                    let lhs = self.stack.pop().context(InvalidFuncArgsSnafu {
660                        err_msg: "Invalid pattern, \"AND\" operator should have two operands",
661                    })?;
662                    self.stack.push(PatternAst::Binary {
663                        op: BinaryOp::And,
664                        children: vec![lhs, rhs],
665                    });
666                    return Ok(());
667                }
668                Token::Or => {
669                    if self.stack.is_empty() {
670                        self.parse_one_impl(tokens)?
671                    };
672                    let rhs = self.stack.pop().context(InvalidFuncArgsSnafu {
673                        err_msg: "Invalid pattern, \"OR\" operator should have two operands",
674                    })?;
675                    if self.stack.is_empty() {
676                        self.parse_one_impl(tokens)?
677                    };
678                    let lhs = self.stack.pop().context(InvalidFuncArgsSnafu {
679                        err_msg: "Invalid pattern, \"OR\" operator should have two operands",
680                    })?;
681                    self.stack.push(PatternAst::Binary {
682                        op: BinaryOp::Or,
683                        children: vec![lhs, rhs],
684                    });
685                    return Ok(());
686                }
687                Token::OpenParen | Token::CloseParen => {
688                    return InvalidFuncArgsSnafu {
689                        err_msg: "Unexpected parentheses",
690                    }
691                    .fail();
692                }
693            }
694        }
695
696        Ok(())
697    }
698}
699
700#[derive(Clone, Debug, PartialEq, Eq)]
701enum Token {
702    /// "+"
703    Must,
704    /// "-"
705    Negative,
706    /// "AND"
707    And,
708    /// "OR"
709    Or,
710    /// "("
711    OpenParen,
712    /// ")"
713    CloseParen,
714    /// Any other phases
715    Phase(String),
716
717    /// This is not a token from user input, but a placeholder for internal use.
718    /// It's used to accomplish the unary operator class with Must and Negative.
719    /// In user provided pattern, optional is expressed by a bare phase or group
720    /// (simply nothing or writespace).
721    Optional,
722}
723
724#[derive(Default)]
725struct Tokenizer {
726    cursor: usize,
727}
728
729impl Tokenizer {
730    pub fn tokenize(mut self, pattern: &str) -> Result<Vec<Token>> {
731        let mut tokens = vec![];
732        let char_len = pattern.chars().count();
733        while self.cursor < char_len {
734            // TODO: collect pattern into Vec<char> if this tokenizer is bottleneck in the future
735            let c = pattern.chars().nth(self.cursor).unwrap();
736            match c {
737                '+' => tokens.push(Token::Must),
738                '-' => tokens.push(Token::Negative),
739                '(' => tokens.push(Token::OpenParen),
740                ')' => tokens.push(Token::CloseParen),
741                ' ' => {
742                    if let Some(last_token) = tokens.last() {
743                        match last_token {
744                            Token::Must | Token::Negative => {
745                                return InvalidFuncArgsSnafu {
746                                    err_msg: format!("Unexpected space after {:?}", last_token),
747                                }
748                                .fail();
749                            }
750                            _ => {}
751                        }
752                    }
753                }
754                '\"' => {
755                    self.step_next();
756                    let phase = self.consume_next_phase(true, pattern)?;
757                    tokens.push(Token::Phase(phase));
758                    // consume a writespace (or EOF) after quotes
759                    if let Some(ending_separator) = self.consume_next(pattern) {
760                        if ending_separator != ' ' {
761                            return InvalidFuncArgsSnafu {
762                                err_msg: "Expect a space after quotes ('\"')",
763                            }
764                            .fail();
765                        }
766                    }
767                }
768                _ => {
769                    let phase = self.consume_next_phase(false, pattern)?;
770                    match phase.to_uppercase().as_str() {
771                        "AND" => tokens.push(Token::And),
772                        "OR" => tokens.push(Token::Or),
773                        _ => tokens.push(Token::Phase(phase)),
774                    }
775                }
776            }
777            self.cursor += 1;
778        }
779        Ok(tokens)
780    }
781
782    fn consume_next(&mut self, pattern: &str) -> Option<char> {
783        self.cursor += 1;
784        let c = pattern.chars().nth(self.cursor);
785        c
786    }
787
788    fn step_next(&mut self) {
789        self.cursor += 1;
790    }
791
792    fn rewind_one(&mut self) {
793        self.cursor -= 1;
794    }
795
796    /// Current `cursor` points to the first character of the phase.
797    /// If the phase is enclosed by double quotes, consume the start quote before calling this.
798    fn consume_next_phase(&mut self, is_quoted: bool, pattern: &str) -> Result<String> {
799        let mut phase = String::new();
800        let mut is_quote_present = false;
801
802        let char_len = pattern.chars().count();
803        while self.cursor < char_len {
804            let mut c = pattern.chars().nth(self.cursor).unwrap();
805
806            match c {
807                '\"' => {
808                    is_quote_present = true;
809                    break;
810                }
811                ' ' => {
812                    if !is_quoted {
813                        break;
814                    }
815                }
816                '(' | ')' | '+' | '-' => {
817                    if !is_quoted {
818                        self.rewind_one();
819                        break;
820                    }
821                }
822                '\\' => {
823                    let Some(next) = self.consume_next(pattern) else {
824                        return InvalidFuncArgsSnafu {
825                            err_msg: "Unexpected end of pattern, expected a character after escape ('\\')",
826                        }.fail();
827                    };
828                    // it doesn't check whether the escaped character is valid or not
829                    c = next;
830                }
831                _ => {}
832            }
833
834            phase.push(c);
835            self.cursor += 1;
836        }
837
838        if is_quoted ^ is_quote_present {
839            return InvalidFuncArgsSnafu {
840                err_msg: "Unclosed quotes ('\"')",
841            }
842            .fail();
843        }
844
845        Ok(phase)
846    }
847}
848
849#[cfg(test)]
850mod test {
851    use datatypes::vectors::StringVector;
852
853    use super::*;
854
855    #[test]
856    fn valid_matches_tokenizer() {
857        use Token::*;
858        let cases = [
859            (
860                "a +b -c",
861                vec![
862                    Phase("a".to_string()),
863                    Must,
864                    Phase("b".to_string()),
865                    Negative,
866                    Phase("c".to_string()),
867                ],
868            ),
869            (
870                "+a(b-c)",
871                vec![
872                    Must,
873                    Phase("a".to_string()),
874                    OpenParen,
875                    Phase("b".to_string()),
876                    Negative,
877                    Phase("c".to_string()),
878                    CloseParen,
879                ],
880            ),
881            (
882                r#"Barack Obama"#,
883                vec![Phase("Barack".to_string()), Phase("Obama".to_string())],
884            ),
885            (
886                r#"+apple +fruit"#,
887                vec![
888                    Must,
889                    Phase("apple".to_string()),
890                    Must,
891                    Phase("fruit".to_string()),
892                ],
893            ),
894            (
895                r#""He said \"hello\"""#,
896                vec![Phase("He said \"hello\"".to_string())],
897            ),
898            (
899                r#"a AND b OR c"#,
900                vec![
901                    Phase("a".to_string()),
902                    And,
903                    Phase("b".to_string()),
904                    Or,
905                    Phase("c".to_string()),
906                ],
907            ),
908            (
909                r#"中文 测试"#,
910                vec![Phase("中文".to_string()), Phase("测试".to_string())],
911            ),
912            (
913                r#"中文 AND 测试"#,
914                vec![Phase("中文".to_string()), And, Phase("测试".to_string())],
915            ),
916            (
917                r#"中文 +测试"#,
918                vec![Phase("中文".to_string()), Must, Phase("测试".to_string())],
919            ),
920            (
921                r#"中文 -测试"#,
922                vec![
923                    Phase("中文".to_string()),
924                    Negative,
925                    Phase("测试".to_string()),
926                ],
927            ),
928        ];
929
930        for (query, expected) in cases {
931            let tokenizer = Tokenizer::default();
932            let tokens = tokenizer.tokenize(query).unwrap();
933            assert_eq!(expected, tokens, "{query}");
934        }
935    }
936
937    #[test]
938    fn invalid_matches_tokenizer() {
939        let cases = [
940            (r#""He said "hello""#, "Expect a space after quotes"),
941            (r#""He said hello"#, "Unclosed quotes"),
942            (r#"a + b - c"#, "Unexpected space after"),
943            (r#"ab "c"def"#, "Expect a space after quotes"),
944        ];
945
946        for (query, expected) in cases {
947            let tokenizer = Tokenizer::default();
948            let result = tokenizer.tokenize(query);
949            assert!(result.is_err(), "{query}");
950            let actual_error = result.unwrap_err().to_string();
951            assert!(actual_error.contains(expected), "{query}, {actual_error}");
952        }
953    }
954
955    #[test]
956    fn valid_ast_transformer() {
957        let cases = [
958            (
959                "a AND b OR c",
960                PatternAst::Binary {
961                    op: BinaryOp::Or,
962                    children: vec![
963                        PatternAst::Literal {
964                            op: UnaryOp::Optional,
965                            pattern: "c".to_string(),
966                        },
967                        PatternAst::Binary {
968                            op: BinaryOp::And,
969                            children: vec![
970                                PatternAst::Literal {
971                                    op: UnaryOp::Optional,
972                                    pattern: "a".to_string(),
973                                },
974                                PatternAst::Literal {
975                                    op: UnaryOp::Optional,
976                                    pattern: "b".to_string(),
977                                },
978                            ],
979                        },
980                    ],
981                },
982            ),
983            (
984                "a -b",
985                PatternAst::Binary {
986                    op: BinaryOp::And,
987                    children: vec![
988                        PatternAst::Literal {
989                            op: UnaryOp::Negative,
990                            pattern: "b".to_string(),
991                        },
992                        PatternAst::Literal {
993                            op: UnaryOp::Optional,
994                            pattern: "a".to_string(),
995                        },
996                    ],
997                },
998            ),
999            (
1000                "a +b",
1001                PatternAst::Literal {
1002                    op: UnaryOp::Must,
1003                    pattern: "b".to_string(),
1004                },
1005            ),
1006            (
1007                "a b c d",
1008                PatternAst::Binary {
1009                    op: BinaryOp::Or,
1010                    children: vec![
1011                        PatternAst::Literal {
1012                            op: UnaryOp::Optional,
1013                            pattern: "a".to_string(),
1014                        },
1015                        PatternAst::Literal {
1016                            op: UnaryOp::Optional,
1017                            pattern: "b".to_string(),
1018                        },
1019                        PatternAst::Literal {
1020                            op: UnaryOp::Optional,
1021                            pattern: "c".to_string(),
1022                        },
1023                        PatternAst::Literal {
1024                            op: UnaryOp::Optional,
1025                            pattern: "d".to_string(),
1026                        },
1027                    ],
1028                },
1029            ),
1030            (
1031                "a b c AND d",
1032                PatternAst::Binary {
1033                    op: BinaryOp::Or,
1034                    children: vec![
1035                        PatternAst::Literal {
1036                            op: UnaryOp::Optional,
1037                            pattern: "a".to_string(),
1038                        },
1039                        PatternAst::Literal {
1040                            op: UnaryOp::Optional,
1041                            pattern: "b".to_string(),
1042                        },
1043                        PatternAst::Binary {
1044                            op: BinaryOp::And,
1045                            children: vec![
1046                                PatternAst::Literal {
1047                                    op: UnaryOp::Optional,
1048                                    pattern: "c".to_string(),
1049                                },
1050                                PatternAst::Literal {
1051                                    op: UnaryOp::Optional,
1052                                    pattern: "d".to_string(),
1053                                },
1054                            ],
1055                        },
1056                    ],
1057                },
1058            ),
1059            (
1060                r#"中文 测试"#,
1061                PatternAst::Binary {
1062                    op: BinaryOp::Or,
1063                    children: vec![
1064                        PatternAst::Literal {
1065                            op: UnaryOp::Optional,
1066                            pattern: "中文".to_string(),
1067                        },
1068                        PatternAst::Literal {
1069                            op: UnaryOp::Optional,
1070                            pattern: "测试".to_string(),
1071                        },
1072                    ],
1073                },
1074            ),
1075            (
1076                r#"中文 AND 测试"#,
1077                PatternAst::Binary {
1078                    op: BinaryOp::And,
1079                    children: vec![
1080                        PatternAst::Literal {
1081                            op: UnaryOp::Optional,
1082                            pattern: "中文".to_string(),
1083                        },
1084                        PatternAst::Literal {
1085                            op: UnaryOp::Optional,
1086                            pattern: "测试".to_string(),
1087                        },
1088                    ],
1089                },
1090            ),
1091            (
1092                r#"中文 +测试"#,
1093                PatternAst::Literal {
1094                    op: UnaryOp::Must,
1095                    pattern: "测试".to_string(),
1096                },
1097            ),
1098            (
1099                r#"中文 -测试"#,
1100                PatternAst::Binary {
1101                    op: BinaryOp::And,
1102                    children: vec![
1103                        PatternAst::Literal {
1104                            op: UnaryOp::Negative,
1105                            pattern: "测试".to_string(),
1106                        },
1107                        PatternAst::Literal {
1108                            op: UnaryOp::Optional,
1109                            pattern: "中文".to_string(),
1110                        },
1111                    ],
1112                },
1113            ),
1114        ];
1115
1116        for (query, expected) in cases {
1117            let parser = ParserContext { stack: vec![] };
1118            let ast = parser.parse_pattern(query).unwrap();
1119            let ast = ast.transform_ast().unwrap();
1120            assert_eq!(expected, ast, "{query}");
1121        }
1122    }
1123
1124    #[test]
1125    fn invalid_ast() {
1126        let cases = [
1127            (r#"a b (c"#, "Unmatched parentheses"),
1128            (r#"a b) c"#, "Unmatched close parentheses"),
1129            (r#"a +-b"#, "unary operators should not be adjacent"),
1130        ];
1131
1132        for (query, expected) in cases {
1133            let result: Result<()> = try {
1134                let parser = ParserContext { stack: vec![] };
1135                let ast = parser.parse_pattern(query)?;
1136                let _ast = ast.transform_ast()?;
1137            };
1138
1139            assert!(result.is_err(), "{query}");
1140            let actual_error = result.unwrap_err().to_string();
1141            assert!(actual_error.contains(expected), "{query}, {actual_error}");
1142        }
1143    }
1144
1145    #[test]
1146    fn valid_matches_parser() {
1147        let cases = [
1148            (
1149                "a AND b OR c",
1150                PatternAst::Binary {
1151                    op: BinaryOp::Or,
1152                    children: vec![
1153                        PatternAst::Binary {
1154                            op: BinaryOp::And,
1155                            children: vec![
1156                                PatternAst::Literal {
1157                                    op: UnaryOp::Optional,
1158                                    pattern: "a".to_string(),
1159                                },
1160                                PatternAst::Literal {
1161                                    op: UnaryOp::Optional,
1162                                    pattern: "b".to_string(),
1163                                },
1164                            ],
1165                        },
1166                        PatternAst::Literal {
1167                            op: UnaryOp::Optional,
1168                            pattern: "c".to_string(),
1169                        },
1170                    ],
1171                },
1172            ),
1173            (
1174                "(a AND b) OR c",
1175                PatternAst::Binary {
1176                    op: BinaryOp::Or,
1177                    children: vec![
1178                        PatternAst::Group {
1179                            op: UnaryOp::Optional,
1180                            child: Box::new(PatternAst::Binary {
1181                                op: BinaryOp::And,
1182                                children: vec![
1183                                    PatternAst::Literal {
1184                                        op: UnaryOp::Optional,
1185                                        pattern: "a".to_string(),
1186                                    },
1187                                    PatternAst::Literal {
1188                                        op: UnaryOp::Optional,
1189                                        pattern: "b".to_string(),
1190                                    },
1191                                ],
1192                            }),
1193                        },
1194                        PatternAst::Literal {
1195                            op: UnaryOp::Optional,
1196                            pattern: "c".to_string(),
1197                        },
1198                    ],
1199                },
1200            ),
1201            (
1202                "a AND (b OR c)",
1203                PatternAst::Binary {
1204                    op: BinaryOp::And,
1205                    children: vec![
1206                        PatternAst::Literal {
1207                            op: UnaryOp::Optional,
1208                            pattern: "a".to_string(),
1209                        },
1210                        PatternAst::Group {
1211                            op: UnaryOp::Optional,
1212                            child: Box::new(PatternAst::Binary {
1213                                op: BinaryOp::Or,
1214                                children: vec![
1215                                    PatternAst::Literal {
1216                                        op: UnaryOp::Optional,
1217                                        pattern: "b".to_string(),
1218                                    },
1219                                    PatternAst::Literal {
1220                                        op: UnaryOp::Optional,
1221                                        pattern: "c".to_string(),
1222                                    },
1223                                ],
1224                            }),
1225                        },
1226                    ],
1227                },
1228            ),
1229            (
1230                "a +b -c",
1231                PatternAst::Binary {
1232                    op: BinaryOp::Or,
1233                    children: vec![
1234                        PatternAst::Literal {
1235                            op: UnaryOp::Optional,
1236                            pattern: "a".to_string(),
1237                        },
1238                        PatternAst::Binary {
1239                            op: BinaryOp::Or,
1240                            children: vec![
1241                                PatternAst::Literal {
1242                                    op: UnaryOp::Must,
1243                                    pattern: "b".to_string(),
1244                                },
1245                                PatternAst::Literal {
1246                                    op: UnaryOp::Negative,
1247                                    pattern: "c".to_string(),
1248                                },
1249                            ],
1250                        },
1251                    ],
1252                },
1253            ),
1254            (
1255                "(+a +b) c",
1256                PatternAst::Binary {
1257                    op: BinaryOp::Or,
1258                    children: vec![
1259                        PatternAst::Group {
1260                            op: UnaryOp::Optional,
1261                            child: Box::new(PatternAst::Binary {
1262                                op: BinaryOp::Or,
1263                                children: vec![
1264                                    PatternAst::Literal {
1265                                        op: UnaryOp::Must,
1266                                        pattern: "a".to_string(),
1267                                    },
1268                                    PatternAst::Literal {
1269                                        op: UnaryOp::Must,
1270                                        pattern: "b".to_string(),
1271                                    },
1272                                ],
1273                            }),
1274                        },
1275                        PatternAst::Literal {
1276                            op: UnaryOp::Optional,
1277                            pattern: "c".to_string(),
1278                        },
1279                    ],
1280                },
1281            ),
1282            (
1283                "\"AND\" AnD \"OR\"",
1284                PatternAst::Binary {
1285                    op: BinaryOp::And,
1286                    children: vec![
1287                        PatternAst::Literal {
1288                            op: UnaryOp::Optional,
1289                            pattern: "AND".to_string(),
1290                        },
1291                        PatternAst::Literal {
1292                            op: UnaryOp::Optional,
1293                            pattern: "OR".to_string(),
1294                        },
1295                    ],
1296                },
1297            ),
1298        ];
1299
1300        for (query, expected) in cases {
1301            let parser = ParserContext { stack: vec![] };
1302            let ast = parser.parse_pattern(query).unwrap();
1303            assert_eq!(expected, ast, "{query}");
1304        }
1305    }
1306
1307    #[test]
1308    fn evaluate_matches() {
1309        let input_data = vec![
1310            "The quick brown fox jumps over the lazy dog",
1311            "The             fox jumps over the lazy dog",
1312            "The quick brown     jumps over the lazy dog",
1313            "The quick brown fox       over the lazy dog",
1314            "The quick brown fox jumps      the lazy dog",
1315            "The quick brown fox jumps over          dog",
1316            "The quick brown fox jumps over the      dog",
1317        ];
1318        let input_vector: VectorRef = Arc::new(StringVector::from(input_data));
1319        let cases = [
1320            // basic cases
1321            ("quick", vec![true, false, true, true, true, true, true]),
1322            (
1323                "\"quick brown\"",
1324                vec![true, false, true, true, true, true, true],
1325            ),
1326            (
1327                "\"fox jumps\"",
1328                vec![true, true, false, false, true, true, true],
1329            ),
1330            (
1331                "fox OR lazy",
1332                vec![true, true, true, true, true, true, true],
1333            ),
1334            (
1335                "fox AND lazy",
1336                vec![true, true, false, true, true, false, false],
1337            ),
1338            (
1339                "-over -lazy",
1340                vec![false, false, false, false, false, false, false],
1341            ),
1342            (
1343                "-over AND -lazy",
1344                vec![false, false, false, false, false, false, false],
1345            ),
1346            // priority between AND & OR
1347            (
1348                "fox AND jumps OR over",
1349                vec![true, true, true, true, true, true, true],
1350            ),
1351            (
1352                "fox OR brown AND quick",
1353                vec![true, true, true, true, true, true, true],
1354            ),
1355            (
1356                "(fox OR brown) AND quick",
1357                vec![true, false, true, true, true, true, true],
1358            ),
1359            (
1360                "brown AND quick OR fox",
1361                vec![true, true, true, true, true, true, true],
1362            ),
1363            (
1364                "brown AND (quick OR fox)",
1365                vec![true, false, true, true, true, true, true],
1366            ),
1367            (
1368                "brown AND quick AND fox  OR  jumps AND over AND lazy",
1369                vec![true, true, true, true, true, true, true],
1370            ),
1371            // optional & must conversion
1372            (
1373                "quick brown fox +jumps",
1374                vec![true, true, true, false, true, true, true],
1375            ),
1376            (
1377                "fox +jumps -over",
1378                vec![false, false, false, false, true, false, false],
1379            ),
1380            (
1381                "fox AND +jumps AND -over",
1382                vec![false, false, false, false, true, false, false],
1383            ),
1384            // weird parentheses cases
1385            (
1386                "(+fox +jumps) over",
1387                vec![true, true, true, true, true, true, true],
1388            ),
1389            (
1390                "+(fox jumps) AND over",
1391                vec![true, true, true, true, false, true, true],
1392            ),
1393            (
1394                "over -(fox jumps)",
1395                vec![false, false, false, false, false, false, false],
1396            ),
1397            (
1398                "over -(fox AND jumps)",
1399                vec![false, false, true, true, false, false, false],
1400            ),
1401            (
1402                "over AND -(-(fox OR jumps))",
1403                vec![true, true, true, true, false, true, true],
1404            ),
1405        ];
1406
1407        let f = MatchesFunction;
1408        for (pattern, expected) in cases {
1409            let actual: VectorRef = f.eval(&input_vector, pattern.to_string()).unwrap();
1410            let expected: VectorRef = Arc::new(BooleanVector::from(expected)) as _;
1411            assert_eq!(expected, actual, "{pattern}");
1412        }
1413    }
1414}