common_function/scalars/string/
regexp_extract.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Implementation of REGEXP_EXTRACT function
16use std::fmt;
17use std::sync::Arc;
18
19use datafusion_common::DataFusionError;
20use datafusion_common::arrow::array::{Array, AsArray, LargeStringBuilder};
21use datafusion_common::arrow::compute::cast;
22use datafusion_common::arrow::datatypes::DataType;
23use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
24use regex::{Regex, RegexBuilder};
25
26use crate::function::Function;
27use crate::function_registry::FunctionRegistry;
28
29const NAME: &str = "regexp_extract";
30
31// Safety limits
32const MAX_REGEX_SIZE: usize = 1024 * 1024; // compiled regex heap cap
33const MAX_DFA_SIZE: usize = 2 * 1024 * 1024; // lazy DFA cap
34const MAX_TOTAL_RESULT_SIZE: usize = 64 * 1024 * 1024; // total batch cap
35const MAX_SINGLE_MATCH: usize = 1024 * 1024; // per-row cap
36const MAX_PATTERN_LEN: usize = 10_000; // pattern text length cap
37
38/// REGEXP_EXTRACT function implementation
39/// Extracts the first substring matching the given regular expression pattern.
40/// If no match is found, returns NULL.
41///
42#[derive(Debug)]
43pub struct RegexpExtractFunction {
44    signature: Signature,
45}
46
47impl RegexpExtractFunction {
48    pub fn register(registry: &FunctionRegistry) {
49        registry.register_scalar(RegexpExtractFunction::default());
50    }
51}
52
53impl Default for RegexpExtractFunction {
54    fn default() -> Self {
55        Self {
56            signature: Signature::one_of(
57                vec![
58                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]),
59                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8View]),
60                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8View]),
61                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8View]),
62                    TypeSignature::Exact(vec![DataType::Utf8View, DataType::LargeUtf8]),
63                    TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
64                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
65                    TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
66                    TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
67                ],
68                Volatility::Immutable,
69            ),
70        }
71    }
72}
73
74impl fmt::Display for RegexpExtractFunction {
75    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
76        write!(f, "{}", NAME.to_ascii_uppercase())
77    }
78}
79
80impl Function for RegexpExtractFunction {
81    fn name(&self) -> &str {
82        NAME
83    }
84
85    // Always return LargeUtf8 for simplicity and safety
86    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
87        Ok(DataType::LargeUtf8)
88    }
89
90    fn signature(&self) -> &Signature {
91        &self.signature
92    }
93
94    fn invoke_with_args(
95        &self,
96        args: ScalarFunctionArgs,
97    ) -> datafusion_common::Result<ColumnarValue> {
98        if args.args.len() != 2 {
99            return Err(DataFusionError::Execution(
100                "REGEXP_EXTRACT requires exactly two arguments (text, pattern)".to_string(),
101            ));
102        }
103
104        // Keep original ColumnarValue variants for scalar-pattern fast path
105        let pattern_is_scalar = matches!(args.args[1], ColumnarValue::Scalar(_));
106
107        let arrays = ColumnarValue::values_to_arrays(&args.args)?;
108        let text_array = &arrays[0];
109        let pattern_array = &arrays[1];
110
111        // Cast both to LargeUtf8 for uniform access (supports Utf8/Utf8View/Dictionary<String>)
112        let text_large = cast(text_array.as_ref(), &DataType::LargeUtf8).map_err(|e| {
113            DataFusionError::Execution(format!("REGEXP_EXTRACT: text cast failed: {e}"))
114        })?;
115        let pattern_large = cast(pattern_array.as_ref(), &DataType::LargeUtf8).map_err(|e| {
116            DataFusionError::Execution(format!("REGEXP_EXTRACT: pattern cast failed: {e}"))
117        })?;
118
119        let text = text_large.as_string::<i64>();
120        let pattern = pattern_large.as_string::<i64>();
121        let len = text.len();
122
123        // Pre-size result builder with conservative estimate
124        let mut estimated_total = 0usize;
125        for i in 0..len {
126            if !text.is_null(i) {
127                estimated_total = estimated_total.saturating_add(text.value_length(i) as usize);
128                if estimated_total > MAX_TOTAL_RESULT_SIZE {
129                    return Err(DataFusionError::ResourcesExhausted(format!(
130                        "REGEXP_EXTRACT total output exceeds {} bytes",
131                        MAX_TOTAL_RESULT_SIZE
132                    )));
133                }
134            }
135        }
136        let mut builder = LargeStringBuilder::with_capacity(len, estimated_total);
137
138        // Fast path: if pattern is scalar, compile once
139        let compiled_scalar: Option<Regex> = if pattern_is_scalar && len > 0 && !pattern.is_null(0)
140        {
141            Some(compile_regex_checked(pattern.value(0))?)
142        } else {
143            None
144        };
145
146        for i in 0..len {
147            if text.is_null(i) || pattern.is_null(i) {
148                builder.append_null();
149                continue;
150            }
151
152            let s = text.value(i);
153            let pat = pattern.value(i);
154
155            // Compile or reuse regex
156            let re = if let Some(ref compiled) = compiled_scalar {
157                compiled
158            } else {
159                // TODO: For performance-critical applications with repeating patterns,
160                // consider adding a small LRU cache here
161                &compile_regex_checked(pat)?
162            };
163
164            // First match only
165            if let Some(m) = re.find(s) {
166                let m_str = m.as_str();
167                if m_str.len() > MAX_SINGLE_MATCH {
168                    return Err(DataFusionError::Execution(
169                        "REGEXP_EXTRACT match exceeds per-row limit (1MB)".to_string(),
170                    ));
171                }
172                builder.append_value(m_str);
173            } else {
174                builder.append_null();
175            }
176        }
177
178        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
179    }
180}
181
182// Compile a regex with safety checks
183fn compile_regex_checked(pattern: &str) -> datafusion_common::Result<Regex> {
184    if pattern.len() > MAX_PATTERN_LEN {
185        return Err(DataFusionError::Execution(format!(
186            "REGEXP_EXTRACT pattern too long (> {} chars)",
187            MAX_PATTERN_LEN
188        )));
189    }
190    RegexBuilder::new(pattern)
191        .size_limit(MAX_REGEX_SIZE)
192        .dfa_size_limit(MAX_DFA_SIZE)
193        .build()
194        .map_err(|e| {
195            DataFusionError::Execution(format!("REGEXP_EXTRACT invalid pattern '{}': {e}", pattern))
196        })
197}
198
199#[cfg(test)]
200mod tests {
201    use datafusion_common::arrow::array::StringArray;
202    use datafusion_common::arrow::datatypes::Field;
203    use datafusion_expr::ScalarFunctionArgs;
204
205    use super::*;
206
207    #[test]
208    fn test_regexp_extract_function_basic() {
209        let text_array = Arc::new(StringArray::from(vec!["version 1.2.3", "no match here"]));
210        let pattern_array = Arc::new(StringArray::from(vec!["\\d+\\.\\d+\\.\\d+", "\\d+"]));
211
212        let args = ScalarFunctionArgs {
213            args: vec![
214                ColumnarValue::Array(text_array),
215                ColumnarValue::Array(pattern_array),
216            ],
217            arg_fields: vec![
218                Arc::new(Field::new("arg_0", DataType::Utf8, false)),
219                Arc::new(Field::new("arg_1", DataType::Utf8, false)),
220            ],
221            return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
222            number_rows: 2,
223            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
224        };
225
226        let function = RegexpExtractFunction::default();
227        let result = function.invoke_with_args(args).unwrap();
228
229        if let ColumnarValue::Array(array) = result {
230            let string_array = array.as_string::<i64>();
231            assert_eq!(string_array.value(0), "1.2.3");
232            assert!(string_array.is_null(1)); // no match should return NULL
233        } else {
234            panic!("Expected array result");
235        }
236    }
237
238    #[test]
239    fn test_regexp_extract_phone_number() {
240        let text_array = Arc::new(StringArray::from(vec!["Phone: 123-456-7890", "No phone"]));
241        let pattern_array = Arc::new(StringArray::from(vec![
242            "\\d{3}-\\d{3}-\\d{4}",
243            "\\d{3}-\\d{3}-\\d{4}",
244        ]));
245
246        let args = ScalarFunctionArgs {
247            args: vec![
248                ColumnarValue::Array(text_array),
249                ColumnarValue::Array(pattern_array),
250            ],
251            arg_fields: vec![
252                Arc::new(Field::new("arg_0", DataType::Utf8, false)),
253                Arc::new(Field::new("arg_1", DataType::Utf8, false)),
254            ],
255            return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
256            number_rows: 2,
257            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
258        };
259
260        let function = RegexpExtractFunction::default();
261        let result = function.invoke_with_args(args).unwrap();
262
263        if let ColumnarValue::Array(array) = result {
264            let string_array = array.as_string::<i64>();
265            assert_eq!(string_array.value(0), "123-456-7890");
266            assert!(string_array.is_null(1)); // no match should return NULL
267        } else {
268            panic!("Expected array result");
269        }
270    }
271
272    #[test]
273    fn test_regexp_extract_email() {
274        let text_array = Arc::new(StringArray::from(vec![
275            "Email: user@domain.com",
276            "Invalid email",
277        ]));
278        let pattern_array = Arc::new(StringArray::from(vec![
279            "[a-zA-Z0-9]+@[a-zA-Z0-9]+\\.[a-zA-Z]+",
280            "[a-zA-Z0-9]+@[a-zA-Z0-9]+\\.[a-zA-Z]+",
281        ]));
282
283        let args = ScalarFunctionArgs {
284            args: vec![
285                ColumnarValue::Array(text_array),
286                ColumnarValue::Array(pattern_array),
287            ],
288            arg_fields: vec![
289                Arc::new(Field::new("arg_0", DataType::Utf8, false)),
290                Arc::new(Field::new("arg_1", DataType::Utf8, false)),
291            ],
292            return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
293            number_rows: 2,
294            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
295        };
296
297        let function = RegexpExtractFunction::default();
298        let result = function.invoke_with_args(args).unwrap();
299
300        if let ColumnarValue::Array(array) = result {
301            let string_array = array.as_string::<i64>();
302            assert_eq!(string_array.value(0), "user@domain.com");
303            assert!(string_array.is_null(1)); // no match should return NULL
304        } else {
305            panic!("Expected array result");
306        }
307    }
308
309    #[test]
310    fn test_regexp_extract_with_nulls() {
311        let text_array = Arc::new(StringArray::from(vec![Some("test 123"), None]));
312        let pattern_array = Arc::new(StringArray::from(vec![Some("\\d+"), Some("\\d+")]));
313
314        let args = ScalarFunctionArgs {
315            args: vec![
316                ColumnarValue::Array(text_array),
317                ColumnarValue::Array(pattern_array),
318            ],
319            arg_fields: vec![
320                Arc::new(Field::new("arg_0", DataType::Utf8, true)),
321                Arc::new(Field::new("arg_1", DataType::Utf8, false)),
322            ],
323            return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
324            number_rows: 2,
325            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
326        };
327
328        let function = RegexpExtractFunction::default();
329        let result = function.invoke_with_args(args).unwrap();
330
331        if let ColumnarValue::Array(array) = result {
332            let string_array = array.as_string::<i64>();
333            assert_eq!(string_array.value(0), "123");
334            assert!(string_array.is_null(1)); // NULL input should return NULL
335        } else {
336            panic!("Expected array result");
337        }
338    }
339}