common_function/scalars/ip/
range.rs1use std::net::{Ipv4Addr, Ipv6Addr};
16use std::str::FromStr;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion::arrow::datatypes::DataType;
20use datafusion_expr::{Signature, Volatility};
21use datatypes::prelude::Value;
22use datatypes::scalars::ScalarVectorBuilder;
23use datatypes::vectors::{BooleanVectorBuilder, MutableVector, VectorRef};
24use derive_more::Display;
25use snafu::ensure;
26
27use crate::function::{Function, FunctionContext};
28
29#[derive(Clone, Debug, Default, Display)]
39#[display("{}", self.name())]
40pub struct Ipv4InRange;
41
42impl Function for Ipv4InRange {
43 fn name(&self) -> &str {
44 "ipv4_in_range"
45 }
46
47 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
48 Ok(DataType::Boolean)
49 }
50
51 fn signature(&self) -> Signature {
52 Signature::string(2, Volatility::Immutable)
53 }
54
55 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
56 ensure!(
57 columns.len() == 2,
58 InvalidFuncArgsSnafu {
59 err_msg: format!("Expected 2 arguments, got {}", columns.len())
60 }
61 );
62
63 let ip_vec = &columns[0];
64 let range_vec = &columns[1];
65 let size = ip_vec.len();
66
67 ensure!(
68 range_vec.len() == size,
69 InvalidFuncArgsSnafu {
70 err_msg: "IP addresses and CIDR ranges must have the same number of rows"
71 .to_string()
72 }
73 );
74
75 let mut results = BooleanVectorBuilder::with_capacity(size);
76
77 for i in 0..size {
78 let ip = ip_vec.get(i);
79 let range = range_vec.get(i);
80
81 let in_range = match (ip, range) {
82 (Value::String(ip_str), Value::String(range_str)) => {
83 let ip_str = ip_str.as_utf8().trim();
84 let range_str = range_str.as_utf8().trim();
85
86 if ip_str.is_empty() || range_str.is_empty() {
87 return InvalidFuncArgsSnafu {
88 err_msg: "IP address and CIDR range cannot be empty".to_string(),
89 }
90 .fail();
91 }
92
93 let ip_addr = Ipv4Addr::from_str(ip_str).map_err(|_| {
95 InvalidFuncArgsSnafu {
96 err_msg: format!("Invalid IPv4 address: {}", ip_str),
97 }
98 .build()
99 })?;
100
101 let (cidr_ip, cidr_prefix) = parse_ipv4_cidr(range_str)?;
103
104 is_ipv4_in_range(&ip_addr, &cidr_ip, cidr_prefix)
106 }
107 _ => None,
108 };
109
110 results.push(in_range);
111 }
112
113 Ok(results.to_vector())
114 }
115}
116
117#[derive(Clone, Debug, Default, Display)]
128#[display("{}", self.name())]
129pub struct Ipv6InRange;
130
131impl Function for Ipv6InRange {
132 fn name(&self) -> &str {
133 "ipv6_in_range"
134 }
135
136 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
137 Ok(DataType::Boolean)
138 }
139
140 fn signature(&self) -> Signature {
141 Signature::string(2, Volatility::Immutable)
142 }
143
144 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
145 ensure!(
146 columns.len() == 2,
147 InvalidFuncArgsSnafu {
148 err_msg: format!("Expected 2 arguments, got {}", columns.len())
149 }
150 );
151
152 let ip_vec = &columns[0];
153 let range_vec = &columns[1];
154 let size = ip_vec.len();
155
156 ensure!(
157 range_vec.len() == size,
158 InvalidFuncArgsSnafu {
159 err_msg: "IP addresses and CIDR ranges must have the same number of rows"
160 .to_string()
161 }
162 );
163
164 let mut results = BooleanVectorBuilder::with_capacity(size);
165
166 for i in 0..size {
167 let ip = ip_vec.get(i);
168 let range = range_vec.get(i);
169
170 let in_range = match (ip, range) {
171 (Value::String(ip_str), Value::String(range_str)) => {
172 let ip_str = ip_str.as_utf8().trim();
173 let range_str = range_str.as_utf8().trim();
174
175 if ip_str.is_empty() || range_str.is_empty() {
176 return InvalidFuncArgsSnafu {
177 err_msg: "IP address and CIDR range cannot be empty".to_string(),
178 }
179 .fail();
180 }
181
182 let ip_addr = Ipv6Addr::from_str(ip_str).map_err(|_| {
184 InvalidFuncArgsSnafu {
185 err_msg: format!("Invalid IPv6 address: {}", ip_str),
186 }
187 .build()
188 })?;
189
190 let (cidr_ip, cidr_prefix) = parse_ipv6_cidr(range_str)?;
192
193 is_ipv6_in_range(&ip_addr, &cidr_ip, cidr_prefix)
195 }
196 _ => None,
197 };
198
199 results.push(in_range);
200 }
201
202 Ok(results.to_vector())
203 }
204}
205
206fn parse_ipv4_cidr(cidr: &str) -> Result<(Ipv4Addr, u8)> {
209 let parts: Vec<&str> = cidr.split('/').collect();
211 ensure!(
212 parts.len() == 2,
213 InvalidFuncArgsSnafu {
214 err_msg: format!("Invalid CIDR notation: {}", cidr),
215 }
216 );
217
218 let ip = Ipv4Addr::from_str(parts[0]).map_err(|_| {
220 InvalidFuncArgsSnafu {
221 err_msg: format!("Invalid IPv4 address in CIDR: {}", parts[0]),
222 }
223 .build()
224 })?;
225
226 let prefix = parts[1].parse::<u8>().map_err(|_| {
228 InvalidFuncArgsSnafu {
229 err_msg: format!("Invalid prefix length: {}", parts[1]),
230 }
231 .build()
232 })?;
233
234 ensure!(
235 prefix <= 32,
236 InvalidFuncArgsSnafu {
237 err_msg: format!("IPv4 prefix length must be <= 32, got {}", prefix),
238 }
239 );
240
241 Ok((ip, prefix))
242}
243
244fn parse_ipv6_cidr(cidr: &str) -> Result<(Ipv6Addr, u8)> {
245 let parts: Vec<&str> = cidr.split('/').collect();
247 ensure!(
248 parts.len() == 2,
249 InvalidFuncArgsSnafu {
250 err_msg: format!("Invalid CIDR notation: {}", cidr),
251 }
252 );
253
254 let ip = Ipv6Addr::from_str(parts[0]).map_err(|_| {
256 InvalidFuncArgsSnafu {
257 err_msg: format!("Invalid IPv6 address in CIDR: {}", parts[0]),
258 }
259 .build()
260 })?;
261
262 let prefix = parts[1].parse::<u8>().map_err(|_| {
264 InvalidFuncArgsSnafu {
265 err_msg: format!("Invalid prefix length: {}", parts[1]),
266 }
267 .build()
268 })?;
269
270 ensure!(
271 prefix <= 128,
272 InvalidFuncArgsSnafu {
273 err_msg: format!("IPv6 prefix length must be <= 128, got {}", prefix),
274 }
275 );
276
277 Ok((ip, prefix))
278}
279
280fn is_ipv4_in_range(ip: &Ipv4Addr, cidr_base: &Ipv4Addr, prefix_len: u8) -> Option<bool> {
281 let ip_int = u32::from(*ip);
283 let cidr_int = u32::from(*cidr_base);
284
285 let mask = if prefix_len == 0 {
287 0
288 } else {
289 u32::MAX << (32 - prefix_len)
290 };
291
292 let ip_network = ip_int & mask;
294 let cidr_network = cidr_int & mask;
295
296 Some(ip_network == cidr_network)
297}
298
299fn is_ipv6_in_range(ip: &Ipv6Addr, cidr_base: &Ipv6Addr, prefix_len: u8) -> Option<bool> {
300 let ip_octets = ip.octets();
302 let cidr_octets = cidr_base.octets();
303
304 let full_bytes = (prefix_len / 8) as usize;
306
307 for i in 0..full_bytes {
309 if ip_octets[i] != cidr_octets[i] {
310 return Some(false);
311 }
312 }
313
314 if prefix_len % 8 != 0 && full_bytes < 16 {
316 let bits_to_check = prefix_len % 8;
317 let mask = 0xFF_u8 << (8 - bits_to_check);
318
319 if (ip_octets[full_bytes] & mask) != (cidr_octets[full_bytes] & mask) {
320 return Some(false);
321 }
322 }
323
324 Some(true)
326}
327
328#[cfg(test)]
329mod tests {
330 use std::sync::Arc;
331
332 use datatypes::scalars::ScalarVector;
333 use datatypes::vectors::{BooleanVector, StringVector};
334
335 use super::*;
336
337 #[test]
338 fn test_ipv4_in_range() {
339 let func = Ipv4InRange;
340 let ctx = FunctionContext::default();
341
342 let ip_values = vec![
344 "192.168.1.5",
345 "192.168.2.1",
346 "10.0.0.1",
347 "10.1.0.1",
348 "172.16.0.1",
349 ];
350
351 let cidr_values = vec![
353 "192.168.1.0/24",
354 "192.168.1.0/24",
355 "10.0.0.0/8",
356 "10.0.0.0/8",
357 "172.16.0.0/16",
358 ];
359
360 let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
361 let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
362
363 let result = func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
364 let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
365
366 assert!(result.get_data(0).unwrap()); assert!(!result.get_data(1).unwrap()); assert!(result.get_data(2).unwrap()); assert!(result.get_data(3).unwrap()); assert!(result.get_data(4).unwrap()); }
373
374 #[test]
375 fn test_ipv6_in_range() {
376 let func = Ipv6InRange;
377 let ctx = FunctionContext::default();
378
379 let ip_values = vec![
381 "2001:db8::1",
382 "2001:db8:1::",
383 "2001:db9::1",
384 "::1",
385 "fe80::1",
386 ];
387
388 let cidr_values = vec![
390 "2001:db8::/32",
391 "2001:db8::/32",
392 "2001:db8::/32",
393 "::1/128",
394 "fe80::/16",
395 ];
396
397 let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
398 let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
399
400 let result = func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
401 let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
402
403 assert!(result.get_data(0).unwrap()); assert!(result.get_data(1).unwrap()); assert!(!result.get_data(2).unwrap()); assert!(result.get_data(3).unwrap()); assert!(result.get_data(4).unwrap()); }
410
411 #[test]
412 fn test_invalid_inputs() {
413 let ipv4_func = Ipv4InRange;
414 let ipv6_func = Ipv6InRange;
415 let ctx = FunctionContext::default();
416
417 let invalid_ip_values = vec!["not-an-ip", "192.168.1.300"];
419 let cidr_values = vec!["192.168.1.0/24", "192.168.1.0/24"];
420
421 let invalid_ip_input = Arc::new(StringVector::from_slice(&invalid_ip_values)) as VectorRef;
422 let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
423
424 let result = ipv4_func.eval(&ctx, &[invalid_ip_input, cidr_input]);
425 assert!(result.is_err());
426
427 let ip_values = vec!["192.168.1.1", "2001:db8::1"];
429 let invalid_cidr_values = vec!["192.168.1.0", "2001:db8::/129"];
430
431 let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
432 let invalid_cidr_input =
433 Arc::new(StringVector::from_slice(&invalid_cidr_values)) as VectorRef;
434
435 let ipv4_result = ipv4_func.eval(&ctx, &[ip_input.clone(), invalid_cidr_input.clone()]);
436 let ipv6_result = ipv6_func.eval(&ctx, &[ip_input, invalid_cidr_input]);
437
438 assert!(ipv4_result.is_err());
439 assert!(ipv6_result.is_err());
440 }
441
442 #[test]
443 fn test_edge_cases() {
444 let ipv4_func = Ipv4InRange;
445 let ctx = FunctionContext::default();
446
447 let ip_values = vec!["8.8.8.8", "192.168.1.1", "192.168.1.1"];
449 let cidr_values = vec!["0.0.0.0/0", "192.168.1.1/32", "192.168.1.0/32"];
450
451 let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
452 let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
453
454 let result = ipv4_func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
455 let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
456
457 assert!(result.get_data(0).unwrap()); assert!(result.get_data(1).unwrap()); assert!(!result.get_data(2).unwrap()); }
461}