common_function/scalars/ip/
cidr.rs1use std::net::{Ipv4Addr, Ipv6Addr};
16use std::str::FromStr;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion_common::types;
20use datafusion_expr::{Coercion, Signature, TypeSignature, TypeSignatureClass, Volatility};
21use datatypes::arrow::datatypes::DataType;
22use datatypes::prelude::Value;
23use datatypes::scalars::ScalarVectorBuilder;
24use datatypes::vectors::{MutableVector, StringVectorBuilder, VectorRef};
25use derive_more::Display;
26use snafu::ensure;
27
28use crate::function::{Function, FunctionContext};
29
30#[derive(Clone, Debug, Default, Display)]
40#[display("{}", self.name())]
41pub struct Ipv4ToCidr;
42
43impl Function for Ipv4ToCidr {
44 fn name(&self) -> &str {
45 "ipv4_to_cidr"
46 }
47
48 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
49 Ok(DataType::Utf8)
50 }
51
52 fn signature(&self) -> Signature {
53 Signature::one_of(
54 vec![
55 TypeSignature::String(1),
56 TypeSignature::Coercible(vec![
57 Coercion::new_exact(TypeSignatureClass::Native(types::logical_string())),
58 Coercion::new_exact(TypeSignatureClass::Integer),
59 ]),
60 ],
61 Volatility::Immutable,
62 )
63 }
64
65 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
66 ensure!(
67 columns.len() == 1 || columns.len() == 2,
68 InvalidFuncArgsSnafu {
69 err_msg: format!("Expected 1 or 2 arguments, got {}", columns.len())
70 }
71 );
72
73 let ip_vec = &columns[0];
74 let mut results = StringVectorBuilder::with_capacity(ip_vec.len());
75
76 let has_subnet_arg = columns.len() == 2;
77 let subnet_vec = if has_subnet_arg {
78 ensure!(
79 columns[1].len() == ip_vec.len(),
80 InvalidFuncArgsSnafu {
81 err_msg:
82 "Subnet mask must have the same number of elements as the IP addresses"
83 .to_string()
84 }
85 );
86 Some(&columns[1])
87 } else {
88 None
89 };
90
91 for i in 0..ip_vec.len() {
92 let ip_str = ip_vec.get(i);
93 let subnet = subnet_vec.map(|v| v.get(i));
94
95 let cidr = match (ip_str, subnet) {
96 (Value::String(s), Some(Value::UInt8(mask))) => {
97 let ip_str = s.as_utf8().trim();
98 if ip_str.is_empty() {
99 return InvalidFuncArgsSnafu {
100 err_msg: "Empty IPv4 address".to_string(),
101 }
102 .fail();
103 }
104
105 let ip_addr = complete_and_parse_ipv4(ip_str)?;
106 let mask_bits = u32::MAX.wrapping_shl(32 - mask as u32);
108 let masked_ip = Ipv4Addr::from(u32::from(ip_addr) & mask_bits);
109
110 Some(format!("{}/{}", masked_ip, mask))
111 }
112 (Value::String(s), None) => {
113 let ip_str = s.as_utf8().trim();
114 if ip_str.is_empty() {
115 return InvalidFuncArgsSnafu {
116 err_msg: "Empty IPv4 address".to_string(),
117 }
118 .fail();
119 }
120
121 let ip_addr = complete_and_parse_ipv4(ip_str)?;
122
123 let ip_bits = u32::from(ip_addr);
125 let dots = ip_str.chars().filter(|&c| c == '.').count();
126
127 let subnet_mask = match dots {
128 0 => 8, 1 => 16, 2 => 24, _ => {
132 let trailing_zeros = ip_bits.trailing_zeros();
134 if trailing_zeros % 8 == 0 {
136 32 - trailing_zeros.min(32) as u8
137 } else {
138 32 - (trailing_zeros as u8 / 8) * 8
139 }
140 }
141 };
142
143 let mask_bits = u32::MAX.wrapping_shl(32 - subnet_mask as u32);
145 let masked_ip = Ipv4Addr::from(ip_bits & mask_bits);
146
147 Some(format!("{}/{}", masked_ip, subnet_mask))
148 }
149 _ => None,
150 };
151
152 results.push(cidr.as_deref());
153 }
154
155 Ok(results.to_vector())
156 }
157}
158
159#[derive(Clone, Debug, Default, Display)]
169#[display("{}", self.name())]
170pub struct Ipv6ToCidr;
171
172impl Function for Ipv6ToCidr {
173 fn name(&self) -> &str {
174 "ipv6_to_cidr"
175 }
176
177 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
178 Ok(DataType::Utf8)
179 }
180
181 fn signature(&self) -> Signature {
182 Signature::one_of(
183 vec![
184 TypeSignature::String(1),
185 TypeSignature::Exact(vec![DataType::Utf8, DataType::UInt8]),
186 ],
187 Volatility::Immutable,
188 )
189 }
190
191 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
192 ensure!(
193 columns.len() == 1 || columns.len() == 2,
194 InvalidFuncArgsSnafu {
195 err_msg: format!("Expected 1 or 2 arguments, got {}", columns.len())
196 }
197 );
198
199 let ip_vec = &columns[0];
200 let size = ip_vec.len();
201 let mut results = StringVectorBuilder::with_capacity(size);
202
203 let has_subnet_arg = columns.len() == 2;
204 let subnet_vec = if has_subnet_arg {
205 Some(&columns[1])
206 } else {
207 None
208 };
209
210 for i in 0..size {
211 let ip_str = ip_vec.get(i);
212 let subnet = subnet_vec.map(|v| v.get(i));
213
214 let cidr = match (ip_str, subnet) {
215 (Value::String(s), Some(Value::UInt8(mask))) => {
216 let ip_str = s.as_utf8().trim();
217 if ip_str.is_empty() {
218 return InvalidFuncArgsSnafu {
219 err_msg: "Empty IPv6 address".to_string(),
220 }
221 .fail();
222 }
223
224 let ip_addr = complete_and_parse_ipv6(ip_str)?;
225
226 let masked_ip = mask_ipv6(&ip_addr, mask);
228
229 Some(format!("{}/{}", masked_ip, mask))
230 }
231 (Value::String(s), None) => {
232 let ip_str = s.as_utf8().trim();
233 if ip_str.is_empty() {
234 return InvalidFuncArgsSnafu {
235 err_msg: "Empty IPv6 address".to_string(),
236 }
237 .fail();
238 }
239
240 let ip_addr = complete_and_parse_ipv6(ip_str)?;
241
242 let subnet_mask = auto_detect_ipv6_subnet(&ip_addr);
244
245 let masked_ip = mask_ipv6(&ip_addr, subnet_mask);
247
248 Some(format!("{}/{}", masked_ip, subnet_mask))
249 }
250 _ => None,
251 };
252
253 results.push(cidr.as_deref());
254 }
255
256 Ok(results.to_vector())
257 }
258}
259
260fn complete_and_parse_ipv4(ip_str: &str) -> Result<Ipv4Addr> {
263 if let Ok(addr) = Ipv4Addr::from_str(ip_str) {
265 return Ok(addr);
266 }
267
268 let dots = ip_str.chars().filter(|&c| c == '.').count();
270
271 let completed = match dots {
273 0 => format!("{}.0.0.0", ip_str),
274 1 => format!("{}.0.0", ip_str),
275 2 => format!("{}.0", ip_str),
276 _ => ip_str.to_string(),
277 };
278
279 Ipv4Addr::from_str(&completed).map_err(|_| {
280 InvalidFuncArgsSnafu {
281 err_msg: format!("Invalid IPv4 address: {}", ip_str),
282 }
283 .build()
284 })
285}
286
287fn complete_and_parse_ipv6(ip_str: &str) -> Result<Ipv6Addr> {
288 if let Ok(addr) = Ipv6Addr::from_str(ip_str) {
290 return Ok(addr);
291 }
292
293 let completed = if ip_str.ends_with(':') {
296 format!("{}:", ip_str)
297 } else if !ip_str.contains("::") {
298 format!("{}::", ip_str)
299 } else {
300 ip_str.to_string()
301 };
302
303 Ipv6Addr::from_str(&completed).map_err(|_| {
304 InvalidFuncArgsSnafu {
305 err_msg: format!("Invalid IPv6 address: {}", ip_str),
306 }
307 .build()
308 })
309}
310
311fn mask_ipv6(addr: &Ipv6Addr, subnet: u8) -> Ipv6Addr {
312 let octets = addr.octets();
313 let mut result = [0u8; 16];
314
315 for i in 0..16 {
317 let bit_pos = i * 8;
318 if bit_pos < subnet as usize {
319 if bit_pos + 8 <= subnet as usize {
320 result[i] = octets[i];
322 } else {
323 let shift = 8 - (subnet as usize - bit_pos);
325 result[i] = octets[i] & (0xFF << shift);
326 }
327 }
328 }
330
331 Ipv6Addr::from(result)
332}
333
334fn auto_detect_ipv6_subnet(addr: &Ipv6Addr) -> u8 {
335 let segments = addr.segments();
336 let str_addr = addr.to_string();
337
338 if str_addr.starts_with("2001:db8::") || str_addr.starts_with("2001:db8:") {
341 return 32;
342 }
343
344 if str_addr == "::1" {
345 return 128; }
347
348 if str_addr.starts_with("fe80::") {
349 return 16; }
351
352 let mut subnet = 128;
354 for i in (0..8).rev() {
355 if segments[i] != 0 {
356 if segments[i] & 0xFF == 0 {
358 subnet = (i * 16) + 8;
360 } else {
361 subnet = (i + 1) * 16; }
364 break;
365 }
366 }
367
368 if subnet < 16 {
370 subnet = 64;
371 }
372
373 subnet as u8
374}
375
376#[cfg(test)]
377mod tests {
378 use std::sync::Arc;
379
380 use datatypes::scalars::ScalarVector;
381 use datatypes::vectors::{StringVector, UInt8Vector};
382
383 use super::*;
384
385 #[test]
386 fn test_ipv4_to_cidr_auto() {
387 let func = Ipv4ToCidr;
388 let ctx = FunctionContext::default();
389
390 let values = vec!["192.168.1.0", "10.0.0.0", "172.16", "192"];
392 let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
393
394 let result = func.eval(&ctx, &[input]).unwrap();
395 let result = result.as_any().downcast_ref::<StringVector>().unwrap();
396
397 assert_eq!(result.get_data(0).unwrap(), "192.168.1.0/24");
398 assert_eq!(result.get_data(1).unwrap(), "10.0.0.0/8");
399 assert_eq!(result.get_data(2).unwrap(), "172.16.0.0/16");
400 assert_eq!(result.get_data(3).unwrap(), "192.0.0.0/8");
401 }
402
403 #[test]
404 fn test_ipv4_to_cidr_with_subnet() {
405 let func = Ipv4ToCidr;
406 let ctx = FunctionContext::default();
407
408 let ip_values = vec!["192.168.1.1", "10.0.0.1", "172.16.5.5"];
410 let subnet_values = vec![24u8, 16u8, 12u8];
411 let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
412 let subnet_input = Arc::new(UInt8Vector::from_vec(subnet_values)) as VectorRef;
413
414 let result = func.eval(&ctx, &[ip_input, subnet_input]).unwrap();
415 let result = result.as_any().downcast_ref::<StringVector>().unwrap();
416
417 assert_eq!(result.get_data(0).unwrap(), "192.168.1.0/24");
418 assert_eq!(result.get_data(1).unwrap(), "10.0.0.0/16");
419 assert_eq!(result.get_data(2).unwrap(), "172.16.0.0/12");
420 }
421
422 #[test]
423 fn test_ipv6_to_cidr_auto() {
424 let func = Ipv6ToCidr;
425 let ctx = FunctionContext::default();
426
427 let values = vec!["2001:db8::", "2001:db8", "fe80::1", "::1"];
429 let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
430
431 let result = func.eval(&ctx, &[input]).unwrap();
432 let result = result.as_any().downcast_ref::<StringVector>().unwrap();
433
434 assert_eq!(result.get_data(0).unwrap(), "2001:db8::/32");
435 assert_eq!(result.get_data(1).unwrap(), "2001:db8::/32");
436 assert_eq!(result.get_data(2).unwrap(), "fe80::/16");
437 assert_eq!(result.get_data(3).unwrap(), "::1/128"); }
439
440 #[test]
441 fn test_ipv6_to_cidr_with_subnet() {
442 let func = Ipv6ToCidr;
443 let ctx = FunctionContext::default();
444
445 let ip_values = vec!["2001:db8::", "fe80::1", "2001:db8:1234::"];
447 let subnet_values = vec![48u8, 10u8, 56u8];
448 let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
449 let subnet_input = Arc::new(UInt8Vector::from_vec(subnet_values)) as VectorRef;
450
451 let result = func.eval(&ctx, &[ip_input, subnet_input]).unwrap();
452 let result = result.as_any().downcast_ref::<StringVector>().unwrap();
453
454 assert_eq!(result.get_data(0).unwrap(), "2001:db8::/48");
455 assert_eq!(result.get_data(1).unwrap(), "fe80::/10");
456 assert_eq!(result.get_data(2).unwrap(), "2001:db8:1234::/56");
457 }
458
459 #[test]
460 fn test_invalid_inputs() {
461 let ipv4_func = Ipv4ToCidr;
462 let ipv6_func = Ipv6ToCidr;
463 let ctx = FunctionContext::default();
464
465 let empty_values = vec![""];
467 let empty_input = Arc::new(StringVector::from_slice(&empty_values)) as VectorRef;
468
469 let ipv4_result = ipv4_func.eval(&ctx, std::slice::from_ref(&empty_input));
470 let ipv6_result = ipv6_func.eval(&ctx, std::slice::from_ref(&empty_input));
471
472 assert!(ipv4_result.is_err());
473 assert!(ipv6_result.is_err());
474
475 let invalid_values = vec!["not an ip", "192.168.1.256", "zzzz::ffff"];
477 let invalid_input = Arc::new(StringVector::from_slice(&invalid_values)) as VectorRef;
478
479 let ipv4_result = ipv4_func.eval(&ctx, std::slice::from_ref(&invalid_input));
480
481 assert!(ipv4_result.is_err());
482 }
483}