common_function/scalars/ip/
range.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::net::{Ipv4Addr, Ipv6Addr};
16use std::str::FromStr;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion::arrow::datatypes::DataType;
20use datafusion_expr::{Signature, Volatility};
21use datatypes::prelude::Value;
22use datatypes::scalars::ScalarVectorBuilder;
23use datatypes::vectors::{BooleanVectorBuilder, MutableVector, VectorRef};
24use derive_more::Display;
25use snafu::ensure;
26
27use crate::function::{Function, FunctionContext};
28
29/// Function that checks if an IPv4 address is within a specified CIDR range.
30///
31/// Both the IP address and the CIDR range are provided as strings.
32/// Returns boolean result indicating whether the IP is in the range.
33///
34/// Examples:
35/// - ipv4_in_range('192.168.1.5', '192.168.1.0/24') -> true
36/// - ipv4_in_range('192.168.2.1', '192.168.1.0/24') -> false
37/// - ipv4_in_range('10.0.0.1', '10.0.0.0/8') -> true
38#[derive(Clone, Debug, Default, Display)]
39#[display("{}", self.name())]
40pub struct Ipv4InRange;
41
42impl Function for Ipv4InRange {
43    fn name(&self) -> &str {
44        "ipv4_in_range"
45    }
46
47    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
48        Ok(DataType::Boolean)
49    }
50
51    fn signature(&self) -> Signature {
52        Signature::string(2, Volatility::Immutable)
53    }
54
55    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
56        ensure!(
57            columns.len() == 2,
58            InvalidFuncArgsSnafu {
59                err_msg: format!("Expected 2 arguments, got {}", columns.len())
60            }
61        );
62
63        let ip_vec = &columns[0];
64        let range_vec = &columns[1];
65        let size = ip_vec.len();
66
67        ensure!(
68            range_vec.len() == size,
69            InvalidFuncArgsSnafu {
70                err_msg: "IP addresses and CIDR ranges must have the same number of rows"
71                    .to_string()
72            }
73        );
74
75        let mut results = BooleanVectorBuilder::with_capacity(size);
76
77        for i in 0..size {
78            let ip = ip_vec.get(i);
79            let range = range_vec.get(i);
80
81            let in_range = match (ip, range) {
82                (Value::String(ip_str), Value::String(range_str)) => {
83                    let ip_str = ip_str.as_utf8().trim();
84                    let range_str = range_str.as_utf8().trim();
85
86                    if ip_str.is_empty() || range_str.is_empty() {
87                        return InvalidFuncArgsSnafu {
88                            err_msg: "IP address and CIDR range cannot be empty".to_string(),
89                        }
90                        .fail();
91                    }
92
93                    // Parse the IP address
94                    let ip_addr = Ipv4Addr::from_str(ip_str).map_err(|_| {
95                        InvalidFuncArgsSnafu {
96                            err_msg: format!("Invalid IPv4 address: {}", ip_str),
97                        }
98                        .build()
99                    })?;
100
101                    // Parse the CIDR range
102                    let (cidr_ip, cidr_prefix) = parse_ipv4_cidr(range_str)?;
103
104                    // Check if the IP is in the CIDR range
105                    is_ipv4_in_range(&ip_addr, &cidr_ip, cidr_prefix)
106                }
107                _ => None,
108            };
109
110            results.push(in_range);
111        }
112
113        Ok(results.to_vector())
114    }
115}
116
117/// Function that checks if an IPv6 address is within a specified CIDR range.
118///
119/// Both the IP address and the CIDR range are provided as strings.
120/// Returns boolean result indicating whether the IP is in the range.
121///
122/// Examples:
123/// - ipv6_in_range('2001:db8::1', '2001:db8::/32') -> true
124/// - ipv6_in_range('2001:db8:1::', '2001:db8::/32') -> true
125/// - ipv6_in_range('2001:db9::1', '2001:db8::/32') -> false
126/// - ipv6_in_range('::1', '::1/128') -> true
127#[derive(Clone, Debug, Default, Display)]
128#[display("{}", self.name())]
129pub struct Ipv6InRange;
130
131impl Function for Ipv6InRange {
132    fn name(&self) -> &str {
133        "ipv6_in_range"
134    }
135
136    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
137        Ok(DataType::Boolean)
138    }
139
140    fn signature(&self) -> Signature {
141        Signature::string(2, Volatility::Immutable)
142    }
143
144    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
145        ensure!(
146            columns.len() == 2,
147            InvalidFuncArgsSnafu {
148                err_msg: format!("Expected 2 arguments, got {}", columns.len())
149            }
150        );
151
152        let ip_vec = &columns[0];
153        let range_vec = &columns[1];
154        let size = ip_vec.len();
155
156        ensure!(
157            range_vec.len() == size,
158            InvalidFuncArgsSnafu {
159                err_msg: "IP addresses and CIDR ranges must have the same number of rows"
160                    .to_string()
161            }
162        );
163
164        let mut results = BooleanVectorBuilder::with_capacity(size);
165
166        for i in 0..size {
167            let ip = ip_vec.get(i);
168            let range = range_vec.get(i);
169
170            let in_range = match (ip, range) {
171                (Value::String(ip_str), Value::String(range_str)) => {
172                    let ip_str = ip_str.as_utf8().trim();
173                    let range_str = range_str.as_utf8().trim();
174
175                    if ip_str.is_empty() || range_str.is_empty() {
176                        return InvalidFuncArgsSnafu {
177                            err_msg: "IP address and CIDR range cannot be empty".to_string(),
178                        }
179                        .fail();
180                    }
181
182                    // Parse the IP address
183                    let ip_addr = Ipv6Addr::from_str(ip_str).map_err(|_| {
184                        InvalidFuncArgsSnafu {
185                            err_msg: format!("Invalid IPv6 address: {}", ip_str),
186                        }
187                        .build()
188                    })?;
189
190                    // Parse the CIDR range
191                    let (cidr_ip, cidr_prefix) = parse_ipv6_cidr(range_str)?;
192
193                    // Check if the IP is in the CIDR range
194                    is_ipv6_in_range(&ip_addr, &cidr_ip, cidr_prefix)
195                }
196                _ => None,
197            };
198
199            results.push(in_range);
200        }
201
202        Ok(results.to_vector())
203    }
204}
205
206// Helper functions
207
208fn parse_ipv4_cidr(cidr: &str) -> Result<(Ipv4Addr, u8)> {
209    // Split the CIDR string into IP and prefix parts
210    let parts: Vec<&str> = cidr.split('/').collect();
211    ensure!(
212        parts.len() == 2,
213        InvalidFuncArgsSnafu {
214            err_msg: format!("Invalid CIDR notation: {}", cidr),
215        }
216    );
217
218    // Parse the IP address part
219    let ip = Ipv4Addr::from_str(parts[0]).map_err(|_| {
220        InvalidFuncArgsSnafu {
221            err_msg: format!("Invalid IPv4 address in CIDR: {}", parts[0]),
222        }
223        .build()
224    })?;
225
226    // Parse the prefix length
227    let prefix = parts[1].parse::<u8>().map_err(|_| {
228        InvalidFuncArgsSnafu {
229            err_msg: format!("Invalid prefix length: {}", parts[1]),
230        }
231        .build()
232    })?;
233
234    ensure!(
235        prefix <= 32,
236        InvalidFuncArgsSnafu {
237            err_msg: format!("IPv4 prefix length must be <= 32, got {}", prefix),
238        }
239    );
240
241    Ok((ip, prefix))
242}
243
244fn parse_ipv6_cidr(cidr: &str) -> Result<(Ipv6Addr, u8)> {
245    // Split the CIDR string into IP and prefix parts
246    let parts: Vec<&str> = cidr.split('/').collect();
247    ensure!(
248        parts.len() == 2,
249        InvalidFuncArgsSnafu {
250            err_msg: format!("Invalid CIDR notation: {}", cidr),
251        }
252    );
253
254    // Parse the IP address part
255    let ip = Ipv6Addr::from_str(parts[0]).map_err(|_| {
256        InvalidFuncArgsSnafu {
257            err_msg: format!("Invalid IPv6 address in CIDR: {}", parts[0]),
258        }
259        .build()
260    })?;
261
262    // Parse the prefix length
263    let prefix = parts[1].parse::<u8>().map_err(|_| {
264        InvalidFuncArgsSnafu {
265            err_msg: format!("Invalid prefix length: {}", parts[1]),
266        }
267        .build()
268    })?;
269
270    ensure!(
271        prefix <= 128,
272        InvalidFuncArgsSnafu {
273            err_msg: format!("IPv6 prefix length must be <= 128, got {}", prefix),
274        }
275    );
276
277    Ok((ip, prefix))
278}
279
280fn is_ipv4_in_range(ip: &Ipv4Addr, cidr_base: &Ipv4Addr, prefix_len: u8) -> Option<bool> {
281    // Convert both IPs to integers
282    let ip_int = u32::from(*ip);
283    let cidr_int = u32::from(*cidr_base);
284
285    // Calculate the mask from the prefix length
286    let mask = if prefix_len == 0 {
287        0
288    } else {
289        u32::MAX << (32 - prefix_len)
290    };
291
292    // Apply the mask to both IPs and see if they match
293    let ip_network = ip_int & mask;
294    let cidr_network = cidr_int & mask;
295
296    Some(ip_network == cidr_network)
297}
298
299fn is_ipv6_in_range(ip: &Ipv6Addr, cidr_base: &Ipv6Addr, prefix_len: u8) -> Option<bool> {
300    // Get the octets (16 bytes) of both IPs
301    let ip_octets = ip.octets();
302    let cidr_octets = cidr_base.octets();
303
304    // Calculate how many full bytes to compare
305    let full_bytes = (prefix_len / 8) as usize;
306
307    // First, check full bytes for equality
308    for i in 0..full_bytes {
309        if ip_octets[i] != cidr_octets[i] {
310            return Some(false);
311        }
312    }
313
314    // If there's a partial byte to check
315    if prefix_len % 8 != 0 && full_bytes < 16 {
316        let bits_to_check = prefix_len % 8;
317        let mask = 0xFF_u8 << (8 - bits_to_check);
318
319        if (ip_octets[full_bytes] & mask) != (cidr_octets[full_bytes] & mask) {
320            return Some(false);
321        }
322    }
323
324    // If we got here, everything matched
325    Some(true)
326}
327
328#[cfg(test)]
329mod tests {
330    use std::sync::Arc;
331
332    use datatypes::scalars::ScalarVector;
333    use datatypes::vectors::{BooleanVector, StringVector};
334
335    use super::*;
336
337    #[test]
338    fn test_ipv4_in_range() {
339        let func = Ipv4InRange;
340        let ctx = FunctionContext::default();
341
342        // Test IPs
343        let ip_values = vec![
344            "192.168.1.5",
345            "192.168.2.1",
346            "10.0.0.1",
347            "10.1.0.1",
348            "172.16.0.1",
349        ];
350
351        // Corresponding CIDR ranges
352        let cidr_values = vec![
353            "192.168.1.0/24",
354            "192.168.1.0/24",
355            "10.0.0.0/8",
356            "10.0.0.0/8",
357            "172.16.0.0/16",
358        ];
359
360        let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
361        let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
362
363        let result = func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
364        let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
365
366        // Expected results
367        assert!(result.get_data(0).unwrap()); // 192.168.1.5 is in 192.168.1.0/24
368        assert!(!result.get_data(1).unwrap()); // 192.168.2.1 is not in 192.168.1.0/24
369        assert!(result.get_data(2).unwrap()); // 10.0.0.1 is in 10.0.0.0/8
370        assert!(result.get_data(3).unwrap()); // 10.1.0.1 is in 10.0.0.0/8
371        assert!(result.get_data(4).unwrap()); // 172.16.0.1 is in 172.16.0.0/16
372    }
373
374    #[test]
375    fn test_ipv6_in_range() {
376        let func = Ipv6InRange;
377        let ctx = FunctionContext::default();
378
379        // Test IPs
380        let ip_values = vec![
381            "2001:db8::1",
382            "2001:db8:1::",
383            "2001:db9::1",
384            "::1",
385            "fe80::1",
386        ];
387
388        // Corresponding CIDR ranges
389        let cidr_values = vec![
390            "2001:db8::/32",
391            "2001:db8::/32",
392            "2001:db8::/32",
393            "::1/128",
394            "fe80::/16",
395        ];
396
397        let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
398        let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
399
400        let result = func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
401        let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
402
403        // Expected results
404        assert!(result.get_data(0).unwrap()); // 2001:db8::1 is in 2001:db8::/32
405        assert!(result.get_data(1).unwrap()); // 2001:db8:1:: is in 2001:db8::/32
406        assert!(!result.get_data(2).unwrap()); // 2001:db9::1 is not in 2001:db8::/32
407        assert!(result.get_data(3).unwrap()); // ::1 is in ::1/128
408        assert!(result.get_data(4).unwrap()); // fe80::1 is in fe80::/16
409    }
410
411    #[test]
412    fn test_invalid_inputs() {
413        let ipv4_func = Ipv4InRange;
414        let ipv6_func = Ipv6InRange;
415        let ctx = FunctionContext::default();
416
417        // Invalid IPv4 address
418        let invalid_ip_values = vec!["not-an-ip", "192.168.1.300"];
419        let cidr_values = vec!["192.168.1.0/24", "192.168.1.0/24"];
420
421        let invalid_ip_input = Arc::new(StringVector::from_slice(&invalid_ip_values)) as VectorRef;
422        let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
423
424        let result = ipv4_func.eval(&ctx, &[invalid_ip_input, cidr_input]);
425        assert!(result.is_err());
426
427        // Invalid CIDR notation
428        let ip_values = vec!["192.168.1.1", "2001:db8::1"];
429        let invalid_cidr_values = vec!["192.168.1.0", "2001:db8::/129"];
430
431        let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
432        let invalid_cidr_input =
433            Arc::new(StringVector::from_slice(&invalid_cidr_values)) as VectorRef;
434
435        let ipv4_result = ipv4_func.eval(&ctx, &[ip_input.clone(), invalid_cidr_input.clone()]);
436        let ipv6_result = ipv6_func.eval(&ctx, &[ip_input, invalid_cidr_input]);
437
438        assert!(ipv4_result.is_err());
439        assert!(ipv6_result.is_err());
440    }
441
442    #[test]
443    fn test_edge_cases() {
444        let ipv4_func = Ipv4InRange;
445        let ctx = FunctionContext::default();
446
447        // Edge cases like prefix length 0 (matches everything) and 32 (exact match)
448        let ip_values = vec!["8.8.8.8", "192.168.1.1", "192.168.1.1"];
449        let cidr_values = vec!["0.0.0.0/0", "192.168.1.1/32", "192.168.1.0/32"];
450
451        let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
452        let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
453
454        let result = ipv4_func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
455        let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
456
457        assert!(result.get_data(0).unwrap()); // 8.8.8.8 is in 0.0.0.0/0 (matches everything)
458        assert!(result.get_data(1).unwrap()); // 192.168.1.1 is in 192.168.1.1/32 (exact match)
459        assert!(!result.get_data(2).unwrap()); // 192.168.1.1 is not in 192.168.1.0/32 (no match)
460    }
461}