common_function/scalars/ip/
range.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::net::{Ipv4Addr, Ipv6Addr};
16use std::str::FromStr;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::{Signature, TypeSignature};
20use datafusion::logical_expr::Volatility;
21use datatypes::prelude::{ConcreteDataType, Value};
22use datatypes::scalars::ScalarVectorBuilder;
23use datatypes::vectors::{BooleanVectorBuilder, MutableVector, VectorRef};
24use derive_more::Display;
25use snafu::ensure;
26
27use crate::function::{Function, FunctionContext};
28
29/// Function that checks if an IPv4 address is within a specified CIDR range.
30///
31/// Both the IP address and the CIDR range are provided as strings.
32/// Returns boolean result indicating whether the IP is in the range.
33///
34/// Examples:
35/// - ipv4_in_range('192.168.1.5', '192.168.1.0/24') -> true
36/// - ipv4_in_range('192.168.2.1', '192.168.1.0/24') -> false
37/// - ipv4_in_range('10.0.0.1', '10.0.0.0/8') -> true
38#[derive(Clone, Debug, Default, Display)]
39#[display("{}", self.name())]
40pub struct Ipv4InRange;
41
42impl Function for Ipv4InRange {
43    fn name(&self) -> &str {
44        "ipv4_in_range"
45    }
46
47    fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
48        Ok(ConcreteDataType::boolean_datatype())
49    }
50
51    fn signature(&self) -> Signature {
52        Signature::new(
53            TypeSignature::Exact(vec![
54                ConcreteDataType::string_datatype(),
55                ConcreteDataType::string_datatype(),
56            ]),
57            Volatility::Immutable,
58        )
59    }
60
61    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
62        ensure!(
63            columns.len() == 2,
64            InvalidFuncArgsSnafu {
65                err_msg: format!("Expected 2 arguments, got {}", columns.len())
66            }
67        );
68
69        let ip_vec = &columns[0];
70        let range_vec = &columns[1];
71        let size = ip_vec.len();
72
73        ensure!(
74            range_vec.len() == size,
75            InvalidFuncArgsSnafu {
76                err_msg: "IP addresses and CIDR ranges must have the same number of rows"
77                    .to_string()
78            }
79        );
80
81        let mut results = BooleanVectorBuilder::with_capacity(size);
82
83        for i in 0..size {
84            let ip = ip_vec.get(i);
85            let range = range_vec.get(i);
86
87            let in_range = match (ip, range) {
88                (Value::String(ip_str), Value::String(range_str)) => {
89                    let ip_str = ip_str.as_utf8().trim();
90                    let range_str = range_str.as_utf8().trim();
91
92                    if ip_str.is_empty() || range_str.is_empty() {
93                        return InvalidFuncArgsSnafu {
94                            err_msg: "IP address and CIDR range cannot be empty".to_string(),
95                        }
96                        .fail();
97                    }
98
99                    // Parse the IP address
100                    let ip_addr = Ipv4Addr::from_str(ip_str).map_err(|_| {
101                        InvalidFuncArgsSnafu {
102                            err_msg: format!("Invalid IPv4 address: {}", ip_str),
103                        }
104                        .build()
105                    })?;
106
107                    // Parse the CIDR range
108                    let (cidr_ip, cidr_prefix) = parse_ipv4_cidr(range_str)?;
109
110                    // Check if the IP is in the CIDR range
111                    is_ipv4_in_range(&ip_addr, &cidr_ip, cidr_prefix)
112                }
113                _ => None,
114            };
115
116            results.push(in_range);
117        }
118
119        Ok(results.to_vector())
120    }
121}
122
123/// Function that checks if an IPv6 address is within a specified CIDR range.
124///
125/// Both the IP address and the CIDR range are provided as strings.
126/// Returns boolean result indicating whether the IP is in the range.
127///
128/// Examples:
129/// - ipv6_in_range('2001:db8::1', '2001:db8::/32') -> true
130/// - ipv6_in_range('2001:db8:1::', '2001:db8::/32') -> true
131/// - ipv6_in_range('2001:db9::1', '2001:db8::/32') -> false
132/// - ipv6_in_range('::1', '::1/128') -> true
133#[derive(Clone, Debug, Default, Display)]
134#[display("{}", self.name())]
135pub struct Ipv6InRange;
136
137impl Function for Ipv6InRange {
138    fn name(&self) -> &str {
139        "ipv6_in_range"
140    }
141
142    fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
143        Ok(ConcreteDataType::boolean_datatype())
144    }
145
146    fn signature(&self) -> Signature {
147        Signature::new(
148            TypeSignature::Exact(vec![
149                ConcreteDataType::string_datatype(),
150                ConcreteDataType::string_datatype(),
151            ]),
152            Volatility::Immutable,
153        )
154    }
155
156    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
157        ensure!(
158            columns.len() == 2,
159            InvalidFuncArgsSnafu {
160                err_msg: format!("Expected 2 arguments, got {}", columns.len())
161            }
162        );
163
164        let ip_vec = &columns[0];
165        let range_vec = &columns[1];
166        let size = ip_vec.len();
167
168        ensure!(
169            range_vec.len() == size,
170            InvalidFuncArgsSnafu {
171                err_msg: "IP addresses and CIDR ranges must have the same number of rows"
172                    .to_string()
173            }
174        );
175
176        let mut results = BooleanVectorBuilder::with_capacity(size);
177
178        for i in 0..size {
179            let ip = ip_vec.get(i);
180            let range = range_vec.get(i);
181
182            let in_range = match (ip, range) {
183                (Value::String(ip_str), Value::String(range_str)) => {
184                    let ip_str = ip_str.as_utf8().trim();
185                    let range_str = range_str.as_utf8().trim();
186
187                    if ip_str.is_empty() || range_str.is_empty() {
188                        return InvalidFuncArgsSnafu {
189                            err_msg: "IP address and CIDR range cannot be empty".to_string(),
190                        }
191                        .fail();
192                    }
193
194                    // Parse the IP address
195                    let ip_addr = Ipv6Addr::from_str(ip_str).map_err(|_| {
196                        InvalidFuncArgsSnafu {
197                            err_msg: format!("Invalid IPv6 address: {}", ip_str),
198                        }
199                        .build()
200                    })?;
201
202                    // Parse the CIDR range
203                    let (cidr_ip, cidr_prefix) = parse_ipv6_cidr(range_str)?;
204
205                    // Check if the IP is in the CIDR range
206                    is_ipv6_in_range(&ip_addr, &cidr_ip, cidr_prefix)
207                }
208                _ => None,
209            };
210
211            results.push(in_range);
212        }
213
214        Ok(results.to_vector())
215    }
216}
217
218// Helper functions
219
220fn parse_ipv4_cidr(cidr: &str) -> Result<(Ipv4Addr, u8)> {
221    // Split the CIDR string into IP and prefix parts
222    let parts: Vec<&str> = cidr.split('/').collect();
223    ensure!(
224        parts.len() == 2,
225        InvalidFuncArgsSnafu {
226            err_msg: format!("Invalid CIDR notation: {}", cidr),
227        }
228    );
229
230    // Parse the IP address part
231    let ip = Ipv4Addr::from_str(parts[0]).map_err(|_| {
232        InvalidFuncArgsSnafu {
233            err_msg: format!("Invalid IPv4 address in CIDR: {}", parts[0]),
234        }
235        .build()
236    })?;
237
238    // Parse the prefix length
239    let prefix = parts[1].parse::<u8>().map_err(|_| {
240        InvalidFuncArgsSnafu {
241            err_msg: format!("Invalid prefix length: {}", parts[1]),
242        }
243        .build()
244    })?;
245
246    ensure!(
247        prefix <= 32,
248        InvalidFuncArgsSnafu {
249            err_msg: format!("IPv4 prefix length must be <= 32, got {}", prefix),
250        }
251    );
252
253    Ok((ip, prefix))
254}
255
256fn parse_ipv6_cidr(cidr: &str) -> Result<(Ipv6Addr, u8)> {
257    // Split the CIDR string into IP and prefix parts
258    let parts: Vec<&str> = cidr.split('/').collect();
259    ensure!(
260        parts.len() == 2,
261        InvalidFuncArgsSnafu {
262            err_msg: format!("Invalid CIDR notation: {}", cidr),
263        }
264    );
265
266    // Parse the IP address part
267    let ip = Ipv6Addr::from_str(parts[0]).map_err(|_| {
268        InvalidFuncArgsSnafu {
269            err_msg: format!("Invalid IPv6 address in CIDR: {}", parts[0]),
270        }
271        .build()
272    })?;
273
274    // Parse the prefix length
275    let prefix = parts[1].parse::<u8>().map_err(|_| {
276        InvalidFuncArgsSnafu {
277            err_msg: format!("Invalid prefix length: {}", parts[1]),
278        }
279        .build()
280    })?;
281
282    ensure!(
283        prefix <= 128,
284        InvalidFuncArgsSnafu {
285            err_msg: format!("IPv6 prefix length must be <= 128, got {}", prefix),
286        }
287    );
288
289    Ok((ip, prefix))
290}
291
292fn is_ipv4_in_range(ip: &Ipv4Addr, cidr_base: &Ipv4Addr, prefix_len: u8) -> Option<bool> {
293    // Convert both IPs to integers
294    let ip_int = u32::from(*ip);
295    let cidr_int = u32::from(*cidr_base);
296
297    // Calculate the mask from the prefix length
298    let mask = if prefix_len == 0 {
299        0
300    } else {
301        u32::MAX << (32 - prefix_len)
302    };
303
304    // Apply the mask to both IPs and see if they match
305    let ip_network = ip_int & mask;
306    let cidr_network = cidr_int & mask;
307
308    Some(ip_network == cidr_network)
309}
310
311fn is_ipv6_in_range(ip: &Ipv6Addr, cidr_base: &Ipv6Addr, prefix_len: u8) -> Option<bool> {
312    // Get the octets (16 bytes) of both IPs
313    let ip_octets = ip.octets();
314    let cidr_octets = cidr_base.octets();
315
316    // Calculate how many full bytes to compare
317    let full_bytes = (prefix_len / 8) as usize;
318
319    // First, check full bytes for equality
320    for i in 0..full_bytes {
321        if ip_octets[i] != cidr_octets[i] {
322            return Some(false);
323        }
324    }
325
326    // If there's a partial byte to check
327    if prefix_len % 8 != 0 && full_bytes < 16 {
328        let bits_to_check = prefix_len % 8;
329        let mask = 0xFF_u8 << (8 - bits_to_check);
330
331        if (ip_octets[full_bytes] & mask) != (cidr_octets[full_bytes] & mask) {
332            return Some(false);
333        }
334    }
335
336    // If we got here, everything matched
337    Some(true)
338}
339
340#[cfg(test)]
341mod tests {
342    use std::sync::Arc;
343
344    use datatypes::scalars::ScalarVector;
345    use datatypes::vectors::{BooleanVector, StringVector};
346
347    use super::*;
348
349    #[test]
350    fn test_ipv4_in_range() {
351        let func = Ipv4InRange;
352        let ctx = FunctionContext::default();
353
354        // Test IPs
355        let ip_values = vec![
356            "192.168.1.5",
357            "192.168.2.1",
358            "10.0.0.1",
359            "10.1.0.1",
360            "172.16.0.1",
361        ];
362
363        // Corresponding CIDR ranges
364        let cidr_values = vec![
365            "192.168.1.0/24",
366            "192.168.1.0/24",
367            "10.0.0.0/8",
368            "10.0.0.0/8",
369            "172.16.0.0/16",
370        ];
371
372        let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
373        let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
374
375        let result = func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
376        let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
377
378        // Expected results
379        assert!(result.get_data(0).unwrap()); // 192.168.1.5 is in 192.168.1.0/24
380        assert!(!result.get_data(1).unwrap()); // 192.168.2.1 is not in 192.168.1.0/24
381        assert!(result.get_data(2).unwrap()); // 10.0.0.1 is in 10.0.0.0/8
382        assert!(result.get_data(3).unwrap()); // 10.1.0.1 is in 10.0.0.0/8
383        assert!(result.get_data(4).unwrap()); // 172.16.0.1 is in 172.16.0.0/16
384    }
385
386    #[test]
387    fn test_ipv6_in_range() {
388        let func = Ipv6InRange;
389        let ctx = FunctionContext::default();
390
391        // Test IPs
392        let ip_values = vec![
393            "2001:db8::1",
394            "2001:db8:1::",
395            "2001:db9::1",
396            "::1",
397            "fe80::1",
398        ];
399
400        // Corresponding CIDR ranges
401        let cidr_values = vec![
402            "2001:db8::/32",
403            "2001:db8::/32",
404            "2001:db8::/32",
405            "::1/128",
406            "fe80::/16",
407        ];
408
409        let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
410        let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
411
412        let result = func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
413        let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
414
415        // Expected results
416        assert!(result.get_data(0).unwrap()); // 2001:db8::1 is in 2001:db8::/32
417        assert!(result.get_data(1).unwrap()); // 2001:db8:1:: is in 2001:db8::/32
418        assert!(!result.get_data(2).unwrap()); // 2001:db9::1 is not in 2001:db8::/32
419        assert!(result.get_data(3).unwrap()); // ::1 is in ::1/128
420        assert!(result.get_data(4).unwrap()); // fe80::1 is in fe80::/16
421    }
422
423    #[test]
424    fn test_invalid_inputs() {
425        let ipv4_func = Ipv4InRange;
426        let ipv6_func = Ipv6InRange;
427        let ctx = FunctionContext::default();
428
429        // Invalid IPv4 address
430        let invalid_ip_values = vec!["not-an-ip", "192.168.1.300"];
431        let cidr_values = vec!["192.168.1.0/24", "192.168.1.0/24"];
432
433        let invalid_ip_input = Arc::new(StringVector::from_slice(&invalid_ip_values)) as VectorRef;
434        let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
435
436        let result = ipv4_func.eval(&ctx, &[invalid_ip_input, cidr_input]);
437        assert!(result.is_err());
438
439        // Invalid CIDR notation
440        let ip_values = vec!["192.168.1.1", "2001:db8::1"];
441        let invalid_cidr_values = vec!["192.168.1.0", "2001:db8::/129"];
442
443        let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
444        let invalid_cidr_input =
445            Arc::new(StringVector::from_slice(&invalid_cidr_values)) as VectorRef;
446
447        let ipv4_result = ipv4_func.eval(&ctx, &[ip_input.clone(), invalid_cidr_input.clone()]);
448        let ipv6_result = ipv6_func.eval(&ctx, &[ip_input, invalid_cidr_input]);
449
450        assert!(ipv4_result.is_err());
451        assert!(ipv6_result.is_err());
452    }
453
454    #[test]
455    fn test_edge_cases() {
456        let ipv4_func = Ipv4InRange;
457        let ctx = FunctionContext::default();
458
459        // Edge cases like prefix length 0 (matches everything) and 32 (exact match)
460        let ip_values = vec!["8.8.8.8", "192.168.1.1", "192.168.1.1"];
461        let cidr_values = vec!["0.0.0.0/0", "192.168.1.1/32", "192.168.1.0/32"];
462
463        let ip_input = Arc::new(StringVector::from_slice(&ip_values)) as VectorRef;
464        let cidr_input = Arc::new(StringVector::from_slice(&cidr_values)) as VectorRef;
465
466        let result = ipv4_func.eval(&ctx, &[ip_input, cidr_input]).unwrap();
467        let result = result.as_any().downcast_ref::<BooleanVector>().unwrap();
468
469        assert!(result.get_data(0).unwrap()); // 8.8.8.8 is in 0.0.0.0/0 (matches everything)
470        assert!(result.get_data(1).unwrap()); // 192.168.1.1 is in 192.168.1.1/32 (exact match)
471        assert!(!result.get_data(2).unwrap()); // 192.168.1.1 is not in 192.168.1.0/32 (no match)
472    }
473}