common_function/scalars/ip/
cidr.rs1use std::net::{Ipv4Addr, Ipv6Addr};
16use std::str::FromStr;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::{Signature, TypeSignature};
20use datafusion::logical_expr::Volatility;
21use datatypes::prelude::{ConcreteDataType, Value};
22use datatypes::scalars::ScalarVectorBuilder;
23use datatypes::vectors::{MutableVector, StringVectorBuilder, VectorRef};
24use derive_more::Display;
25use snafu::ensure;
26
27use crate::function::{Function, FunctionContext};
28
29#[derive(Clone, Debug, Default, Display)]
39#[display("{}", self.name())]
40pub struct Ipv4ToCidr;
41
42impl Function for Ipv4ToCidr {
43 fn name(&self) -> &str {
44 "ipv4_to_cidr"
45 }
46
47 fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
48 Ok(ConcreteDataType::string_datatype())
49 }
50
51 fn signature(&self) -> Signature {
52 Signature::one_of(
53 vec![
54 TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
55 TypeSignature::Exact(vec![
56 ConcreteDataType::string_datatype(),
57 ConcreteDataType::uint8_datatype(),
58 ]),
59 ],
60 Volatility::Immutable,
61 )
62 }
63
64 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
65 ensure!(
66 columns.len() == 1 || columns.len() == 2,
67 InvalidFuncArgsSnafu {
68 err_msg: format!("Expected 1 or 2 arguments, got {}", columns.len())
69 }
70 );
71
72 let ip_vec = &columns[0];
73 let mut results = StringVectorBuilder::with_capacity(ip_vec.len());
74
75 let has_subnet_arg = columns.len() == 2;
76 let subnet_vec = if has_subnet_arg {
77 ensure!(
78 columns[1].len() == ip_vec.len(),
79 InvalidFuncArgsSnafu {
80 err_msg:
81 "Subnet mask must have the same number of elements as the IP addresses"
82 .to_string()
83 }
84 );
85 Some(&columns[1])
86 } else {
87 None
88 };
89
90 for i in 0..ip_vec.len() {
91 let ip_str = ip_vec.get(i);
92 let subnet = subnet_vec.map(|v| v.get(i));
93
94 let cidr = match (ip_str, subnet) {
95 (Value::String(s), Some(Value::UInt8(mask))) => {
96 let ip_str = s.as_utf8().trim();
97 if ip_str.is_empty() {
98 return InvalidFuncArgsSnafu {
99 err_msg: "Empty IPv4 address".to_string(),
100 }
101 .fail();
102 }
103
104 let ip_addr = complete_and_parse_ipv4(ip_str)?;
105 let mask_bits = u32::MAX.wrapping_shl(32 - mask as u32);
107 let masked_ip = Ipv4Addr::from(u32::from(ip_addr) & mask_bits);
108
109 Some(format!("{}/{}", masked_ip, mask))
110 }
111 (Value::String(s), None) => {
112 let ip_str = s.as_utf8().trim();
113 if ip_str.is_empty() {
114 return InvalidFuncArgsSnafu {
115 err_msg: "Empty IPv4 address".to_string(),
116 }
117 .fail();
118 }
119
120 let ip_addr = complete_and_parse_ipv4(ip_str)?;
121
122 let ip_bits = u32::from(ip_addr);
124 let dots = ip_str.chars().filter(|&c| c == '.').count();
125
126 let subnet_mask = match dots {
127 0 => 8, 1 => 16, 2 => 24, _ => {
131 let trailing_zeros = ip_bits.trailing_zeros();
133 if trailing_zeros % 8 == 0 {
135 32 - trailing_zeros.min(32) as u8
136 } else {
137 32 - (trailing_zeros as u8 / 8) * 8
138 }
139 }
140 };
141
142 let mask_bits = u32::MAX.wrapping_shl(32 - subnet_mask as u32);
144 let masked_ip = Ipv4Addr::from(ip_bits & mask_bits);
145
146 Some(format!("{}/{}", masked_ip, subnet_mask))
147 }
148 _ => None,
149 };
150
151 results.push(cidr.as_deref());
152 }
153
154 Ok(results.to_vector())
155 }
156}
157
158#[derive(Clone, Debug, Default, Display)]
168#[display("{}", self.name())]
169pub struct Ipv6ToCidr;
170
171impl Function for Ipv6ToCidr {
172 fn name(&self) -> &str {
173 "ipv6_to_cidr"
174 }
175
176 fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
177 Ok(ConcreteDataType::string_datatype())
178 }
179
180 fn signature(&self) -> Signature {
181 Signature::one_of(
182 vec![
183 TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
184 TypeSignature::Exact(vec![
185 ConcreteDataType::string_datatype(),
186 ConcreteDataType::uint8_datatype(),
187 ]),
188 ],
189 Volatility::Immutable,
190 )
191 }
192
193 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
194 ensure!(
195 columns.len() == 1 || columns.len() == 2,
196 InvalidFuncArgsSnafu {
197 err_msg: format!("Expected 1 or 2 arguments, got {}", columns.len())
198 }
199 );
200
201 let ip_vec = &columns[0];
202 let size = ip_vec.len();
203 let mut results = StringVectorBuilder::with_capacity(size);
204
205 let has_subnet_arg = columns.len() == 2;
206 let subnet_vec = if has_subnet_arg {
207 Some(&columns[1])
208 } else {
209 None
210 };
211
212 for i in 0..size {
213 let ip_str = ip_vec.get(i);
214 let subnet = subnet_vec.map(|v| v.get(i));
215
216 let cidr = match (ip_str, subnet) {
217 (Value::String(s), Some(Value::UInt8(mask))) => {
218 let ip_str = s.as_utf8().trim();
219 if ip_str.is_empty() {
220 return InvalidFuncArgsSnafu {
221 err_msg: "Empty IPv6 address".to_string(),
222 }
223 .fail();
224 }
225
226 let ip_addr = complete_and_parse_ipv6(ip_str)?;
227
228 let masked_ip = mask_ipv6(&ip_addr, mask);
230
231 Some(format!("{}/{}", masked_ip, mask))
232 }
233 (Value::String(s), None) => {
234 let ip_str = s.as_utf8().trim();
235 if ip_str.is_empty() {
236 return InvalidFuncArgsSnafu {
237 err_msg: "Empty IPv6 address".to_string(),
238 }
239 .fail();
240 }
241
242 let ip_addr = complete_and_parse_ipv6(ip_str)?;
243
244 let subnet_mask = auto_detect_ipv6_subnet(&ip_addr);
246
247 let masked_ip = mask_ipv6(&ip_addr, subnet_mask);
249
250 Some(format!("{}/{}", masked_ip, subnet_mask))
251 }
252 _ => None,
253 };
254
255 results.push(cidr.as_deref());
256 }
257
258 Ok(results.to_vector())
259 }
260}
261
262fn complete_and_parse_ipv4(ip_str: &str) -> Result<Ipv4Addr> {
265 if let Ok(addr) = Ipv4Addr::from_str(ip_str) {
267 return Ok(addr);
268 }
269
270 let dots = ip_str.chars().filter(|&c| c == '.').count();
272
273 let completed = match dots {
275 0 => format!("{}.0.0.0", ip_str),
276 1 => format!("{}.0.0", ip_str),
277 2 => format!("{}.0", ip_str),
278 _ => ip_str.to_string(),
279 };
280
281 Ipv4Addr::from_str(&completed).map_err(|_| {
282 InvalidFuncArgsSnafu {
283 err_msg: format!("Invalid IPv4 address: {}", ip_str),
284 }
285 .build()
286 })
287}
288
289fn complete_and_parse_ipv6(ip_str: &str) -> Result<Ipv6Addr> {
290 if let Ok(addr) = Ipv6Addr::from_str(ip_str) {
292 return Ok(addr);
293 }
294
295 let completed = if ip_str.ends_with(':') {
298 format!("{}:", ip_str)
299 } else if !ip_str.contains("::") {
300 format!("{}::", ip_str)
301 } else {
302 ip_str.to_string()
303 };
304
305 Ipv6Addr::from_str(&completed).map_err(|_| {
306 InvalidFuncArgsSnafu {
307 err_msg: format!("Invalid IPv6 address: {}", ip_str),
308 }
309 .build()
310 })
311}
312
313fn mask_ipv6(addr: &Ipv6Addr, subnet: u8) -> Ipv6Addr {
314 let octets = addr.octets();
315 let mut result = [0u8; 16];
316
317 for i in 0..16 {
319 let bit_pos = i * 8;
320 if bit_pos < subnet as usize {
321 if bit_pos + 8 <= subnet as usize {
322 result[i] = octets[i];
324 } else {
325 let shift = 8 - (subnet as usize - bit_pos);
327 result[i] = octets[i] & (0xFF << shift);
328 }
329 }
330 }
332
333 Ipv6Addr::from(result)
334}
335
336fn auto_detect_ipv6_subnet(addr: &Ipv6Addr) -> u8 {
337 let segments = addr.segments();
338 let str_addr = addr.to_string();
339
340 if str_addr.starts_with("2001:db8::") || str_addr.starts_with("2001:db8:") {
343 return 32;
344 }
345
346 if str_addr == "::1" {
347 return 128; }
349
350 if str_addr.starts_with("fe80::") {
351 return 16; }
353
354 let mut subnet = 128;
356 for i in (0..8).rev() {
357 if segments[i] != 0 {
358 if segments[i] & 0xFF == 0 {
360 subnet = (i * 16) + 8;
362 } else {
363 subnet = (i + 1) * 16; }
366 break;
367 }
368 }
369
370 if subnet < 16 {
372 subnet = 64;
373 }
374
375 subnet as u8
376}
377
378#[cfg(test)]
379mod tests {
380 use std::sync::Arc;
381
382 use datatypes::scalars::ScalarVector;
383 use datatypes::vectors::{StringVector, UInt8Vector};
384
385 use super::*;
386
387 #[test]
388 fn test_ipv4_to_cidr_auto() {
389 let func = Ipv4ToCidr;
390 let ctx = FunctionContext::default();
391
392 let values = vec!["192.168.1.0", "10.0.0.0", "172.16", "192"];
394 let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
395
396 let result = func.eval(&ctx, &[input]).unwrap();
397 let result = result.as_any().downcast_ref::<StringVector>().unwrap();
398
399 assert_eq!(result.get_data(0).unwrap(), "192.168.1.0/24");
400 assert_eq!(result.get_data(1).unwrap(), "10.0.0.0/8");
401 assert_eq!(result.get_data(2).unwrap(), "172.16.0.0/16");
402 assert_eq!(result.get_data(3).unwrap(), "192.0.0.0/8");
403 }
404
405 #[test]
406 fn test_ipv4_to_cidr_with_subnet() {
407 let func = Ipv4ToCidr;
408 let ctx = FunctionContext::default();
409
410 let ip_values = vec!["192.168.1.1", "10.0.0.1", "172.16.5.5"];
412 let subnet_values = vec![24u8, 16u8, 12u8];
413 let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
414 let subnet_input = Arc::new(UInt8Vector::from_vec(subnet_values)) as VectorRef;
415
416 let result = func.eval(&ctx, &[ip_input, subnet_input]).unwrap();
417 let result = result.as_any().downcast_ref::<StringVector>().unwrap();
418
419 assert_eq!(result.get_data(0).unwrap(), "192.168.1.0/24");
420 assert_eq!(result.get_data(1).unwrap(), "10.0.0.0/16");
421 assert_eq!(result.get_data(2).unwrap(), "172.16.0.0/12");
422 }
423
424 #[test]
425 fn test_ipv6_to_cidr_auto() {
426 let func = Ipv6ToCidr;
427 let ctx = FunctionContext::default();
428
429 let values = vec!["2001:db8::", "2001:db8", "fe80::1", "::1"];
431 let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
432
433 let result = func.eval(&ctx, &[input]).unwrap();
434 let result = result.as_any().downcast_ref::<StringVector>().unwrap();
435
436 assert_eq!(result.get_data(0).unwrap(), "2001:db8::/32");
437 assert_eq!(result.get_data(1).unwrap(), "2001:db8::/32");
438 assert_eq!(result.get_data(2).unwrap(), "fe80::/16");
439 assert_eq!(result.get_data(3).unwrap(), "::1/128"); }
441
442 #[test]
443 fn test_ipv6_to_cidr_with_subnet() {
444 let func = Ipv6ToCidr;
445 let ctx = FunctionContext::default();
446
447 let ip_values = vec!["2001:db8::", "fe80::1", "2001:db8:1234::"];
449 let subnet_values = vec![48u8, 10u8, 56u8];
450 let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
451 let subnet_input = Arc::new(UInt8Vector::from_vec(subnet_values)) as VectorRef;
452
453 let result = func.eval(&ctx, &[ip_input, subnet_input]).unwrap();
454 let result = result.as_any().downcast_ref::<StringVector>().unwrap();
455
456 assert_eq!(result.get_data(0).unwrap(), "2001:db8::/48");
457 assert_eq!(result.get_data(1).unwrap(), "fe80::/10");
458 assert_eq!(result.get_data(2).unwrap(), "2001:db8:1234::/56");
459 }
460
461 #[test]
462 fn test_invalid_inputs() {
463 let ipv4_func = Ipv4ToCidr;
464 let ipv6_func = Ipv6ToCidr;
465 let ctx = FunctionContext::default();
466
467 let empty_values = vec![""];
469 let empty_input = Arc::new(StringVector::from_slice(&empty_values)) as VectorRef;
470
471 let ipv4_result = ipv4_func.eval(&ctx, &[empty_input.clone()]);
472 let ipv6_result = ipv6_func.eval(&ctx, &[empty_input.clone()]);
473
474 assert!(ipv4_result.is_err());
475 assert!(ipv6_result.is_err());
476
477 let invalid_values = vec!["not an ip", "192.168.1.256", "zzzz::ffff"];
479 let invalid_input = Arc::new(StringVector::from_slice(&invalid_values)) as VectorRef;
480
481 let ipv4_result = ipv4_func.eval(&ctx, &[invalid_input.clone()]);
482
483 assert!(ipv4_result.is_err());
484 }
485}