Skip to main content

common_function/scalars/ip/
ipv4.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::net::Ipv4Addr;
16use std::str::FromStr;
17use std::sync::Arc;
18
19use common_query::error::InvalidFuncArgsSnafu;
20use datafusion_common::arrow::array::{Array, AsArray, StringViewBuilder, UInt32Builder};
21use datafusion_common::arrow::compute;
22use datafusion_common::arrow::datatypes::{DataType, UInt32Type};
23use datafusion_expr::{
24    Coercion, ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, TypeSignatureClass,
25    Volatility,
26};
27use derive_more::Display;
28
29use crate::function::{Function, extract_args};
30
31/// Function that converts a UInt32 number to an IPv4 address string.
32///
33/// Interprets the number as an IPv4 address in big endian and returns
34/// a string in the format A.B.C.D (dot-separated numbers in decimal form).
35///
36/// For example:
37/// - 167772160 (0x0A000000) returns "10.0.0.0"
38/// - 3232235521 (0xC0A80001) returns "192.168.0.1"
39#[derive(Clone, Debug, Display)]
40#[display("{}", self.name())]
41pub struct Ipv4NumToString {
42    signature: Signature,
43    aliases: [String; 1],
44}
45
46impl Default for Ipv4NumToString {
47    fn default() -> Self {
48        Self {
49            signature: Signature::new(
50                TypeSignature::Coercible(vec![Coercion::new_exact(TypeSignatureClass::Integer)]),
51                Volatility::Immutable,
52            ),
53            aliases: ["inet_ntoa".to_string()],
54        }
55    }
56}
57
58impl Function for Ipv4NumToString {
59    fn name(&self) -> &str {
60        "ipv4_num_to_string"
61    }
62
63    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
64        Ok(DataType::Utf8View)
65    }
66
67    fn signature(&self) -> &Signature {
68        &self.signature
69    }
70
71    fn invoke_with_args(
72        &self,
73        args: ScalarFunctionArgs,
74    ) -> datafusion_common::Result<ColumnarValue> {
75        let [arg0] = extract_args(self.name(), &args)?;
76        let arg0 = compute::cast_with_options(
77            &arg0,
78            &DataType::UInt32,
79            &compute::CastOptions {
80                safe: false,
81                ..Default::default()
82            },
83        )?;
84        let uint_vec = arg0.as_primitive::<UInt32Type>();
85
86        let size = uint_vec.len();
87        let mut builder = StringViewBuilder::with_capacity(size);
88
89        for i in 0..size {
90            let ip_num = uint_vec.is_valid(i).then(|| uint_vec.value(i));
91            let ip_str = match ip_num {
92                Some(num) => {
93                    // Convert UInt32 to IPv4 string (A.B.C.D format)
94                    let a = (num >> 24) & 0xFF;
95                    let b = (num >> 16) & 0xFF;
96                    let c = (num >> 8) & 0xFF;
97                    let d = num & 0xFF;
98                    Some(format!("{}.{}.{}.{}", a, b, c, d))
99                }
100                _ => None,
101            };
102
103            builder.append_option(ip_str.as_deref());
104        }
105
106        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
107    }
108
109    fn aliases(&self) -> &[String] {
110        &self.aliases
111    }
112}
113
114/// Function that converts a string representation of an IPv4 address to a UInt32 number.
115///
116/// For example:
117/// - "10.0.0.1" returns 167772161
118/// - "192.168.0.1" returns 3232235521
119/// - Invalid IPv4 format throws an exception
120#[derive(Clone, Debug, Display)]
121#[display("{}", self.name())]
122pub(crate) struct Ipv4StringToNum {
123    signature: Signature,
124}
125
126impl Default for Ipv4StringToNum {
127    fn default() -> Self {
128        Self {
129            signature: Signature::string(1, Volatility::Immutable),
130        }
131    }
132}
133
134impl Function for Ipv4StringToNum {
135    fn name(&self) -> &str {
136        "ipv4_string_to_num"
137    }
138
139    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
140        Ok(DataType::UInt32)
141    }
142
143    fn signature(&self) -> &Signature {
144        &self.signature
145    }
146
147    fn invoke_with_args(
148        &self,
149        args: ScalarFunctionArgs,
150    ) -> datafusion_common::Result<ColumnarValue> {
151        let [arg0] = extract_args(self.name(), &args)?;
152
153        let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
154        let ip_vec = arg0.as_string_view();
155        let size = ip_vec.len();
156        let mut builder = UInt32Builder::with_capacity(size);
157
158        for i in 0..size {
159            let ip_str = ip_vec.is_valid(i).then(|| ip_vec.value(i));
160            let ip_num = match ip_str {
161                Some(ip_str) => {
162                    let ip_addr = Ipv4Addr::from_str(ip_str).map_err(|_| {
163                        InvalidFuncArgsSnafu {
164                            err_msg: format!("Invalid IPv4 address format: {}", ip_str),
165                        }
166                        .build()
167                    })?;
168                    Some(u32::from(ip_addr))
169                }
170                _ => None,
171            };
172
173            builder.append_option(ip_num);
174        }
175
176        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use std::sync::Arc;
183
184    use arrow_schema::Field;
185    use datafusion_common::arrow::array::{Int64Array, StringViewArray, UInt32Array};
186
187    use super::*;
188
189    #[test]
190    fn test_ipv4_num_to_string() {
191        let func = Ipv4NumToString::default();
192
193        // Test data
194        let values = vec![167772161u32, 3232235521u32, 0u32, 4294967295u32];
195        let input = ColumnarValue::Array(Arc::new(UInt32Array::from(values)));
196
197        let args = ScalarFunctionArgs {
198            args: vec![input],
199            arg_fields: vec![],
200            number_rows: 4,
201            return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
202            config_options: Arc::new(Default::default()),
203        };
204        let result = func.invoke_with_args(args).unwrap();
205        let result = result.to_array(4).unwrap();
206        let result = result.as_string_view();
207
208        assert_eq!(result.value(0), "10.0.0.1");
209        assert_eq!(result.value(1), "192.168.0.1");
210        assert_eq!(result.value(2), "0.0.0.0");
211        assert_eq!(result.value(3), "255.255.255.255");
212    }
213
214    #[test]
215    fn test_ipv4_num_to_string_accepts_int64() {
216        let func = Ipv4NumToString::default();
217
218        // Test data
219        let values = vec![167772161i64, 3232235521i64, 0i64, 4294967295i64];
220        let input = ColumnarValue::Array(Arc::new(Int64Array::from(values)));
221
222        let args = ScalarFunctionArgs {
223            args: vec![input],
224            arg_fields: vec![],
225            number_rows: 4,
226            return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
227            config_options: Arc::new(Default::default()),
228        };
229        let result = func.invoke_with_args(args).unwrap();
230        let result = result.to_array(4).unwrap();
231        let result = result.as_string_view();
232
233        assert_eq!(result.value(0), "10.0.0.1");
234        assert_eq!(result.value(1), "192.168.0.1");
235        assert_eq!(result.value(2), "0.0.0.0");
236        assert_eq!(result.value(3), "255.255.255.255");
237    }
238
239    #[test]
240    fn test_ipv4_num_to_string_rejects_negative_int64() {
241        let func = Ipv4NumToString::default();
242
243        // Test data
244        let values = vec![-1i64];
245        let input = ColumnarValue::Array(Arc::new(Int64Array::from(values)));
246
247        let args = ScalarFunctionArgs {
248            args: vec![input],
249            arg_fields: vec![],
250            number_rows: 1,
251            return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
252            config_options: Arc::new(Default::default()),
253        };
254        let result = func.invoke_with_args(args);
255
256        assert!(result.is_err());
257    }
258
259    #[test]
260    fn test_ipv4_string_to_num() {
261        let func = Ipv4StringToNum::default();
262
263        // Test data
264        let values = vec!["10.0.0.1", "192.168.0.1", "0.0.0.0", "255.255.255.255"];
265        let input = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
266
267        let args = ScalarFunctionArgs {
268            args: vec![input],
269            arg_fields: vec![],
270            number_rows: 4,
271            return_field: Arc::new(Field::new("x", DataType::UInt32, false)),
272            config_options: Arc::new(Default::default()),
273        };
274        let result = func.invoke_with_args(args).unwrap();
275        let result = result.to_array(4).unwrap();
276        let result = result.as_primitive::<UInt32Type>();
277
278        assert_eq!(result.value(0), 167772161);
279        assert_eq!(result.value(1), 3232235521);
280        assert_eq!(result.value(2), 0);
281        assert_eq!(result.value(3), 4294967295);
282    }
283
284    #[test]
285    fn test_ipv4_conversions_roundtrip() {
286        let to_num = Ipv4StringToNum::default();
287        let to_string = Ipv4NumToString::default();
288
289        // Test data for string to num to string
290        let values = vec!["10.0.0.1", "192.168.0.1", "0.0.0.0", "255.255.255.255"];
291        let input = ColumnarValue::Array(Arc::new(StringViewArray::from_iter_values(&values)));
292
293        let args = ScalarFunctionArgs {
294            args: vec![input],
295            arg_fields: vec![],
296            number_rows: 4,
297            return_field: Arc::new(Field::new("x", DataType::UInt32, false)),
298            config_options: Arc::new(Default::default()),
299        };
300        let result = to_num.invoke_with_args(args).unwrap();
301
302        let args = ScalarFunctionArgs {
303            args: vec![result],
304            arg_fields: vec![],
305            number_rows: 4,
306            return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
307            config_options: Arc::new(Default::default()),
308        };
309        let result = to_string.invoke_with_args(args).unwrap();
310        let result = result.to_array(4).unwrap();
311        let result = result.as_string_view();
312
313        for (i, expected) in values.iter().enumerate() {
314            assert_eq!(result.value(i), *expected);
315        }
316    }
317}