common_function/scalars/ip/
ipv4.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::net::Ipv4Addr;
16use std::str::FromStr;
17use std::sync::Arc;
18
19use common_query::error::InvalidFuncArgsSnafu;
20use datafusion_common::arrow::array::{Array, AsArray, StringViewBuilder, UInt32Builder};
21use datafusion_common::arrow::compute;
22use datafusion_common::arrow::datatypes::{DataType, UInt32Type};
23use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
24use derive_more::Display;
25
26use crate::function::{Function, extract_args};
27
28/// Function that converts a UInt32 number to an IPv4 address string.
29///
30/// Interprets the number as an IPv4 address in big endian and returns
31/// a string in the format A.B.C.D (dot-separated numbers in decimal form).
32///
33/// For example:
34/// - 167772160 (0x0A000000) returns "10.0.0.0"
35/// - 3232235521 (0xC0A80001) returns "192.168.0.1"
36#[derive(Clone, Debug, Display)]
37#[display("{}", self.name())]
38pub struct Ipv4NumToString {
39    signature: Signature,
40    aliases: [String; 1],
41}
42
43impl Default for Ipv4NumToString {
44    fn default() -> Self {
45        Self {
46            signature: Signature::new(
47                TypeSignature::Exact(vec![DataType::UInt32]),
48                Volatility::Immutable,
49            ),
50            aliases: ["inet_ntoa".to_string()],
51        }
52    }
53}
54
55impl Function for Ipv4NumToString {
56    fn name(&self) -> &str {
57        "ipv4_num_to_string"
58    }
59
60    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
61        Ok(DataType::Utf8View)
62    }
63
64    fn signature(&self) -> &Signature {
65        &self.signature
66    }
67
68    fn invoke_with_args(
69        &self,
70        args: ScalarFunctionArgs,
71    ) -> datafusion_common::Result<ColumnarValue> {
72        let [arg0] = extract_args(self.name(), &args)?;
73        let uint_vec = arg0.as_primitive::<UInt32Type>();
74
75        let size = uint_vec.len();
76        let mut builder = StringViewBuilder::with_capacity(size);
77
78        for i in 0..size {
79            let ip_num = uint_vec.is_valid(i).then(|| uint_vec.value(i));
80            let ip_str = match ip_num {
81                Some(num) => {
82                    // Convert UInt32 to IPv4 string (A.B.C.D format)
83                    let a = (num >> 24) & 0xFF;
84                    let b = (num >> 16) & 0xFF;
85                    let c = (num >> 8) & 0xFF;
86                    let d = num & 0xFF;
87                    Some(format!("{}.{}.{}.{}", a, b, c, d))
88                }
89                _ => None,
90            };
91
92            builder.append_option(ip_str.as_deref());
93        }
94
95        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
96    }
97
98    fn aliases(&self) -> &[String] {
99        &self.aliases
100    }
101}
102
103/// Function that converts a string representation of an IPv4 address to a UInt32 number.
104///
105/// For example:
106/// - "10.0.0.1" returns 167772161
107/// - "192.168.0.1" returns 3232235521
108/// - Invalid IPv4 format throws an exception
109#[derive(Clone, Debug, Display)]
110#[display("{}", self.name())]
111pub(crate) struct Ipv4StringToNum {
112    signature: Signature,
113}
114
115impl Default for Ipv4StringToNum {
116    fn default() -> Self {
117        Self {
118            signature: Signature::string(1, Volatility::Immutable),
119        }
120    }
121}
122
123impl Function for Ipv4StringToNum {
124    fn name(&self) -> &str {
125        "ipv4_string_to_num"
126    }
127
128    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
129        Ok(DataType::UInt32)
130    }
131
132    fn signature(&self) -> &Signature {
133        &self.signature
134    }
135
136    fn invoke_with_args(
137        &self,
138        args: ScalarFunctionArgs,
139    ) -> datafusion_common::Result<ColumnarValue> {
140        let [arg0] = extract_args(self.name(), &args)?;
141
142        let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
143        let ip_vec = arg0.as_string_view();
144        let size = ip_vec.len();
145        let mut builder = UInt32Builder::with_capacity(size);
146
147        for i in 0..size {
148            let ip_str = ip_vec.is_valid(i).then(|| ip_vec.value(i));
149            let ip_num = match ip_str {
150                Some(ip_str) => {
151                    let ip_addr = Ipv4Addr::from_str(ip_str).map_err(|_| {
152                        InvalidFuncArgsSnafu {
153                            err_msg: format!("Invalid IPv4 address format: {}", ip_str),
154                        }
155                        .build()
156                    })?;
157                    Some(u32::from(ip_addr))
158                }
159                _ => None,
160            };
161
162            builder.append_option(ip_num);
163        }
164
165        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use std::sync::Arc;
172
173    use arrow_schema::Field;
174    use datafusion_common::arrow::array::{StringViewArray, UInt32Array};
175
176    use super::*;
177
178    #[test]
179    fn test_ipv4_num_to_string() {
180        let func = Ipv4NumToString::default();
181
182        // Test data
183        let values = vec![167772161u32, 3232235521u32, 0u32, 4294967295u32];
184        let input = ColumnarValue::Array(Arc::new(UInt32Array::from(values)));
185
186        let args = ScalarFunctionArgs {
187            args: vec![input],
188            arg_fields: vec![],
189            number_rows: 4,
190            return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
191            config_options: Arc::new(Default::default()),
192        };
193        let result = func.invoke_with_args(args).unwrap();
194        let result = result.to_array(4).unwrap();
195        let result = result.as_string_view();
196
197        assert_eq!(result.value(0), "10.0.0.1");
198        assert_eq!(result.value(1), "192.168.0.1");
199        assert_eq!(result.value(2), "0.0.0.0");
200        assert_eq!(result.value(3), "255.255.255.255");
201    }
202
203    #[test]
204    fn test_ipv4_string_to_num() {
205        let func = Ipv4StringToNum::default();
206
207        // Test data
208        let values = vec!["10.0.0.1", "192.168.0.1", "0.0.0.0", "255.255.255.255"];
209        let input = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
210
211        let args = ScalarFunctionArgs {
212            args: vec![input],
213            arg_fields: vec![],
214            number_rows: 4,
215            return_field: Arc::new(Field::new("x", DataType::UInt32, false)),
216            config_options: Arc::new(Default::default()),
217        };
218        let result = func.invoke_with_args(args).unwrap();
219        let result = result.to_array(4).unwrap();
220        let result = result.as_primitive::<UInt32Type>();
221
222        assert_eq!(result.value(0), 167772161);
223        assert_eq!(result.value(1), 3232235521);
224        assert_eq!(result.value(2), 0);
225        assert_eq!(result.value(3), 4294967295);
226    }
227
228    #[test]
229    fn test_ipv4_conversions_roundtrip() {
230        let to_num = Ipv4StringToNum::default();
231        let to_string = Ipv4NumToString::default();
232
233        // Test data for string to num to string
234        let values = vec!["10.0.0.1", "192.168.0.1", "0.0.0.0", "255.255.255.255"];
235        let input = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
236
237        let args = ScalarFunctionArgs {
238            args: vec![input],
239            arg_fields: vec![],
240            number_rows: 4,
241            return_field: Arc::new(Field::new("x", DataType::UInt32, false)),
242            config_options: Arc::new(Default::default()),
243        };
244        let result = to_num.invoke_with_args(args).unwrap();
245
246        let args = ScalarFunctionArgs {
247            args: vec![result],
248            arg_fields: vec![],
249            number_rows: 4,
250            return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
251            config_options: Arc::new(Default::default()),
252        };
253        let result = to_string.invoke_with_args(args).unwrap();
254        let result = result.to_array(4).unwrap();
255        let result = result.as_string_view();
256
257        for (i, expected) in values.iter().enumerate() {
258            assert_eq!(result.value(i), *expected);
259        }
260    }
261}