common_function/scalars/ip/
cidr.rs1use std::net::{Ipv4Addr, Ipv6Addr};
16use std::str::FromStr;
17use std::sync::Arc;
18
19use common_query::error::{InvalidFuncArgsSnafu, Result};
20use datafusion_common::arrow::array::{Array, AsArray, StringViewBuilder};
21use datafusion_common::arrow::compute;
22use datafusion_common::arrow::datatypes::{DataType, UInt8Type};
23use datafusion_common::{DataFusionError, types};
24use datafusion_expr::{
25 Coercion, ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, TypeSignatureClass,
26 Volatility,
27};
28use derive_more::Display;
29use snafu::ensure;
30
31use crate::function::Function;
32
33#[derive(Clone, Debug, Display)]
43#[display("{}", self.name())]
44pub(crate) struct Ipv4ToCidr {
45 signature: Signature,
46}
47
48impl Default for Ipv4ToCidr {
49 fn default() -> Self {
50 Self {
51 signature: Signature::one_of(
52 vec![
53 TypeSignature::String(1),
54 TypeSignature::Coercible(vec![
55 Coercion::new_exact(TypeSignatureClass::Native(types::logical_string())),
56 Coercion::new_exact(TypeSignatureClass::Integer),
57 ]),
58 ],
59 Volatility::Immutable,
60 ),
61 }
62 }
63}
64
65impl Function for Ipv4ToCidr {
66 fn name(&self) -> &str {
67 "ipv4_to_cidr"
68 }
69
70 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
71 Ok(DataType::Utf8View)
72 }
73
74 fn signature(&self) -> &Signature {
75 &self.signature
76 }
77
78 fn invoke_with_args(
79 &self,
80 args: ScalarFunctionArgs,
81 ) -> datafusion_common::Result<ColumnarValue> {
82 if args.args.len() != 1 && args.args.len() != 2 {
83 return Err(DataFusionError::Execution(format!(
84 "expecting 1 or 2 arguments, got {}",
85 args.args.len()
86 )));
87 }
88 let columns = ColumnarValue::values_to_arrays(&args.args)?;
89
90 let ip_vec = &columns[0];
91 let mut builder = StringViewBuilder::with_capacity(ip_vec.len());
92 let arg0 = compute::cast(ip_vec, &DataType::Utf8View)?;
93 let ip_vec = arg0.as_string_view();
94
95 let maybe_arg1 = if columns.len() > 1 {
96 Some(compute::cast(&columns[1], &DataType::UInt8)?)
97 } else {
98 None
99 };
100 let subnets = if let Some(arg1) = maybe_arg1.as_ref() {
101 ensure!(
102 columns[1].len() == ip_vec.len(),
103 InvalidFuncArgsSnafu {
104 err_msg:
105 "Subnet mask must have the same number of elements as the IP addresses"
106 .to_string()
107 }
108 );
109 Some(arg1.as_primitive::<UInt8Type>())
110 } else {
111 None
112 };
113
114 for i in 0..ip_vec.len() {
115 let ip_str = ip_vec.is_valid(i).then(|| ip_vec.value(i));
116 let subnet = subnets.and_then(|v| v.is_valid(i).then(|| v.value(i)));
117
118 let cidr = match (ip_str, subnet) {
119 (Some(ip_str), Some(mask)) => {
120 if ip_str.is_empty() {
121 return Err(DataFusionError::Execution("empty IPv4 address".to_string()));
122 }
123
124 let ip_addr = complete_and_parse_ipv4(ip_str)?;
125 let mask_bits = u32::MAX.wrapping_shl(32 - mask as u32);
127 let masked_ip = Ipv4Addr::from(u32::from(ip_addr) & mask_bits);
128
129 Some(format!("{}/{}", masked_ip, mask))
130 }
131 (Some(ip_str), None) => {
132 if ip_str.is_empty() {
133 return Err(DataFusionError::Execution("empty IPv4 address".to_string()));
134 }
135
136 let ip_addr = complete_and_parse_ipv4(ip_str)?;
137
138 let ip_bits = u32::from(ip_addr);
140 let dots = ip_str.chars().filter(|&c| c == '.').count();
141
142 let subnet_mask = match dots {
143 0 => 8, 1 => 16, 2 => 24, _ => {
147 let trailing_zeros = ip_bits.trailing_zeros();
149 if trailing_zeros % 8 == 0 {
151 32 - trailing_zeros.min(32) as u8
152 } else {
153 32 - (trailing_zeros as u8 / 8) * 8
154 }
155 }
156 };
157
158 let mask_bits = u32::MAX.wrapping_shl(32 - subnet_mask as u32);
160 let masked_ip = Ipv4Addr::from(ip_bits & mask_bits);
161
162 Some(format!("{}/{}", masked_ip, subnet_mask))
163 }
164 _ => None,
165 };
166
167 builder.append_option(cidr.as_deref());
168 }
169
170 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
171 }
172}
173
174#[derive(Clone, Debug, Display)]
184#[display("{}", self.name())]
185pub(crate) struct Ipv6ToCidr {
186 signature: Signature,
187}
188
189impl Default for Ipv6ToCidr {
190 fn default() -> Self {
191 Self {
192 signature: Signature::one_of(
193 vec![
194 TypeSignature::String(1),
195 TypeSignature::Exact(vec![DataType::Utf8, DataType::UInt8]),
196 ],
197 Volatility::Immutable,
198 ),
199 }
200 }
201}
202
203impl Function for Ipv6ToCidr {
204 fn name(&self) -> &str {
205 "ipv6_to_cidr"
206 }
207
208 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
209 Ok(DataType::Utf8View)
210 }
211
212 fn signature(&self) -> &Signature {
213 &self.signature
214 }
215
216 fn invoke_with_args(
217 &self,
218 args: ScalarFunctionArgs,
219 ) -> datafusion_common::Result<ColumnarValue> {
220 if args.args.len() != 1 && args.args.len() != 2 {
221 return Err(DataFusionError::Execution(format!(
222 "expecting 1 or 2 arguments, got {}",
223 args.args.len()
224 )));
225 }
226 let columns = ColumnarValue::values_to_arrays(&args.args)?;
227
228 let ip_vec = &columns[0];
229 let size = ip_vec.len();
230 let mut builder = StringViewBuilder::with_capacity(size);
231 let arg0 = compute::cast(ip_vec, &DataType::Utf8View)?;
232 let ip_vec = arg0.as_string_view();
233
234 let maybe_arg1 = if columns.len() > 1 {
235 Some(compute::cast(&columns[1], &DataType::UInt8)?)
236 } else {
237 None
238 };
239 let subnets = maybe_arg1
240 .as_ref()
241 .map(|arg1| arg1.as_primitive::<UInt8Type>());
242
243 for i in 0..size {
244 let ip_str = ip_vec.is_valid(i).then(|| ip_vec.value(i));
245 let subnet = subnets.and_then(|v| v.is_valid(i).then(|| v.value(i)));
246
247 let cidr = match (ip_str, subnet) {
248 (Some(ip_str), Some(mask)) => {
249 if ip_str.is_empty() {
250 return Err(DataFusionError::Execution("empty IPv6 address".to_string()));
251 }
252
253 let ip_addr = complete_and_parse_ipv6(ip_str)?;
254
255 let masked_ip = mask_ipv6(&ip_addr, mask);
257
258 Some(format!("{}/{}", masked_ip, mask))
259 }
260 (Some(ip_str), None) => {
261 if ip_str.is_empty() {
262 return Err(DataFusionError::Execution("empty IPv6 address".to_string()));
263 }
264
265 let ip_addr = complete_and_parse_ipv6(ip_str)?;
266
267 let subnet_mask = auto_detect_ipv6_subnet(&ip_addr);
269
270 let masked_ip = mask_ipv6(&ip_addr, subnet_mask);
272
273 Some(format!("{}/{}", masked_ip, subnet_mask))
274 }
275 _ => None,
276 };
277
278 builder.append_option(cidr.as_deref());
279 }
280
281 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
282 }
283}
284
285fn complete_and_parse_ipv4(ip_str: &str) -> Result<Ipv4Addr> {
288 if let Ok(addr) = Ipv4Addr::from_str(ip_str) {
290 return Ok(addr);
291 }
292
293 let dots = ip_str.chars().filter(|&c| c == '.').count();
295
296 let completed = match dots {
298 0 => format!("{}.0.0.0", ip_str),
299 1 => format!("{}.0.0", ip_str),
300 2 => format!("{}.0", ip_str),
301 _ => ip_str.to_string(),
302 };
303
304 Ipv4Addr::from_str(&completed).map_err(|_| {
305 InvalidFuncArgsSnafu {
306 err_msg: format!("Invalid IPv4 address: {}", ip_str),
307 }
308 .build()
309 })
310}
311
312fn complete_and_parse_ipv6(ip_str: &str) -> Result<Ipv6Addr> {
313 if let Ok(addr) = Ipv6Addr::from_str(ip_str) {
315 return Ok(addr);
316 }
317
318 let completed = if ip_str.ends_with(':') {
321 format!("{}:", ip_str)
322 } else if !ip_str.contains("::") {
323 format!("{}::", ip_str)
324 } else {
325 ip_str.to_string()
326 };
327
328 Ipv6Addr::from_str(&completed).map_err(|_| {
329 InvalidFuncArgsSnafu {
330 err_msg: format!("Invalid IPv6 address: {}", ip_str),
331 }
332 .build()
333 })
334}
335
336fn mask_ipv6(addr: &Ipv6Addr, subnet: u8) -> Ipv6Addr {
337 let octets = addr.octets();
338 let mut result = [0u8; 16];
339
340 for i in 0..16 {
342 let bit_pos = i * 8;
343 if bit_pos < subnet as usize {
344 if bit_pos + 8 <= subnet as usize {
345 result[i] = octets[i];
347 } else {
348 let shift = 8 - (subnet as usize - bit_pos);
350 result[i] = octets[i] & (0xFF << shift);
351 }
352 }
353 }
355
356 Ipv6Addr::from(result)
357}
358
359fn auto_detect_ipv6_subnet(addr: &Ipv6Addr) -> u8 {
360 let segments = addr.segments();
361 let str_addr = addr.to_string();
362
363 if str_addr.starts_with("2001:db8::") || str_addr.starts_with("2001:db8:") {
366 return 32;
367 }
368
369 if str_addr == "::1" {
370 return 128; }
372
373 if str_addr.starts_with("fe80::") {
374 return 16; }
376
377 let mut subnet = 128;
379 for i in (0..8).rev() {
380 if segments[i] != 0 {
381 if segments[i] & 0xFF == 0 {
383 subnet = (i * 16) + 8;
385 } else {
386 subnet = (i + 1) * 16; }
389 break;
390 }
391 }
392
393 if subnet < 16 {
395 subnet = 64;
396 }
397
398 subnet as u8
399}
400
401#[cfg(test)]
402mod tests {
403 use arrow_schema::Field;
404 use datafusion_common::arrow::array::{StringViewArray, UInt8Array};
405
406 use super::*;
407
408 #[test]
409 fn test_ipv4_to_cidr_auto() {
410 let func = Ipv4ToCidr::default();
411
412 let values = vec!["192.168.1.0", "10.0.0.0", "172.16", "192"];
414 let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
415
416 let args = ScalarFunctionArgs {
417 args: vec![arg0],
418 arg_fields: vec![],
419 number_rows: 4,
420 return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
421 config_options: Arc::new(Default::default()),
422 };
423 let result = func.invoke_with_args(args).unwrap();
424 let result = result.to_array(4).unwrap();
425 let result = result.as_string_view();
426
427 assert_eq!(result.value(0), "192.168.1.0/24");
428 assert_eq!(result.value(1), "10.0.0.0/8");
429 assert_eq!(result.value(2), "172.16.0.0/16");
430 assert_eq!(result.value(3), "192.0.0.0/8");
431 }
432
433 #[test]
434 fn test_ipv4_to_cidr_with_subnet() {
435 let func = Ipv4ToCidr::default();
436
437 let ip_values = vec!["192.168.1.1", "10.0.0.1", "172.16.5.5"];
439 let subnet_values = vec![24u8, 16u8, 12u8];
440 let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
441 let arg1 = ColumnarValue::Array(Arc::new(UInt8Array::from(subnet_values)));
442
443 let args = ScalarFunctionArgs {
444 args: vec![arg0, arg1],
445 arg_fields: vec![],
446 number_rows: 3,
447 return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
448 config_options: Arc::new(Default::default()),
449 };
450 let result = func.invoke_with_args(args).unwrap();
451 let result = result.to_array(3).unwrap();
452 let result = result.as_string_view();
453
454 assert_eq!(result.value(0), "192.168.1.0/24");
455 assert_eq!(result.value(1), "10.0.0.0/16");
456 assert_eq!(result.value(2), "172.16.0.0/12");
457 }
458
459 #[test]
460 fn test_ipv6_to_cidr_auto() {
461 let func = Ipv6ToCidr::default();
462
463 let values = vec!["2001:db8::", "2001:db8", "fe80::1", "::1"];
465 let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
466
467 let args = ScalarFunctionArgs {
468 args: vec![arg0],
469 arg_fields: vec![],
470 number_rows: 4,
471 return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
472 config_options: Arc::new(Default::default()),
473 };
474 let result = func.invoke_with_args(args).unwrap();
475 let result = result.to_array(4).unwrap();
476 let result = result.as_string_view();
477
478 assert_eq!(result.value(0), "2001:db8::/32");
479 assert_eq!(result.value(1), "2001:db8::/32");
480 assert_eq!(result.value(2), "fe80::/16");
481 assert_eq!(result.value(3), "::1/128"); }
483
484 #[test]
485 fn test_ipv6_to_cidr_with_subnet() {
486 let func = Ipv6ToCidr::default();
487
488 let ip_values = vec!["2001:db8::", "fe80::1", "2001:db8:1234::"];
490 let subnet_values = vec![48u8, 10u8, 56u8];
491 let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
492 let arg1 = ColumnarValue::Array(Arc::new(UInt8Array::from(subnet_values)));
493
494 let args = ScalarFunctionArgs {
495 args: vec![arg0, arg1],
496 arg_fields: vec![],
497 number_rows: 3,
498 return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
499 config_options: Arc::new(Default::default()),
500 };
501 let result = func.invoke_with_args(args).unwrap();
502 let result = result.to_array(3).unwrap();
503 let result = result.as_string_view();
504
505 assert_eq!(result.value(0), "2001:db8::/48");
506 assert_eq!(result.value(1), "fe80::/10");
507 assert_eq!(result.value(2), "2001:db8:1234::/56");
508 }
509
510 #[test]
511 fn test_invalid_inputs() {
512 let ipv4_func = Ipv4ToCidr::default();
513 let ipv6_func = Ipv6ToCidr::default();
514
515 let empty_values = vec![""];
517 let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&empty_values)));
518
519 let args = ScalarFunctionArgs {
520 args: vec![arg0],
521 arg_fields: vec![],
522 number_rows: 1,
523 return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
524 config_options: Arc::new(Default::default()),
525 };
526 let ipv4_result = ipv4_func.invoke_with_args(args.clone());
527 let ipv6_result = ipv6_func.invoke_with_args(args);
528
529 assert!(ipv4_result.is_err());
530 assert!(ipv6_result.is_err());
531
532 let invalid_values = vec!["not an ip", "192.168.1.256", "zzzz::ffff"];
534 let arg0 =
535 ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&invalid_values)));
536
537 let args = ScalarFunctionArgs {
538 args: vec![arg0],
539 arg_fields: vec![],
540 number_rows: 3,
541 return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
542 config_options: Arc::new(Default::default()),
543 };
544 let ipv4_result = ipv4_func.invoke_with_args(args);
545
546 assert!(ipv4_result.is_err());
547 }
548}