common_function/scalars/ip/
ipv4.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::net::Ipv4Addr;
16use std::str::FromStr;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion_expr::{Signature, TypeSignature, Volatility};
20use datatypes::arrow::datatypes::DataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{MutableVector, StringVectorBuilder, UInt32VectorBuilder, VectorRef};
23use derive_more::Display;
24use snafu::ensure;
25
26use crate::function::{Function, FunctionContext};
27
28/// Function that converts a UInt32 number to an IPv4 address string.
29///
30/// Interprets the number as an IPv4 address in big endian and returns
31/// a string in the format A.B.C.D (dot-separated numbers in decimal form).
32///
33/// For example:
34/// - 167772160 (0x0A000000) returns "10.0.0.0"
35/// - 3232235521 (0xC0A80001) returns "192.168.0.1"
36#[derive(Clone, Debug, Display)]
37#[display("{}", self.name())]
38pub struct Ipv4NumToString {
39    aliases: [String; 1],
40}
41
42impl Default for Ipv4NumToString {
43    fn default() -> Self {
44        Self {
45            aliases: ["inet_ntoa".to_string()],
46        }
47    }
48}
49
50impl Function for Ipv4NumToString {
51    fn name(&self) -> &str {
52        "ipv4_num_to_string"
53    }
54
55    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
56        Ok(DataType::Utf8)
57    }
58
59    fn signature(&self) -> Signature {
60        Signature::new(
61            TypeSignature::Exact(vec![DataType::UInt32]),
62            Volatility::Immutable,
63        )
64    }
65
66    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
67        ensure!(
68            columns.len() == 1,
69            InvalidFuncArgsSnafu {
70                err_msg: format!("Expected 1 argument, got {}", columns.len())
71            }
72        );
73
74        let uint_vec = &columns[0];
75        let size = uint_vec.len();
76        let mut results = StringVectorBuilder::with_capacity(size);
77
78        for i in 0..size {
79            let ip_num = uint_vec.get(i);
80            let ip_str = match ip_num {
81                datatypes::value::Value::UInt32(num) => {
82                    // Convert UInt32 to IPv4 string (A.B.C.D format)
83                    let a = (num >> 24) & 0xFF;
84                    let b = (num >> 16) & 0xFF;
85                    let c = (num >> 8) & 0xFF;
86                    let d = num & 0xFF;
87                    Some(format!("{}.{}.{}.{}", a, b, c, d))
88                }
89                _ => None,
90            };
91
92            results.push(ip_str.as_deref());
93        }
94
95        Ok(results.to_vector())
96    }
97
98    fn aliases(&self) -> &[String] {
99        &self.aliases
100    }
101}
102
103/// Function that converts a string representation of an IPv4 address to a UInt32 number.
104///
105/// For example:
106/// - "10.0.0.1" returns 167772161
107/// - "192.168.0.1" returns 3232235521
108/// - Invalid IPv4 format throws an exception
109#[derive(Clone, Debug, Default, Display)]
110#[display("{}", self.name())]
111pub struct Ipv4StringToNum;
112
113impl Function for Ipv4StringToNum {
114    fn name(&self) -> &str {
115        "ipv4_string_to_num"
116    }
117
118    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
119        Ok(DataType::UInt32)
120    }
121
122    fn signature(&self) -> Signature {
123        Signature::string(1, Volatility::Immutable)
124    }
125
126    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
127        ensure!(
128            columns.len() == 1,
129            InvalidFuncArgsSnafu {
130                err_msg: format!("Expected 1 argument, got {}", columns.len())
131            }
132        );
133
134        let ip_vec = &columns[0];
135        let size = ip_vec.len();
136        let mut results = UInt32VectorBuilder::with_capacity(size);
137
138        for i in 0..size {
139            let ip_str = ip_vec.get(i);
140            let ip_num = match ip_str {
141                datatypes::value::Value::String(s) => {
142                    let ip_str = s.as_utf8();
143                    let ip_addr = Ipv4Addr::from_str(ip_str).map_err(|_| {
144                        InvalidFuncArgsSnafu {
145                            err_msg: format!("Invalid IPv4 address format: {}", ip_str),
146                        }
147                        .build()
148                    })?;
149                    Some(u32::from(ip_addr))
150                }
151                _ => None,
152            };
153
154            results.push(ip_num);
155        }
156
157        Ok(results.to_vector())
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use std::sync::Arc;
164
165    use datatypes::scalars::ScalarVector;
166    use datatypes::vectors::{StringVector, UInt32Vector};
167
168    use super::*;
169
170    #[test]
171    fn test_ipv4_num_to_string() {
172        let func = Ipv4NumToString::default();
173        let ctx = FunctionContext::default();
174
175        // Test data
176        let values = vec![167772161u32, 3232235521u32, 0u32, 4294967295u32];
177        let input = Arc::new(UInt32Vector::from_vec(values)) as VectorRef;
178
179        let result = func.eval(&ctx, &[input]).unwrap();
180        let result = result.as_any().downcast_ref::<StringVector>().unwrap();
181
182        assert_eq!(result.get_data(0).unwrap(), "10.0.0.1");
183        assert_eq!(result.get_data(1).unwrap(), "192.168.0.1");
184        assert_eq!(result.get_data(2).unwrap(), "0.0.0.0");
185        assert_eq!(result.get_data(3).unwrap(), "255.255.255.255");
186    }
187
188    #[test]
189    fn test_ipv4_string_to_num() {
190        let func = Ipv4StringToNum;
191        let ctx = FunctionContext::default();
192
193        // Test data
194        let values = vec!["10.0.0.1", "192.168.0.1", "0.0.0.0", "255.255.255.255"];
195        let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
196
197        let result = func.eval(&ctx, &[input]).unwrap();
198        let result = result.as_any().downcast_ref::<UInt32Vector>().unwrap();
199
200        assert_eq!(result.get_data(0).unwrap(), 167772161);
201        assert_eq!(result.get_data(1).unwrap(), 3232235521);
202        assert_eq!(result.get_data(2).unwrap(), 0);
203        assert_eq!(result.get_data(3).unwrap(), 4294967295);
204    }
205
206    #[test]
207    fn test_ipv4_conversions_roundtrip() {
208        let to_num = Ipv4StringToNum;
209        let to_string = Ipv4NumToString::default();
210        let ctx = FunctionContext::default();
211
212        // Test data for string to num to string
213        let values = vec!["10.0.0.1", "192.168.0.1", "0.0.0.0", "255.255.255.255"];
214        let input = Arc::new(StringVector::from_slice(&values)) as VectorRef;
215
216        let num_result = to_num.eval(&ctx, &[input]).unwrap();
217        let back_to_string = to_string.eval(&ctx, &[num_result]).unwrap();
218        let str_result = back_to_string
219            .as_any()
220            .downcast_ref::<StringVector>()
221            .unwrap();
222
223        for (i, expected) in values.iter().enumerate() {
224            assert_eq!(str_result.get_data(i).unwrap(), *expected);
225        }
226    }
227}