common_function/scalars/ip/
range.rs1use std::net::{Ipv4Addr, Ipv6Addr};
16use std::str::FromStr;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::{Signature, TypeSignature};
20use datafusion::logical_expr::Volatility;
21use datatypes::prelude::{ConcreteDataType, Value};
22use datatypes::scalars::ScalarVectorBuilder;
23use datatypes::vectors::{BooleanVectorBuilder, MutableVector, VectorRef};
24use derive_more::Display;
25use snafu::ensure;
26
27use crate::function::{Function, FunctionContext};
28
29#[derive(Clone, Debug, Default, Display)]
39#[display("{}", self.name())]
40pub struct Ipv4InRange;
41
42impl Function for Ipv4InRange {
43 fn name(&self) -> &str {
44 "ipv4_in_range"
45 }
46
47 fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
48 Ok(ConcreteDataType::boolean_datatype())
49 }
50
51 fn signature(&self) -> Signature {
52 Signature::new(
53 TypeSignature::Exact(vec![
54 ConcreteDataType::string_datatype(),
55 ConcreteDataType::string_datatype(),
56 ]),
57 Volatility::Immutable,
58 )
59 }
60
61 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
62 ensure!(
63 columns.len() == 2,
64 InvalidFuncArgsSnafu {
65 err_msg: format!("Expected 2 arguments, got {}", columns.len())
66 }
67 );
68
69 let ip_vec = &columns[0];
70 let range_vec = &columns[1];
71 let size = ip_vec.len();
72
73 ensure!(
74 range_vec.len() == size,
75 InvalidFuncArgsSnafu {
76 err_msg: "IP addresses and CIDR ranges must have the same number of rows"
77 .to_string()
78 }
79 );
80
81 let mut results = BooleanVectorBuilder::with_capacity(size);
82
83 for i in 0..size {
84 let ip = ip_vec.get(i);
85 let range = range_vec.get(i);
86
87 let in_range = match (ip, range) {
88 (Value::String(ip_str), Value::String(range_str)) => {
89 let ip_str = ip_str.as_utf8().trim();
90 let range_str = range_str.as_utf8().trim();
91
92 if ip_str.is_empty() || range_str.is_empty() {
93 return InvalidFuncArgsSnafu {
94 err_msg: "IP address and CIDR range cannot be empty".to_string(),
95 }
96 .fail();
97 }
98
99 let ip_addr = Ipv4Addr::from_str(ip_str).map_err(|_| {
101 InvalidFuncArgsSnafu {
102 err_msg: format!("Invalid IPv4 address: {}", ip_str),
103 }
104 .build()
105 })?;
106
107 let (cidr_ip, cidr_prefix) = parse_ipv4_cidr(range_str)?;
109
110 is_ipv4_in_range(&ip_addr, &cidr_ip, cidr_prefix)
112 }
113 _ => None,
114 };
115
116 results.push(in_range);
117 }
118
119 Ok(results.to_vector())
120 }
121}
122
123#[derive(Clone, Debug, Default, Display)]
134#[display("{}", self.name())]
135pub struct Ipv6InRange;
136
137impl Function for Ipv6InRange {
138 fn name(&self) -> &str {
139 "ipv6_in_range"
140 }
141
142 fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
143 Ok(ConcreteDataType::boolean_datatype())
144 }
145
146 fn signature(&self) -> Signature {
147 Signature::new(
148 TypeSignature::Exact(vec![
149 ConcreteDataType::string_datatype(),
150 ConcreteDataType::string_datatype(),
151 ]),
152 Volatility::Immutable,
153 )
154 }
155
156 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
157 ensure!(
158 columns.len() == 2,
159 InvalidFuncArgsSnafu {
160 err_msg: format!("Expected 2 arguments, got {}", columns.len())
161 }
162 );
163
164 let ip_vec = &columns[0];
165 let range_vec = &columns[1];
166 let size = ip_vec.len();
167
168 ensure!(
169 range_vec.len() == size,
170 InvalidFuncArgsSnafu {
171 err_msg: "IP addresses and CIDR ranges must have the same number of rows"
172 .to_string()
173 }
174 );
175
176 let mut results = BooleanVectorBuilder::with_capacity(size);
177
178 for i in 0..size {
179 let ip = ip_vec.get(i);
180 let range = range_vec.get(i);
181
182 let in_range = match (ip, range) {
183 (Value::String(ip_str), Value::String(range_str)) => {
184 let ip_str = ip_str.as_utf8().trim();
185 let range_str = range_str.as_utf8().trim();
186
187 if ip_str.is_empty() || range_str.is_empty() {
188 return InvalidFuncArgsSnafu {
189 err_msg: "IP address and CIDR range cannot be empty".to_string(),
190 }
191 .fail();
192 }
193
194 let ip_addr = Ipv6Addr::from_str(ip_str).map_err(|_| {
196 InvalidFuncArgsSnafu {
197 err_msg: format!("Invalid IPv6 address: {}", ip_str),
198 }
199 .build()
200 })?;
201
202 let (cidr_ip, cidr_prefix) = parse_ipv6_cidr(range_str)?;
204
205 is_ipv6_in_range(&ip_addr, &cidr_ip, cidr_prefix)
207 }
208 _ => None,
209 };
210
211 results.push(in_range);
212 }
213
214 Ok(results.to_vector())
215 }
216}
217
218fn parse_ipv4_cidr(cidr: &str) -> Result<(Ipv4Addr, u8)> {
221 let parts: Vec<&str> = cidr.split('/').collect();
223 ensure!(
224 parts.len() == 2,
225 InvalidFuncArgsSnafu {
226 err_msg: format!("Invalid CIDR notation: {}", cidr),
227 }
228 );
229
230 let ip = Ipv4Addr::from_str(parts[0]).map_err(|_| {
232 InvalidFuncArgsSnafu {
233 err_msg: format!("Invalid IPv4 address in CIDR: {}", parts[0]),
234 }
235 .build()
236 })?;
237
238 let prefix = parts[1].parse::<u8>().map_err(|_| {
240 InvalidFuncArgsSnafu {
241 err_msg: format!("Invalid prefix length: {}", parts[1]),
242 }
243 .build()
244 })?;
245
246 ensure!(
247 prefix <= 32,
248 InvalidFuncArgsSnafu {
249 err_msg: format!("IPv4 prefix length must be <= 32, got {}", prefix),
250 }
251 );
252
253 Ok((ip, prefix))
254}
255
256fn parse_ipv6_cidr(cidr: &str) -> Result<(Ipv6Addr, u8)> {
257 let parts: Vec<&str> = cidr.split('/').collect();
259 ensure!(
260 parts.len() == 2,
261 InvalidFuncArgsSnafu {
262 err_msg: format!("Invalid CIDR notation: {}", cidr),
263 }
264 );
265
266 let ip = Ipv6Addr::from_str(parts[0]).map_err(|_| {
268 InvalidFuncArgsSnafu {
269 err_msg: format!("Invalid IPv6 address in CIDR: {}", parts[0]),
270 }
271 .build()
272 })?;
273
274 let prefix = parts[1].parse::<u8>().map_err(|_| {
276 InvalidFuncArgsSnafu {
277 err_msg: format!("Invalid prefix length: {}", parts[1]),
278 }
279 .build()
280 })?;
281
282 ensure!(
283 prefix <= 128,
284 InvalidFuncArgsSnafu {
285 err_msg: format!("IPv6 prefix length must be <= 128, got {}", prefix),
286 }
287 );
288
289 Ok((ip, prefix))
290}
291
292fn is_ipv4_in_range(ip: &Ipv4Addr, cidr_base: &Ipv4Addr, prefix_len: u8) -> Option<bool> {
293 let ip_int = u32::from(*ip);
295 let cidr_int = u32::from(*cidr_base);
296
297 let mask = if prefix_len == 0 {
299 0
300 } else {
301 u32::MAX << (32 - prefix_len)
302 };
303
304 let ip_network = ip_int & mask;
306 let cidr_network = cidr_int & mask;
307
308 Some(ip_network == cidr_network)
309}
310
311fn is_ipv6_in_range(ip: &Ipv6Addr, cidr_base: &Ipv6Addr, prefix_len: u8) -> Option<bool> {
312 let ip_octets = ip.octets();
314 let cidr_octets = cidr_base.octets();
315
316 let full_bytes = (prefix_len / 8) as usize;
318
319 for i in 0..full_bytes {
321 if ip_octets[i] != cidr_octets[i] {
322 return Some(false);
323 }
324 }
325
326 if prefix_len % 8 != 0 && full_bytes < 16 {
328 let bits_to_check = prefix_len % 8;
329 let mask = 0xFF_u8 << (8 - bits_to_check);
330
331 if (ip_octets[full_bytes] & mask) != (cidr_octets[full_bytes] & mask) {
332 return Some(false);
333 }
334 }
335
336 Some(true)
338}
339
340#[cfg(test)]
341mod tests {
342 use std::sync::Arc;
343
344 use datatypes::scalars::ScalarVector;
345 use datatypes::vectors::{BooleanVector, StringVector};
346
347 use super::*;
348
349 #[test]
350 fn test_ipv4_in_range() {
351 let func = Ipv4InRange;
352 let ctx = FunctionContext::default();
353
354 let ip_values = vec![
356 "192.168.1.5",
357 "192.168.2.1",
358 "10.0.0.1",
359 "10.1.0.1",
360 "172.16.0.1",
361 ];
362
363 let cidr_values = vec![
365 "192.168.1.0/24",
366 "192.168.1.0/24",
367 "10.0.0.0/8",
368 "10.0.0.0/8",
369 "172.16.0.0/16",
370 ];
371
372 let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
373 let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
374
375 let result = func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
376 let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
377
378 assert!(result.get_data(0).unwrap()); assert!(!result.get_data(1).unwrap()); assert!(result.get_data(2).unwrap()); assert!(result.get_data(3).unwrap()); assert!(result.get_data(4).unwrap()); }
385
386 #[test]
387 fn test_ipv6_in_range() {
388 let func = Ipv6InRange;
389 let ctx = FunctionContext::default();
390
391 let ip_values = vec![
393 "2001:db8::1",
394 "2001:db8:1::",
395 "2001:db9::1",
396 "::1",
397 "fe80::1",
398 ];
399
400 let cidr_values = vec![
402 "2001:db8::/32",
403 "2001:db8::/32",
404 "2001:db8::/32",
405 "::1/128",
406 "fe80::/16",
407 ];
408
409 let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
410 let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
411
412 let result = func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
413 let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
414
415 assert!(result.get_data(0).unwrap()); assert!(result.get_data(1).unwrap()); assert!(!result.get_data(2).unwrap()); assert!(result.get_data(3).unwrap()); assert!(result.get_data(4).unwrap()); }
422
423 #[test]
424 fn test_invalid_inputs() {
425 let ipv4_func = Ipv4InRange;
426 let ipv6_func = Ipv6InRange;
427 let ctx = FunctionContext::default();
428
429 let invalid_ip_values = vec!["not-an-ip", "192.168.1.300"];
431 let cidr_values = vec!["192.168.1.0/24", "192.168.1.0/24"];
432
433 let invalid_ip_input = Arc::new(StringVector::from_slice(&invalid_ip_values)) as VectorRef;
434 let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
435
436 let result = ipv4_func.eval(&ctx, &[invalid_ip_input, cidr_input]);
437 assert!(result.is_err());
438
439 let ip_values = vec!["192.168.1.1", "2001:db8::1"];
441 let invalid_cidr_values = vec!["192.168.1.0", "2001:db8::/129"];
442
443 let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
444 let invalid_cidr_input =
445 Arc::new(StringVector::from_slice(&invalid_cidr_values)) as VectorRef;
446
447 let ipv4_result = ipv4_func.eval(&ctx, &[ip_input.clone(), invalid_cidr_input.clone()]);
448 let ipv6_result = ipv6_func.eval(&ctx, &[ip_input, invalid_cidr_input]);
449
450 assert!(ipv4_result.is_err());
451 assert!(ipv6_result.is_err());
452 }
453
454 #[test]
455 fn test_edge_cases() {
456 let ipv4_func = Ipv4InRange;
457 let ctx = FunctionContext::default();
458
459 let ip_values = vec!["8.8.8.8", "192.168.1.1", "192.168.1.1"];
461 let cidr_values = vec!["0.0.0.0/0", "192.168.1.1/32", "192.168.1.0/32"];
462
463 let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
464 let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
465
466 let result = ipv4_func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
467 let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
468
469 assert!(result.get_data(0).unwrap()); assert!(result.get_data(1).unwrap()); assert!(!result.get_data(2).unwrap()); }
473}