common_function/scalars/string/
regexp_extract.rs1use std::fmt;
17use std::sync::Arc;
18
19use datafusion_common::DataFusionError;
20use datafusion_common::arrow::array::{Array, AsArray, LargeStringBuilder};
21use datafusion_common::arrow::compute::cast;
22use datafusion_common::arrow::datatypes::DataType;
23use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
24use regex::{Regex, RegexBuilder};
25
26use crate::function::Function;
27use crate::function_registry::FunctionRegistry;
28
29const NAME: &str = "regexp_extract";
30
31const MAX_REGEX_SIZE: usize = 1024 * 1024; const MAX_DFA_SIZE: usize = 2 * 1024 * 1024; const MAX_TOTAL_RESULT_SIZE: usize = 64 * 1024 * 1024; const MAX_SINGLE_MATCH: usize = 1024 * 1024; const MAX_PATTERN_LEN: usize = 10_000; #[derive(Debug)]
43pub struct RegexpExtractFunction {
44 signature: Signature,
45}
46
47impl RegexpExtractFunction {
48 pub fn register(registry: &FunctionRegistry) {
49 registry.register_scalar(RegexpExtractFunction::default());
50 }
51}
52
53impl Default for RegexpExtractFunction {
54 fn default() -> Self {
55 Self {
56 signature: Signature::one_of(
57 vec![
58 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]),
59 TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8View]),
60 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8View]),
61 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8View]),
62 TypeSignature::Exact(vec![DataType::Utf8View, DataType::LargeUtf8]),
63 TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
64 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
65 TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]),
66 TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
67 ],
68 Volatility::Immutable,
69 ),
70 }
71 }
72}
73
74impl fmt::Display for RegexpExtractFunction {
75 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
76 write!(f, "{}", NAME.to_ascii_uppercase())
77 }
78}
79
80impl Function for RegexpExtractFunction {
81 fn name(&self) -> &str {
82 NAME
83 }
84
85 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
87 Ok(DataType::LargeUtf8)
88 }
89
90 fn signature(&self) -> &Signature {
91 &self.signature
92 }
93
94 fn invoke_with_args(
95 &self,
96 args: ScalarFunctionArgs,
97 ) -> datafusion_common::Result<ColumnarValue> {
98 if args.args.len() != 2 {
99 return Err(DataFusionError::Execution(
100 "REGEXP_EXTRACT requires exactly two arguments (text, pattern)".to_string(),
101 ));
102 }
103
104 let pattern_is_scalar = matches!(args.args[1], ColumnarValue::Scalar(_));
106
107 let arrays = ColumnarValue::values_to_arrays(&args.args)?;
108 let text_array = &arrays[0];
109 let pattern_array = &arrays[1];
110
111 let text_large = cast(text_array.as_ref(), &DataType::LargeUtf8).map_err(|e| {
113 DataFusionError::Execution(format!("REGEXP_EXTRACT: text cast failed: {e}"))
114 })?;
115 let pattern_large = cast(pattern_array.as_ref(), &DataType::LargeUtf8).map_err(|e| {
116 DataFusionError::Execution(format!("REGEXP_EXTRACT: pattern cast failed: {e}"))
117 })?;
118
119 let text = text_large.as_string::<i64>();
120 let pattern = pattern_large.as_string::<i64>();
121 let len = text.len();
122
123 let mut estimated_total = 0usize;
125 for i in 0..len {
126 if !text.is_null(i) {
127 estimated_total = estimated_total.saturating_add(text.value_length(i) as usize);
128 if estimated_total > MAX_TOTAL_RESULT_SIZE {
129 return Err(DataFusionError::ResourcesExhausted(format!(
130 "REGEXP_EXTRACT total output exceeds {} bytes",
131 MAX_TOTAL_RESULT_SIZE
132 )));
133 }
134 }
135 }
136 let mut builder = LargeStringBuilder::with_capacity(len, estimated_total);
137
138 let compiled_scalar: Option<Regex> = if pattern_is_scalar && len > 0 && !pattern.is_null(0)
140 {
141 Some(compile_regex_checked(pattern.value(0))?)
142 } else {
143 None
144 };
145
146 for i in 0..len {
147 if text.is_null(i) || pattern.is_null(i) {
148 builder.append_null();
149 continue;
150 }
151
152 let s = text.value(i);
153 let pat = pattern.value(i);
154
155 let re = if let Some(ref compiled) = compiled_scalar {
157 compiled
158 } else {
159 &compile_regex_checked(pat)?
162 };
163
164 if let Some(m) = re.find(s) {
166 let m_str = m.as_str();
167 if m_str.len() > MAX_SINGLE_MATCH {
168 return Err(DataFusionError::Execution(
169 "REGEXP_EXTRACT match exceeds per-row limit (1MB)".to_string(),
170 ));
171 }
172 builder.append_value(m_str);
173 } else {
174 builder.append_null();
175 }
176 }
177
178 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
179 }
180}
181
182fn compile_regex_checked(pattern: &str) -> datafusion_common::Result<Regex> {
184 if pattern.len() > MAX_PATTERN_LEN {
185 return Err(DataFusionError::Execution(format!(
186 "REGEXP_EXTRACT pattern too long (> {} chars)",
187 MAX_PATTERN_LEN
188 )));
189 }
190 RegexBuilder::new(pattern)
191 .size_limit(MAX_REGEX_SIZE)
192 .dfa_size_limit(MAX_DFA_SIZE)
193 .build()
194 .map_err(|e| {
195 DataFusionError::Execution(format!("REGEXP_EXTRACT invalid pattern '{}': {e}", pattern))
196 })
197}
198
199#[cfg(test)]
200mod tests {
201 use datafusion_common::arrow::array::StringArray;
202 use datafusion_common::arrow::datatypes::Field;
203 use datafusion_expr::ScalarFunctionArgs;
204
205 use super::*;
206
207 #[test]
208 fn test_regexp_extract_function_basic() {
209 let text_array = Arc::new(StringArray::from(vec!["version 1.2.3", "no match here"]));
210 let pattern_array = Arc::new(StringArray::from(vec!["\\d+\\.\\d+\\.\\d+", "\\d+"]));
211
212 let args = ScalarFunctionArgs {
213 args: vec![
214 ColumnarValue::Array(text_array),
215 ColumnarValue::Array(pattern_array),
216 ],
217 arg_fields: vec![
218 Arc::new(Field::new("arg_0", DataType::Utf8, false)),
219 Arc::new(Field::new("arg_1", DataType::Utf8, false)),
220 ],
221 return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
222 number_rows: 2,
223 config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
224 };
225
226 let function = RegexpExtractFunction::default();
227 let result = function.invoke_with_args(args).unwrap();
228
229 if let ColumnarValue::Array(array) = result {
230 let string_array = array.as_string::<i64>();
231 assert_eq!(string_array.value(0), "1.2.3");
232 assert!(string_array.is_null(1)); } else {
234 panic!("Expected array result");
235 }
236 }
237
238 #[test]
239 fn test_regexp_extract_phone_number() {
240 let text_array = Arc::new(StringArray::from(vec!["Phone: 123-456-7890", "No phone"]));
241 let pattern_array = Arc::new(StringArray::from(vec![
242 "\\d{3}-\\d{3}-\\d{4}",
243 "\\d{3}-\\d{3}-\\d{4}",
244 ]));
245
246 let args = ScalarFunctionArgs {
247 args: vec![
248 ColumnarValue::Array(text_array),
249 ColumnarValue::Array(pattern_array),
250 ],
251 arg_fields: vec![
252 Arc::new(Field::new("arg_0", DataType::Utf8, false)),
253 Arc::new(Field::new("arg_1", DataType::Utf8, false)),
254 ],
255 return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
256 number_rows: 2,
257 config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
258 };
259
260 let function = RegexpExtractFunction::default();
261 let result = function.invoke_with_args(args).unwrap();
262
263 if let ColumnarValue::Array(array) = result {
264 let string_array = array.as_string::<i64>();
265 assert_eq!(string_array.value(0), "123-456-7890");
266 assert!(string_array.is_null(1)); } else {
268 panic!("Expected array result");
269 }
270 }
271
272 #[test]
273 fn test_regexp_extract_email() {
274 let text_array = Arc::new(StringArray::from(vec![
275 "Email: user@domain.com",
276 "Invalid email",
277 ]));
278 let pattern_array = Arc::new(StringArray::from(vec![
279 "[a-zA-Z0-9]+@[a-zA-Z0-9]+\\.[a-zA-Z]+",
280 "[a-zA-Z0-9]+@[a-zA-Z0-9]+\\.[a-zA-Z]+",
281 ]));
282
283 let args = ScalarFunctionArgs {
284 args: vec![
285 ColumnarValue::Array(text_array),
286 ColumnarValue::Array(pattern_array),
287 ],
288 arg_fields: vec![
289 Arc::new(Field::new("arg_0", DataType::Utf8, false)),
290 Arc::new(Field::new("arg_1", DataType::Utf8, false)),
291 ],
292 return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
293 number_rows: 2,
294 config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
295 };
296
297 let function = RegexpExtractFunction::default();
298 let result = function.invoke_with_args(args).unwrap();
299
300 if let ColumnarValue::Array(array) = result {
301 let string_array = array.as_string::<i64>();
302 assert_eq!(string_array.value(0), "user@domain.com");
303 assert!(string_array.is_null(1)); } else {
305 panic!("Expected array result");
306 }
307 }
308
309 #[test]
310 fn test_regexp_extract_with_nulls() {
311 let text_array = Arc::new(StringArray::from(vec![Some("test 123"), None]));
312 let pattern_array = Arc::new(StringArray::from(vec![Some("\\d+"), Some("\\d+")]));
313
314 let args = ScalarFunctionArgs {
315 args: vec![
316 ColumnarValue::Array(text_array),
317 ColumnarValue::Array(pattern_array),
318 ],
319 arg_fields: vec![
320 Arc::new(Field::new("arg_0", DataType::Utf8, true)),
321 Arc::new(Field::new("arg_1", DataType::Utf8, false)),
322 ],
323 return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
324 number_rows: 2,
325 config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
326 };
327
328 let function = RegexpExtractFunction::default();
329 let result = function.invoke_with_args(args).unwrap();
330
331 if let ColumnarValue::Array(array) = result {
332 let string_array = array.as_string::<i64>();
333 assert_eq!(string_array.value(0), "123");
334 assert!(string_array.is_null(1)); } else {
336 panic!("Expected array result");
337 }
338 }
339}