common_function/scalars/string/
format.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! MySQL-compatible FORMAT function implementation.
16//!
17//! FORMAT(X, D) - Formats the number X with D decimal places using thousand separators.
18
19use std::fmt;
20use std::sync::Arc;
21
22use datafusion_common::DataFusionError;
23use datafusion_common::arrow::array::{Array, AsArray, LargeStringBuilder};
24use datafusion_common::arrow::datatypes as arrow_types;
25use datafusion_common::arrow::datatypes::DataType;
26use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
27
28use crate::function::Function;
29use crate::function_registry::FunctionRegistry;
30
31const NAME: &str = "format";
32
33/// MySQL-compatible FORMAT function.
34///
35/// Syntax: FORMAT(X, D)
36/// Formats the number X to a format like '#,###,###.##', rounded to D decimal places.
37/// D can be 0 to 30.
38///
39/// Note: This implementation uses the en_US locale (comma as thousand separator,
40/// period as decimal separator).
41#[derive(Debug)]
42pub struct FormatFunction {
43    signature: Signature,
44}
45
46impl FormatFunction {
47    pub fn register(registry: &FunctionRegistry) {
48        registry.register_scalar(FormatFunction::default());
49    }
50}
51
52impl Default for FormatFunction {
53    fn default() -> Self {
54        let mut signatures = Vec::new();
55
56        // Support various numeric types for X
57        let numeric_types = [
58            DataType::Float64,
59            DataType::Float32,
60            DataType::Int64,
61            DataType::Int32,
62            DataType::Int16,
63            DataType::Int8,
64            DataType::UInt64,
65            DataType::UInt32,
66            DataType::UInt16,
67            DataType::UInt8,
68        ];
69
70        // D can be various integer types
71        let int_types = [
72            DataType::Int64,
73            DataType::Int32,
74            DataType::Int16,
75            DataType::Int8,
76            DataType::UInt64,
77            DataType::UInt32,
78            DataType::UInt16,
79            DataType::UInt8,
80        ];
81
82        for x_type in &numeric_types {
83            for d_type in &int_types {
84                signatures.push(TypeSignature::Exact(vec![x_type.clone(), d_type.clone()]));
85            }
86        }
87
88        Self {
89            signature: Signature::one_of(signatures, Volatility::Immutable),
90        }
91    }
92}
93
94impl fmt::Display for FormatFunction {
95    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
96        write!(f, "{}", NAME.to_ascii_uppercase())
97    }
98}
99
100impl Function for FormatFunction {
101    fn name(&self) -> &str {
102        NAME
103    }
104
105    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
106        Ok(DataType::LargeUtf8)
107    }
108
109    fn signature(&self) -> &Signature {
110        &self.signature
111    }
112
113    fn invoke_with_args(
114        &self,
115        args: ScalarFunctionArgs,
116    ) -> datafusion_common::Result<ColumnarValue> {
117        if args.args.len() != 2 {
118            return Err(DataFusionError::Execution(
119                "FORMAT requires exactly 2 arguments: FORMAT(X, D)".to_string(),
120            ));
121        }
122
123        let arrays = ColumnarValue::values_to_arrays(&args.args)?;
124        let len = arrays[0].len();
125
126        let x_array = &arrays[0];
127        let d_array = &arrays[1];
128
129        let mut builder = LargeStringBuilder::with_capacity(len, len * 20);
130
131        for i in 0..len {
132            if x_array.is_null(i) || d_array.is_null(i) {
133                builder.append_null();
134                continue;
135            }
136
137            let decimal_places = get_decimal_places(d_array, i)?.clamp(0, 30) as usize;
138
139            let formatted = match x_array.data_type() {
140                DataType::Float64 | DataType::Float32 => {
141                    format_number_float(get_float_value(x_array, i)?, decimal_places)
142                }
143                DataType::Int64
144                | DataType::Int32
145                | DataType::Int16
146                | DataType::Int8
147                | DataType::UInt64
148                | DataType::UInt32
149                | DataType::UInt16
150                | DataType::UInt8 => format_number_integer(x_array, i, decimal_places)?,
151                _ => {
152                    return Err(DataFusionError::Execution(format!(
153                        "FORMAT: unsupported type {:?}",
154                        x_array.data_type()
155                    )));
156                }
157            };
158            builder.append_value(&formatted);
159        }
160
161        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
162    }
163}
164
165/// Get float value from various numeric types.
166fn get_float_value(
167    array: &datafusion_common::arrow::array::ArrayRef,
168    index: usize,
169) -> datafusion_common::Result<f64> {
170    match array.data_type() {
171        DataType::Float64 => Ok(array
172            .as_primitive::<arrow_types::Float64Type>()
173            .value(index)),
174        DataType::Float32 => Ok(array
175            .as_primitive::<arrow_types::Float32Type>()
176            .value(index) as f64),
177        _ => Err(DataFusionError::Execution(format!(
178            "FORMAT: unsupported type {:?}",
179            array.data_type()
180        ))),
181    }
182}
183
184/// Get decimal places from various integer types.
185///
186/// MySQL clamps decimal places to `0..=30`. This function returns an `i64` so the caller can clamp.
187fn get_decimal_places(
188    array: &datafusion_common::arrow::array::ArrayRef,
189    index: usize,
190) -> datafusion_common::Result<i64> {
191    match array.data_type() {
192        DataType::Int64 => Ok(array.as_primitive::<arrow_types::Int64Type>().value(index)),
193        DataType::Int32 => Ok(array.as_primitive::<arrow_types::Int32Type>().value(index) as i64),
194        DataType::Int16 => Ok(array.as_primitive::<arrow_types::Int16Type>().value(index) as i64),
195        DataType::Int8 => Ok(array.as_primitive::<arrow_types::Int8Type>().value(index) as i64),
196        DataType::UInt64 => {
197            let v = array.as_primitive::<arrow_types::UInt64Type>().value(index);
198            Ok(if v > i64::MAX as u64 {
199                i64::MAX
200            } else {
201                v as i64
202            })
203        }
204        DataType::UInt32 => Ok(array.as_primitive::<arrow_types::UInt32Type>().value(index) as i64),
205        DataType::UInt16 => Ok(array.as_primitive::<arrow_types::UInt16Type>().value(index) as i64),
206        DataType::UInt8 => Ok(array.as_primitive::<arrow_types::UInt8Type>().value(index) as i64),
207        _ => Err(DataFusionError::Execution(format!(
208            "FORMAT: unsupported type {:?}",
209            array.data_type()
210        ))),
211    }
212}
213
214fn format_number_integer(
215    array: &datafusion_common::arrow::array::ArrayRef,
216    index: usize,
217    decimal_places: usize,
218) -> datafusion_common::Result<String> {
219    let (is_negative, abs_digits) = match array.data_type() {
220        DataType::Int64 => {
221            let v = array.as_primitive::<arrow_types::Int64Type>().value(index) as i128;
222            (v.is_negative(), v.unsigned_abs().to_string())
223        }
224        DataType::Int32 => {
225            let v = array.as_primitive::<arrow_types::Int32Type>().value(index) as i128;
226            (v.is_negative(), v.unsigned_abs().to_string())
227        }
228        DataType::Int16 => {
229            let v = array.as_primitive::<arrow_types::Int16Type>().value(index) as i128;
230            (v.is_negative(), v.unsigned_abs().to_string())
231        }
232        DataType::Int8 => {
233            let v = array.as_primitive::<arrow_types::Int8Type>().value(index) as i128;
234            (v.is_negative(), v.unsigned_abs().to_string())
235        }
236        DataType::UInt64 => {
237            let v = array.as_primitive::<arrow_types::UInt64Type>().value(index) as u128;
238            (false, v.to_string())
239        }
240        DataType::UInt32 => {
241            let v = array.as_primitive::<arrow_types::UInt32Type>().value(index) as u128;
242            (false, v.to_string())
243        }
244        DataType::UInt16 => {
245            let v = array.as_primitive::<arrow_types::UInt16Type>().value(index) as u128;
246            (false, v.to_string())
247        }
248        DataType::UInt8 => {
249            let v = array.as_primitive::<arrow_types::UInt8Type>().value(index) as u128;
250            (false, v.to_string())
251        }
252        _ => {
253            return Err(DataFusionError::Execution(format!(
254                "FORMAT: unsupported type {:?}",
255                array.data_type()
256            )));
257        }
258    };
259
260    let mut result = String::new();
261    if is_negative {
262        result.push('-');
263    }
264    result.push_str(&add_thousand_separators(&abs_digits));
265
266    if decimal_places > 0 {
267        result.push('.');
268        result.push_str(&"0".repeat(decimal_places));
269    }
270
271    Ok(result)
272}
273
274/// Format a float with thousand separators and `decimal_places` digits after decimal point.
275fn format_number_float(x: f64, decimal_places: usize) -> String {
276    // Handle special cases
277    if x.is_nan() {
278        return "NaN".to_string();
279    }
280    if x.is_infinite() {
281        return if x.is_sign_positive() {
282            "Infinity".to_string()
283        } else {
284            "-Infinity".to_string()
285        };
286    }
287
288    // Round to decimal_places
289    let multiplier = 10f64.powi(decimal_places as i32);
290    let rounded = (x * multiplier).round() / multiplier;
291
292    // Split into integer and fractional parts
293    let is_negative = rounded < 0.0;
294    let abs_value = rounded.abs();
295
296    // Format with the specified decimal places
297    let formatted = if decimal_places == 0 {
298        format!("{:.0}", abs_value)
299    } else {
300        format!("{:.prec$}", abs_value, prec = decimal_places)
301    };
302
303    // Split at decimal point
304    let parts: Vec<&str> = formatted.split('.').collect();
305    let int_part = parts[0];
306    let dec_part = parts.get(1).copied();
307
308    // Add thousand separators to integer part
309    let int_with_sep = add_thousand_separators(int_part);
310
311    // Build result
312    let mut result = String::new();
313    if is_negative {
314        result.push('-');
315    }
316    result.push_str(&int_with_sep);
317    if let Some(dec) = dec_part {
318        result.push('.');
319        result.push_str(dec);
320    }
321
322    result
323}
324
325/// Add thousand separators (commas) to an integer string.
326fn add_thousand_separators(s: &str) -> String {
327    let chars: Vec<char> = s.chars().collect();
328    let len = chars.len();
329
330    if len <= 3 {
331        return s.to_string();
332    }
333
334    let mut result = String::with_capacity(len + len / 3);
335    let first_group_len = len % 3;
336    let first_group_len = if first_group_len == 0 {
337        3
338    } else {
339        first_group_len
340    };
341
342    for (i, ch) in chars.iter().enumerate() {
343        if i > 0 && i >= first_group_len && (i - first_group_len) % 3 == 0 {
344            result.push(',');
345        }
346        result.push(*ch);
347    }
348
349    result
350}
351
352#[cfg(test)]
353mod tests {
354    use std::sync::Arc;
355
356    use datafusion_common::arrow::array::{Float64Array, Int64Array};
357    use datafusion_common::arrow::datatypes::Field;
358    use datafusion_expr::ScalarFunctionArgs;
359
360    use super::*;
361
362    fn create_args(arrays: Vec<datafusion_common::arrow::array::ArrayRef>) -> ScalarFunctionArgs {
363        let arg_fields: Vec<_> = arrays
364            .iter()
365            .enumerate()
366            .map(|(i, arr)| {
367                Arc::new(Field::new(
368                    format!("arg_{}", i),
369                    arr.data_type().clone(),
370                    true,
371                ))
372            })
373            .collect();
374
375        ScalarFunctionArgs {
376            args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
377            arg_fields,
378            return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
379            number_rows: arrays[0].len(),
380            config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
381        }
382    }
383
384    #[test]
385    fn test_format_basic() {
386        let function = FormatFunction::default();
387
388        let x = Arc::new(Float64Array::from(vec![1234567.891, 1234.5, 1234567.0]));
389        let d = Arc::new(Int64Array::from(vec![2, 0, 3]));
390
391        let args = create_args(vec![x, d]);
392        let result = function.invoke_with_args(args).unwrap();
393
394        if let ColumnarValue::Array(array) = result {
395            let str_array = array.as_string::<i64>();
396            assert_eq!(str_array.value(0), "1,234,567.89");
397            assert_eq!(str_array.value(1), "1,235"); // rounded
398            assert_eq!(str_array.value(2), "1,234,567.000");
399        } else {
400            panic!("Expected array result");
401        }
402    }
403
404    #[test]
405    fn test_format_negative() {
406        let function = FormatFunction::default();
407
408        let x = Arc::new(Float64Array::from(vec![-1234567.891]));
409        let d = Arc::new(Int64Array::from(vec![2]));
410
411        let args = create_args(vec![x, d]);
412        let result = function.invoke_with_args(args).unwrap();
413
414        if let ColumnarValue::Array(array) = result {
415            let str_array = array.as_string::<i64>();
416            assert_eq!(str_array.value(0), "-1,234,567.89");
417        } else {
418            panic!("Expected array result");
419        }
420    }
421
422    #[test]
423    fn test_format_small_numbers() {
424        let function = FormatFunction::default();
425
426        let x = Arc::new(Float64Array::from(vec![0.5, 12.345, 123.0]));
427        let d = Arc::new(Int64Array::from(vec![2, 2, 0]));
428
429        let args = create_args(vec![x, d]);
430        let result = function.invoke_with_args(args).unwrap();
431
432        if let ColumnarValue::Array(array) = result {
433            let str_array = array.as_string::<i64>();
434            assert_eq!(str_array.value(0), "0.50");
435            assert_eq!(str_array.value(1), "12.35"); // rounded
436            assert_eq!(str_array.value(2), "123");
437        } else {
438            panic!("Expected array result");
439        }
440    }
441
442    #[test]
443    fn test_format_with_nulls() {
444        let function = FormatFunction::default();
445
446        let x = Arc::new(Float64Array::from(vec![Some(1234.5), None]));
447        let d = Arc::new(Int64Array::from(vec![2, 2]));
448
449        let args = create_args(vec![x, d]);
450        let result = function.invoke_with_args(args).unwrap();
451
452        if let ColumnarValue::Array(array) = result {
453            let str_array = array.as_string::<i64>();
454            assert_eq!(str_array.value(0), "1,234.50");
455            assert!(str_array.is_null(1));
456        } else {
457            panic!("Expected array result");
458        }
459    }
460
461    #[test]
462    fn test_add_thousand_separators() {
463        assert_eq!(add_thousand_separators("1"), "1");
464        assert_eq!(add_thousand_separators("12"), "12");
465        assert_eq!(add_thousand_separators("123"), "123");
466        assert_eq!(add_thousand_separators("1234"), "1,234");
467        assert_eq!(add_thousand_separators("12345"), "12,345");
468        assert_eq!(add_thousand_separators("123456"), "123,456");
469        assert_eq!(add_thousand_separators("1234567"), "1,234,567");
470        assert_eq!(add_thousand_separators("12345678"), "12,345,678");
471        assert_eq!(add_thousand_separators("123456789"), "123,456,789");
472    }
473
474    #[test]
475    fn test_format_large_int_no_float_precision_loss() {
476        let function = FormatFunction::default();
477
478        // 2^53 + 1 cannot be represented exactly as f64.
479        let x = Arc::new(Int64Array::from(vec![9_007_199_254_740_993i64]));
480        let d = Arc::new(Int64Array::from(vec![0]));
481
482        let args = create_args(vec![x, d]);
483        let result = function.invoke_with_args(args).unwrap();
484
485        if let ColumnarValue::Array(array) = result {
486            let str_array = array.as_string::<i64>();
487            assert_eq!(str_array.value(0), "9,007,199,254,740,993");
488        } else {
489            panic!("Expected array result");
490        }
491    }
492
493    #[test]
494    fn test_format_decimal_places_u64_overflow_clamps() {
495        use datafusion_common::arrow::array::UInt64Array;
496
497        let function = FormatFunction::default();
498
499        let x = Arc::new(Int64Array::from(vec![1]));
500        let d = Arc::new(UInt64Array::from(vec![u64::MAX]));
501
502        let args = create_args(vec![x, d]);
503        let result = function.invoke_with_args(args).unwrap();
504
505        if let ColumnarValue::Array(array) = result {
506            let str_array = array.as_string::<i64>();
507            assert_eq!(str_array.value(0), format!("1.{}", "0".repeat(30)));
508        } else {
509            panic!("Expected array result");
510        }
511    }
512}