common_function/scalars/ip/
ipv4.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::net::Ipv4Addr;
16use std::str::FromStr;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::{Signature, TypeSignature};
20use datafusion::logical_expr::Volatility;
21use datatypes::prelude::ConcreteDataType;
22use datatypes::scalars::ScalarVectorBuilder;
23use datatypes::vectors::{MutableVector, StringVectorBuilder, UInt32VectorBuilder, VectorRef};
24use derive_more::Display;
25use snafu::ensure;
26
27use crate::function::{Function, FunctionContext};
28
29/// Function that converts a UInt32 number to an IPv4 address string.
30///
31/// Interprets the number as an IPv4 address in big endian and returns
32/// a string in the format A.B.C.D (dot-separated numbers in decimal form).
33///
34/// For example:
35/// - 167772160 (0x0A000000) returns "10.0.0.0"
36/// - 3232235521 (0xC0A80001) returns "192.168.0.1"
37#[derive(Clone, Debug, Default, Display)]
38#[display("{}", self.name())]
39pub struct Ipv4NumToString;
40
41impl Function for Ipv4NumToString {
42    fn name(&self) -> &str {
43        "ipv4_num_to_string"
44    }
45
46    fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
47        Ok(ConcreteDataType::string_datatype())
48    }
49
50    fn signature(&self) -> Signature {
51        Signature::new(
52            TypeSignature::Exact(vec![ConcreteDataType::uint32_datatype()]),
53            Volatility::Immutable,
54        )
55    }
56
57    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
58        ensure!(
59            columns.len() == 1,
60            InvalidFuncArgsSnafu {
61                err_msg: format!("Expected 1 argument, got {}", columns.len())
62            }
63        );
64
65        let uint_vec = &columns[0];
66        let size = uint_vec.len();
67        let mut results = StringVectorBuilder::with_capacity(size);
68
69        for i in 0..size {
70            let ip_num = uint_vec.get(i);
71            let ip_str = match ip_num {
72                datatypes::value::Value::UInt32(num) => {
73                    // Convert UInt32 to IPv4 string (A.B.C.D format)
74                    let a = (num >> 24) & 0xFF;
75                    let b = (num >> 16) & 0xFF;
76                    let c = (num >> 8) & 0xFF;
77                    let d = num & 0xFF;
78                    Some(format!("{}.{}.{}.{}", a, b, c, d))
79                }
80                _ => None,
81            };
82
83            results.push(ip_str.as_deref());
84        }
85
86        Ok(results.to_vector())
87    }
88}
89
90/// Function that converts a string representation of an IPv4 address to a UInt32 number.
91///
92/// For example:
93/// - "10.0.0.1" returns 167772161
94/// - "192.168.0.1" returns 3232235521
95/// - Invalid IPv4 format throws an exception
96#[derive(Clone, Debug, Default, Display)]
97#[display("{}", self.name())]
98pub struct Ipv4StringToNum;
99
100impl Function for Ipv4StringToNum {
101    fn name(&self) -> &str {
102        "ipv4_string_to_num"
103    }
104
105    fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
106        Ok(ConcreteDataType::uint32_datatype())
107    }
108
109    fn signature(&self) -> Signature {
110        Signature::new(
111            TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
112            Volatility::Immutable,
113        )
114    }
115
116    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
117        ensure!(
118            columns.len() == 1,
119            InvalidFuncArgsSnafu {
120                err_msg: format!("Expected 1 argument, got {}", columns.len())
121            }
122        );
123
124        let ip_vec = &columns[0];
125        let size = ip_vec.len();
126        let mut results = UInt32VectorBuilder::with_capacity(size);
127
128        for i in 0..size {
129            let ip_str = ip_vec.get(i);
130            let ip_num = match ip_str {
131                datatypes::value::Value::String(s) => {
132                    let ip_str = s.as_utf8();
133                    let ip_addr = Ipv4Addr::from_str(ip_str).map_err(|_| {
134                        InvalidFuncArgsSnafu {
135                            err_msg: format!("Invalid IPv4 address format: {}", ip_str),
136                        }
137                        .build()
138                    })?;
139                    Some(u32::from(ip_addr))
140                }
141                _ => None,
142            };
143
144            results.push(ip_num);
145        }
146
147        Ok(results.to_vector())
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use std::sync::Arc;
154
155    use datatypes::scalars::ScalarVector;
156    use datatypes::vectors::{StringVector, UInt32Vector};
157
158    use super::*;
159
160    #[test]
161    fn test_ipv4_num_to_string() {
162        let func = Ipv4NumToString;
163        let ctx = FunctionContext::default();
164
165        // Test data
166        let values = vec![167772161u32, 3232235521u32, 0u32, 4294967295u32];
167        let input = Arc::new(UInt32Vector::from_vec(values)) as VectorRef;
168
169        let result = func.eval(&ctx, &[input]).unwrap();
170        let result = result.as_any().downcast_ref::<StringVector>().unwrap();
171
172        assert_eq!(result.get_data(0).unwrap(), "10.0.0.1");
173        assert_eq!(result.get_data(1).unwrap(), "192.168.0.1");
174        assert_eq!(result.get_data(2).unwrap(), "0.0.0.0");
175        assert_eq!(result.get_data(3).unwrap(), "255.255.255.255");
176    }
177
178    #[test]
179    fn test_ipv4_string_to_num() {
180        let func = Ipv4StringToNum;
181        let ctx = FunctionContext::default();
182
183        // Test data
184        let values = vec!["10.0.0.1", "192.168.0.1", "0.0.0.0", "255.255.255.255"];
185        let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
186
187        let result = func.eval(&ctx, &[input]).unwrap();
188        let result = result.as_any().downcast_ref::<UInt32Vector>().unwrap();
189
190        assert_eq!(result.get_data(0).unwrap(), 167772161);
191        assert_eq!(result.get_data(1).unwrap(), 3232235521);
192        assert_eq!(result.get_data(2).unwrap(), 0);
193        assert_eq!(result.get_data(3).unwrap(), 4294967295);
194    }
195
196    #[test]
197    fn test_ipv4_conversions_roundtrip() {
198        let to_num = Ipv4StringToNum;
199        let to_string = Ipv4NumToString;
200        let ctx = FunctionContext::default();
201
202        // Test data for string to num to string
203        let values = vec!["10.0.0.1", "192.168.0.1", "0.0.0.0", "255.255.255.255"];
204        let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
205
206        let num_result = to_num.eval(&ctx, &[input]).unwrap();
207        let back_to_string = to_string.eval(&ctx, &[num_result]).unwrap();
208        let str_result = back_to_string
209            .as_any()
210            .downcast_ref::<StringVector>()
211            .unwrap();
212
213        for (i, expected) in values.iter().enumerate() {
214            assert_eq!(str_result.get_data(i).unwrap(), *expected);
215        }
216    }
217}