common_function/scalars/ip/
range.rs1use std::net::{Ipv4Addr, Ipv6Addr};
16use std::str::FromStr;
17use std::sync::Arc;
18
19use common_query::error::{InvalidFuncArgsSnafu, Result};
20use datafusion_common::DataFusionError;
21use datafusion_common::arrow::array::{Array, AsArray, BooleanBuilder};
22use datafusion_common::arrow::compute;
23use datafusion_common::arrow::datatypes::DataType;
24use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
25use derive_more::Display;
26use snafu::ensure;
27
28use crate::function::{Function, extract_args};
29
30#[derive(Clone, Debug, Display)]
40#[display("{}", self.name())]
41pub(crate) struct Ipv4InRange {
42 signature: Signature,
43}
44
45impl Default for Ipv4InRange {
46 fn default() -> Self {
47 Self {
48 signature: Signature::string(2, Volatility::Immutable),
49 }
50 }
51}
52
53impl Function for Ipv4InRange {
54 fn name(&self) -> &str {
55 "ipv4_in_range"
56 }
57
58 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
59 Ok(DataType::Boolean)
60 }
61
62 fn signature(&self) -> &Signature {
63 &self.signature
64 }
65
66 fn invoke_with_args(
67 &self,
68 args: ScalarFunctionArgs,
69 ) -> datafusion_common::Result<ColumnarValue> {
70 let [arg0, arg1] = extract_args(self.name(), &args)?;
71
72 let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
73 let ip_vec = arg0.as_string_view();
74 let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
75 let ranges = arg1.as_string_view();
76
77 let size = ip_vec.len();
78
79 let mut builder = BooleanBuilder::with_capacity(size);
80
81 for i in 0..size {
82 let ip = ip_vec.is_valid(i).then(|| ip_vec.value(i));
83 let range = ranges.is_valid(i).then(|| ranges.value(i));
84
85 let in_range = match (ip, range) {
86 (Some(ip_str), Some(range_str)) => {
87 if ip_str.is_empty() || range_str.is_empty() {
88 return Err(DataFusionError::Execution(
89 "IP address or CIDR range cannot be empty".to_string(),
90 ));
91 }
92
93 let ip_addr = Ipv4Addr::from_str(ip_str).map_err(|_| {
95 InvalidFuncArgsSnafu {
96 err_msg: format!("Invalid IPv4 address: {}", ip_str),
97 }
98 .build()
99 })?;
100
101 let (cidr_ip, cidr_prefix) = parse_ipv4_cidr(range_str)?;
103
104 is_ipv4_in_range(&ip_addr, &cidr_ip, cidr_prefix)
106 }
107 _ => None,
108 };
109
110 builder.append_option(in_range);
111 }
112
113 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
114 }
115}
116
117#[derive(Clone, Debug, Display)]
128#[display("{}", self.name())]
129pub(crate) struct Ipv6InRange {
130 signature: Signature,
131}
132
133impl Default for Ipv6InRange {
134 fn default() -> Self {
135 Self {
136 signature: Signature::string(2, Volatility::Immutable),
137 }
138 }
139}
140
141impl Function for Ipv6InRange {
142 fn name(&self) -> &str {
143 "ipv6_in_range"
144 }
145
146 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
147 Ok(DataType::Boolean)
148 }
149
150 fn signature(&self) -> &Signature {
151 &self.signature
152 }
153
154 fn invoke_with_args(
155 &self,
156 args: ScalarFunctionArgs,
157 ) -> datafusion_common::Result<ColumnarValue> {
158 let [arg0, arg1] = extract_args(self.name(), &args)?;
159
160 let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
161 let ip_vec = arg0.as_string_view();
162 let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
163 let ranges = arg1.as_string_view();
164 let size = ip_vec.len();
165 let mut builder = BooleanBuilder::with_capacity(size);
166
167 for i in 0..size {
168 let ip = ip_vec.is_valid(i).then(|| ip_vec.value(i));
169 let range = ranges.is_valid(i).then(|| ranges.value(i));
170
171 let in_range = match (ip, range) {
172 (Some(ip_str), Some(range_str)) => {
173 if ip_str.is_empty() || range_str.is_empty() {
174 return Err(DataFusionError::Execution(
175 "IP address or CIDR range cannot be empty".to_string(),
176 ));
177 }
178
179 let ip_addr = Ipv6Addr::from_str(ip_str).map_err(|_| {
181 InvalidFuncArgsSnafu {
182 err_msg: format!("Invalid IPv6 address: {}", ip_str),
183 }
184 .build()
185 })?;
186
187 let (cidr_ip, cidr_prefix) = parse_ipv6_cidr(range_str)?;
189
190 is_ipv6_in_range(&ip_addr, &cidr_ip, cidr_prefix)
192 }
193 _ => None,
194 };
195
196 builder.append_option(in_range);
197 }
198
199 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
200 }
201}
202
203fn parse_ipv4_cidr(cidr: &str) -> Result<(Ipv4Addr, u8)> {
206 let parts: Vec<&str> = cidr.split('/').collect();
208 ensure!(
209 parts.len() == 2,
210 InvalidFuncArgsSnafu {
211 err_msg: format!("Invalid CIDR notation: {}", cidr),
212 }
213 );
214
215 let ip = Ipv4Addr::from_str(parts[0]).map_err(|_| {
217 InvalidFuncArgsSnafu {
218 err_msg: format!("Invalid IPv4 address in CIDR: {}", parts[0]),
219 }
220 .build()
221 })?;
222
223 let prefix = parts[1].parse::<u8>().map_err(|_| {
225 InvalidFuncArgsSnafu {
226 err_msg: format!("Invalid prefix length: {}", parts[1]),
227 }
228 .build()
229 })?;
230
231 ensure!(
232 prefix <= 32,
233 InvalidFuncArgsSnafu {
234 err_msg: format!("IPv4 prefix length must be <= 32, got {}", prefix),
235 }
236 );
237
238 Ok((ip, prefix))
239}
240
241fn parse_ipv6_cidr(cidr: &str) -> Result<(Ipv6Addr, u8)> {
242 let parts: Vec<&str> = cidr.split('/').collect();
244 ensure!(
245 parts.len() == 2,
246 InvalidFuncArgsSnafu {
247 err_msg: format!("Invalid CIDR notation: {}", cidr),
248 }
249 );
250
251 let ip = Ipv6Addr::from_str(parts[0]).map_err(|_| {
253 InvalidFuncArgsSnafu {
254 err_msg: format!("Invalid IPv6 address in CIDR: {}", parts[0]),
255 }
256 .build()
257 })?;
258
259 let prefix = parts[1].parse::<u8>().map_err(|_| {
261 InvalidFuncArgsSnafu {
262 err_msg: format!("Invalid prefix length: {}", parts[1]),
263 }
264 .build()
265 })?;
266
267 ensure!(
268 prefix <= 128,
269 InvalidFuncArgsSnafu {
270 err_msg: format!("IPv6 prefix length must be <= 128, got {}", prefix),
271 }
272 );
273
274 Ok((ip, prefix))
275}
276
277fn is_ipv4_in_range(ip: &Ipv4Addr, cidr_base: &Ipv4Addr, prefix_len: u8) -> Option<bool> {
278 let ip_int = u32::from(*ip);
280 let cidr_int = u32::from(*cidr_base);
281
282 let mask = if prefix_len == 0 {
284 0
285 } else {
286 u32::MAX << (32 - prefix_len)
287 };
288
289 let ip_network = ip_int & mask;
291 let cidr_network = cidr_int & mask;
292
293 Some(ip_network == cidr_network)
294}
295
296fn is_ipv6_in_range(ip: &Ipv6Addr, cidr_base: &Ipv6Addr, prefix_len: u8) -> Option<bool> {
297 let ip_octets = ip.octets();
299 let cidr_octets = cidr_base.octets();
300
301 let full_bytes = (prefix_len / 8) as usize;
303
304 for i in 0..full_bytes {
306 if ip_octets[i] != cidr_octets[i] {
307 return Some(false);
308 }
309 }
310
311 if !prefix_len.is_multiple_of(8) && full_bytes < 16 {
313 let bits_to_check = prefix_len % 8;
314 let mask = 0xFF_u8 << (8 - bits_to_check);
315
316 if (ip_octets[full_bytes] & mask) != (cidr_octets[full_bytes] & mask) {
317 return Some(false);
318 }
319 }
320
321 Some(true)
323}
324
325#[cfg(test)]
326mod tests {
327 use std::sync::Arc;
328
329 use arrow_schema::Field;
330 use datafusion_common::arrow::array::StringViewArray;
331
332 use super::*;
333
334 #[test]
335 fn test_ipv4_in_range() {
336 let func = Ipv4InRange::default();
337
338 let ip_values = vec![
340 "192.168.1.5",
341 "192.168.2.1",
342 "10.0.0.1",
343 "10.1.0.1",
344 "172.16.0.1",
345 ];
346
347 let cidr_values = vec![
349 "192.168.1.0/24",
350 "192.168.1.0/24",
351 "10.0.0.0/8",
352 "10.0.0.0/8",
353 "172.16.0.0/16",
354 ];
355
356 let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
357 let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values)));
358
359 let args = ScalarFunctionArgs {
360 args: vec![arg0, arg1],
361 arg_fields: vec![],
362 number_rows: 5,
363 return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
364 config_options: Arc::new(Default::default()),
365 };
366 let result = func.invoke_with_args(args).unwrap();
367 let result = result.to_array(5).unwrap();
368 let result = result.as_boolean();
369
370 assert!(result.value(0)); assert!(!result.value(1)); assert!(result.value(2)); assert!(result.value(3)); assert!(result.value(4)); }
377
378 #[test]
379 fn test_ipv6_in_range() {
380 let func = Ipv6InRange::default();
381
382 let ip_values = vec![
384 "2001:db8::1",
385 "2001:db8:1::",
386 "2001:db9::1",
387 "::1",
388 "fe80::1",
389 ];
390
391 let cidr_values = vec![
393 "2001:db8::/32",
394 "2001:db8::/32",
395 "2001:db8::/32",
396 "::1/128",
397 "fe80::/16",
398 ];
399
400 let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
401 let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values)));
402
403 let args = ScalarFunctionArgs {
404 args: vec![arg0, arg1],
405 arg_fields: vec![],
406 number_rows: 5,
407 return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
408 config_options: Arc::new(Default::default()),
409 };
410 let result = func.invoke_with_args(args).unwrap();
411 let result = result.to_array(5).unwrap();
412 let result = result.as_boolean();
413
414 assert!(result.value(0)); assert!(result.value(1)); assert!(!result.value(2)); assert!(result.value(3)); assert!(result.value(4)); }
421
422 #[test]
423 fn test_invalid_inputs() {
424 let ipv4_func = Ipv4InRange::default();
425 let ipv6_func = Ipv6InRange::default();
426
427 let invalid_ip_values = vec!["not-an-ip", "192.168.1.300"];
429 let cidr_values = vec!["192.168.1.0/24", "192.168.1.0/24"];
430
431 let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(
432 &invalid_ip_values,
433 )));
434 let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values)));
435
436 let args = ScalarFunctionArgs {
437 args: vec![arg0, arg1],
438 arg_fields: vec![],
439 number_rows: 2,
440 return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
441 config_options: Arc::new(Default::default()),
442 };
443 let result = ipv4_func.invoke_with_args(args);
444 assert!(result.is_err());
445
446 let ip_values = vec!["192.168.1.1", "2001:db8::1"];
448 let invalid_cidr_values = vec!["192.168.1.0", "2001:db8::/129"];
449
450 let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
451 let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(
452 &invalid_cidr_values,
453 )));
454
455 let args = ScalarFunctionArgs {
456 args: vec![arg0, arg1],
457 arg_fields: vec![],
458 number_rows: 2,
459 return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
460 config_options: Arc::new(Default::default()),
461 };
462 let ipv4_result = ipv4_func.invoke_with_args(args.clone());
463 let ipv6_result = ipv6_func.invoke_with_args(args);
464
465 assert!(ipv4_result.is_err());
466 assert!(ipv6_result.is_err());
467 }
468
469 #[test]
470 fn test_edge_cases() {
471 let ipv4_func = Ipv4InRange::default();
472
473 let ip_values = vec!["8.8.8.8", "192.168.1.1", "192.168.1.1"];
475 let cidr_values = vec!["0.0.0.0/0", "192.168.1.1/32", "192.168.1.0/32"];
476
477 let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
478 let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values)));
479
480 let args = ScalarFunctionArgs {
481 args: vec![arg0, arg1],
482 arg_fields: vec![],
483 number_rows: 3,
484 return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
485 config_options: Arc::new(Default::default()),
486 };
487 let result = ipv4_func.invoke_with_args(args).unwrap();
488 let result = result.to_array(3).unwrap();
489 let result = result.as_boolean();
490
491 assert!(result.value(0)); assert!(result.value(1)); assert!(!result.value(2)); }
495}