common_function/scalars/ip/
range.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::net::{Ipv4Addr, Ipv6Addr};
16use std::str::FromStr;
17use std::sync::Arc;
18
19use common_query::error::{InvalidFuncArgsSnafu, Result};
20use datafusion_common::DataFusionError;
21use datafusion_common::arrow::array::{Array, AsArray, BooleanBuilder};
22use datafusion_common::arrow::compute;
23use datafusion_common::arrow::datatypes::DataType;
24use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
25use derive_more::Display;
26use snafu::ensure;
27
28use crate::function::{Function, extract_args};
29
30/// Function that checks if an IPv4 address is within a specified CIDR range.
31///
32/// Both the IP address and the CIDR range are provided as strings.
33/// Returns boolean result indicating whether the IP is in the range.
34///
35/// Examples:
36/// - ipv4_in_range('192.168.1.5', '192.168.1.0/24') -> true
37/// - ipv4_in_range('192.168.2.1', '192.168.1.0/24') -> false
38/// - ipv4_in_range('10.0.0.1', '10.0.0.0/8') -> true
39#[derive(Clone, Debug, Display)]
40#[display("{}", self.name())]
41pub(crate) struct Ipv4InRange {
42    signature: Signature,
43}
44
45impl Default for Ipv4InRange {
46    fn default() -> Self {
47        Self {
48            signature: Signature::string(2, Volatility::Immutable),
49        }
50    }
51}
52
53impl Function for Ipv4InRange {
54    fn name(&self) -> &str {
55        "ipv4_in_range"
56    }
57
58    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
59        Ok(DataType::Boolean)
60    }
61
62    fn signature(&self) -> &Signature {
63        &self.signature
64    }
65
66    fn invoke_with_args(
67        &self,
68        args: ScalarFunctionArgs,
69    ) -> datafusion_common::Result<ColumnarValue> {
70        let [arg0, arg1] = extract_args(self.name(), &args)?;
71
72        let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
73        let ip_vec = arg0.as_string_view();
74        let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
75        let ranges = arg1.as_string_view();
76
77        let size = ip_vec.len();
78
79        let mut builder = BooleanBuilder::with_capacity(size);
80
81        for i in 0..size {
82            let ip = ip_vec.is_valid(i).then(|| ip_vec.value(i));
83            let range = ranges.is_valid(i).then(|| ranges.value(i));
84
85            let in_range = match (ip, range) {
86                (Some(ip_str), Some(range_str)) => {
87                    if ip_str.is_empty() || range_str.is_empty() {
88                        return Err(DataFusionError::Execution(
89                            "IP address or CIDR range cannot be empty".to_string(),
90                        ));
91                    }
92
93                    // Parse the IP address
94                    let ip_addr = Ipv4Addr::from_str(ip_str).map_err(|_| {
95                        InvalidFuncArgsSnafu {
96                            err_msg: format!("Invalid IPv4 address: {}", ip_str),
97                        }
98                        .build()
99                    })?;
100
101                    // Parse the CIDR range
102                    let (cidr_ip, cidr_prefix) = parse_ipv4_cidr(range_str)?;
103
104                    // Check if the IP is in the CIDR range
105                    is_ipv4_in_range(&ip_addr, &cidr_ip, cidr_prefix)
106                }
107                _ => None,
108            };
109
110            builder.append_option(in_range);
111        }
112
113        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
114    }
115}
116
117/// Function that checks if an IPv6 address is within a specified CIDR range.
118///
119/// Both the IP address and the CIDR range are provided as strings.
120/// Returns boolean result indicating whether the IP is in the range.
121///
122/// Examples:
123/// - ipv6_in_range('2001:db8::1', '2001:db8::/32') -> true
124/// - ipv6_in_range('2001:db8:1::', '2001:db8::/32') -> true
125/// - ipv6_in_range('2001:db9::1', '2001:db8::/32') -> false
126/// - ipv6_in_range('::1', '::1/128') -> true
127#[derive(Clone, Debug, Display)]
128#[display("{}", self.name())]
129pub(crate) struct Ipv6InRange {
130    signature: Signature,
131}
132
133impl Default for Ipv6InRange {
134    fn default() -> Self {
135        Self {
136            signature: Signature::string(2, Volatility::Immutable),
137        }
138    }
139}
140
141impl Function for Ipv6InRange {
142    fn name(&self) -> &str {
143        "ipv6_in_range"
144    }
145
146    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
147        Ok(DataType::Boolean)
148    }
149
150    fn signature(&self) -> &Signature {
151        &self.signature
152    }
153
154    fn invoke_with_args(
155        &self,
156        args: ScalarFunctionArgs,
157    ) -> datafusion_common::Result<ColumnarValue> {
158        let [arg0, arg1] = extract_args(self.name(), &args)?;
159
160        let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
161        let ip_vec = arg0.as_string_view();
162        let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
163        let ranges = arg1.as_string_view();
164        let size = ip_vec.len();
165        let mut builder = BooleanBuilder::with_capacity(size);
166
167        for i in 0..size {
168            let ip = ip_vec.is_valid(i).then(|| ip_vec.value(i));
169            let range = ranges.is_valid(i).then(|| ranges.value(i));
170
171            let in_range = match (ip, range) {
172                (Some(ip_str), Some(range_str)) => {
173                    if ip_str.is_empty() || range_str.is_empty() {
174                        return Err(DataFusionError::Execution(
175                            "IP address or CIDR range cannot be empty".to_string(),
176                        ));
177                    }
178
179                    // Parse the IP address
180                    let ip_addr = Ipv6Addr::from_str(ip_str).map_err(|_| {
181                        InvalidFuncArgsSnafu {
182                            err_msg: format!("Invalid IPv6 address: {}", ip_str),
183                        }
184                        .build()
185                    })?;
186
187                    // Parse the CIDR range
188                    let (cidr_ip, cidr_prefix) = parse_ipv6_cidr(range_str)?;
189
190                    // Check if the IP is in the CIDR range
191                    is_ipv6_in_range(&ip_addr, &cidr_ip, cidr_prefix)
192                }
193                _ => None,
194            };
195
196            builder.append_option(in_range);
197        }
198
199        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
200    }
201}
202
203// Helper functions
204
205fn parse_ipv4_cidr(cidr: &str) -> Result<(Ipv4Addr, u8)> {
206    // Split the CIDR string into IP and prefix parts
207    let parts: Vec<&str> = cidr.split('/').collect();
208    ensure!(
209        parts.len() == 2,
210        InvalidFuncArgsSnafu {
211            err_msg: format!("Invalid CIDR notation: {}", cidr),
212        }
213    );
214
215    // Parse the IP address part
216    let ip = Ipv4Addr::from_str(parts[0]).map_err(|_| {
217        InvalidFuncArgsSnafu {
218            err_msg: format!("Invalid IPv4 address in CIDR: {}", parts[0]),
219        }
220        .build()
221    })?;
222
223    // Parse the prefix length
224    let prefix = parts[1].parse::<u8>().map_err(|_| {
225        InvalidFuncArgsSnafu {
226            err_msg: format!("Invalid prefix length: {}", parts[1]),
227        }
228        .build()
229    })?;
230
231    ensure!(
232        prefix <= 32,
233        InvalidFuncArgsSnafu {
234            err_msg: format!("IPv4 prefix length must be <= 32, got {}", prefix),
235        }
236    );
237
238    Ok((ip, prefix))
239}
240
241fn parse_ipv6_cidr(cidr: &str) -> Result<(Ipv6Addr, u8)> {
242    // Split the CIDR string into IP and prefix parts
243    let parts: Vec<&str> = cidr.split('/').collect();
244    ensure!(
245        parts.len() == 2,
246        InvalidFuncArgsSnafu {
247            err_msg: format!("Invalid CIDR notation: {}", cidr),
248        }
249    );
250
251    // Parse the IP address part
252    let ip = Ipv6Addr::from_str(parts[0]).map_err(|_| {
253        InvalidFuncArgsSnafu {
254            err_msg: format!("Invalid IPv6 address in CIDR: {}", parts[0]),
255        }
256        .build()
257    })?;
258
259    // Parse the prefix length
260    let prefix = parts[1].parse::<u8>().map_err(|_| {
261        InvalidFuncArgsSnafu {
262            err_msg: format!("Invalid prefix length: {}", parts[1]),
263        }
264        .build()
265    })?;
266
267    ensure!(
268        prefix <= 128,
269        InvalidFuncArgsSnafu {
270            err_msg: format!("IPv6 prefix length must be <= 128, got {}", prefix),
271        }
272    );
273
274    Ok((ip, prefix))
275}
276
277fn is_ipv4_in_range(ip: &Ipv4Addr, cidr_base: &Ipv4Addr, prefix_len: u8) -> Option<bool> {
278    // Convert both IPs to integers
279    let ip_int = u32::from(*ip);
280    let cidr_int = u32::from(*cidr_base);
281
282    // Calculate the mask from the prefix length
283    let mask = if prefix_len == 0 {
284        0
285    } else {
286        u32::MAX << (32 - prefix_len)
287    };
288
289    // Apply the mask to both IPs and see if they match
290    let ip_network = ip_int & mask;
291    let cidr_network = cidr_int & mask;
292
293    Some(ip_network == cidr_network)
294}
295
296fn is_ipv6_in_range(ip: &Ipv6Addr, cidr_base: &Ipv6Addr, prefix_len: u8) -> Option<bool> {
297    // Get the octets (16 bytes) of both IPs
298    let ip_octets = ip.octets();
299    let cidr_octets = cidr_base.octets();
300
301    // Calculate how many full bytes to compare
302    let full_bytes = (prefix_len / 8) as usize;
303
304    // First, check full bytes for equality
305    for i in 0..full_bytes {
306        if ip_octets[i] != cidr_octets[i] {
307            return Some(false);
308        }
309    }
310
311    // If there's a partial byte to check
312    if !prefix_len.is_multiple_of(8) && full_bytes < 16 {
313        let bits_to_check = prefix_len % 8;
314        let mask = 0xFF_u8 << (8 - bits_to_check);
315
316        if (ip_octets[full_bytes] & mask) != (cidr_octets[full_bytes] & mask) {
317            return Some(false);
318        }
319    }
320
321    // If we got here, everything matched
322    Some(true)
323}
324
325#[cfg(test)]
326mod tests {
327    use std::sync::Arc;
328
329    use arrow_schema::Field;
330    use datafusion_common::arrow::array::StringViewArray;
331
332    use super::*;
333
334    #[test]
335    fn test_ipv4_in_range() {
336        let func = Ipv4InRange::default();
337
338        // Test IPs
339        let ip_values = vec![
340            "192.168.1.5",
341            "192.168.2.1",
342            "10.0.0.1",
343            "10.1.0.1",
344            "172.16.0.1",
345        ];
346
347        // Corresponding CIDR ranges
348        let cidr_values = vec![
349            "192.168.1.0/24",
350            "192.168.1.0/24",
351            "10.0.0.0/8",
352            "10.0.0.0/8",
353            "172.16.0.0/16",
354        ];
355
356        let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
357        let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values)));
358
359        let args = ScalarFunctionArgs {
360            args: vec![arg0, arg1],
361            arg_fields: vec![],
362            number_rows: 5,
363            return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
364            config_options: Arc::new(Default::default()),
365        };
366        let result = func.invoke_with_args(args).unwrap();
367        let result = result.to_array(5).unwrap();
368        let result = result.as_boolean();
369
370        // Expected results
371        assert!(result.value(0)); // 192.168.1.5 is in 192.168.1.0/24
372        assert!(!result.value(1)); // 192.168.2.1 is not in 192.168.1.0/24
373        assert!(result.value(2)); // 10.0.0.1 is in 10.0.0.0/8
374        assert!(result.value(3)); // 10.1.0.1 is in 10.0.0.0/8
375        assert!(result.value(4)); // 172.16.0.1 is in 172.16.0.0/16
376    }
377
378    #[test]
379    fn test_ipv6_in_range() {
380        let func = Ipv6InRange::default();
381
382        // Test IPs
383        let ip_values = vec![
384            "2001:db8::1",
385            "2001:db8:1::",
386            "2001:db9::1",
387            "::1",
388            "fe80::1",
389        ];
390
391        // Corresponding CIDR ranges
392        let cidr_values = vec![
393            "2001:db8::/32",
394            "2001:db8::/32",
395            "2001:db8::/32",
396            "::1/128",
397            "fe80::/16",
398        ];
399
400        let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
401        let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values)));
402
403        let args = ScalarFunctionArgs {
404            args: vec![arg0, arg1],
405            arg_fields: vec![],
406            number_rows: 5,
407            return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
408            config_options: Arc::new(Default::default()),
409        };
410        let result = func.invoke_with_args(args).unwrap();
411        let result = result.to_array(5).unwrap();
412        let result = result.as_boolean();
413
414        // Expected results
415        assert!(result.value(0)); // 2001:db8::1 is in 2001:db8::/32
416        assert!(result.value(1)); // 2001:db8:1:: is in 2001:db8::/32
417        assert!(!result.value(2)); // 2001:db9::1 is not in 2001:db8::/32
418        assert!(result.value(3)); // ::1 is in ::1/128
419        assert!(result.value(4)); // fe80::1 is in fe80::/16
420    }
421
422    #[test]
423    fn test_invalid_inputs() {
424        let ipv4_func = Ipv4InRange::default();
425        let ipv6_func = Ipv6InRange::default();
426
427        // Invalid IPv4 address
428        let invalid_ip_values = vec!["not-an-ip", "192.168.1.300"];
429        let cidr_values = vec!["192.168.1.0/24", "192.168.1.0/24"];
430
431        let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(
432            &invalid_ip_values,
433        )));
434        let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values)));
435
436        let args = ScalarFunctionArgs {
437            args: vec![arg0, arg1],
438            arg_fields: vec![],
439            number_rows: 2,
440            return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
441            config_options: Arc::new(Default::default()),
442        };
443        let result = ipv4_func.invoke_with_args(args);
444        assert!(result.is_err());
445
446        // Invalid CIDR notation
447        let ip_values = vec!["192.168.1.1", "2001:db8::1"];
448        let invalid_cidr_values = vec!["192.168.1.0", "2001:db8::/129"];
449
450        let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
451        let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(
452            &invalid_cidr_values,
453        )));
454
455        let args = ScalarFunctionArgs {
456            args: vec![arg0, arg1],
457            arg_fields: vec![],
458            number_rows: 2,
459            return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
460            config_options: Arc::new(Default::default()),
461        };
462        let ipv4_result = ipv4_func.invoke_with_args(args.clone());
463        let ipv6_result = ipv6_func.invoke_with_args(args);
464
465        assert!(ipv4_result.is_err());
466        assert!(ipv6_result.is_err());
467    }
468
469    #[test]
470    fn test_edge_cases() {
471        let ipv4_func = Ipv4InRange::default();
472
473        // Edge cases like prefix length 0 (matches everything) and 32 (exact match)
474        let ip_values = vec!["8.8.8.8", "192.168.1.1", "192.168.1.1"];
475        let cidr_values = vec!["0.0.0.0/0", "192.168.1.1/32", "192.168.1.0/32"];
476
477        let arg0 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&ip_values)));
478        let arg1 = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&cidr_values)));
479
480        let args = ScalarFunctionArgs {
481            args: vec![arg0, arg1],
482            arg_fields: vec![],
483            number_rows: 3,
484            return_field: Arc::new(Field::new("x", DataType::Boolean, false)),
485            config_options: Arc::new(Default::default()),
486        };
487        let result = ipv4_func.invoke_with_args(args).unwrap();
488        let result = result.to_array(3).unwrap();
489        let result = result.as_boolean();
490
491        assert!(result.value(0)); // 8.8.8.8 is in 0.0.0.0/0 (matches everything)
492        assert!(result.value(1)); // 192.168.1.1 is in 192.168.1.1/32 (exact match)
493        assert!(!result.value(2)); // 192.168.1.1 is not in 192.168.1.0/32 (no match)
494    }
495}