common_function/scalars/string/
format.rs1use std::fmt;
20use std::sync::Arc;
21
22use datafusion_common::DataFusionError;
23use datafusion_common::arrow::array::{Array, AsArray, LargeStringBuilder};
24use datafusion_common::arrow::datatypes as arrow_types;
25use datafusion_common::arrow::datatypes::DataType;
26use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
27
28use crate::function::Function;
29use crate::function_registry::FunctionRegistry;
30
31const NAME: &str = "format";
32
33#[derive(Debug)]
42pub struct FormatFunction {
43 signature: Signature,
44}
45
46impl FormatFunction {
47 pub fn register(registry: &FunctionRegistry) {
48 registry.register_scalar(FormatFunction::default());
49 }
50}
51
52impl Default for FormatFunction {
53 fn default() -> Self {
54 let mut signatures = Vec::new();
55
56 let numeric_types = [
58 DataType::Float64,
59 DataType::Float32,
60 DataType::Int64,
61 DataType::Int32,
62 DataType::Int16,
63 DataType::Int8,
64 DataType::UInt64,
65 DataType::UInt32,
66 DataType::UInt16,
67 DataType::UInt8,
68 ];
69
70 let int_types = [
72 DataType::Int64,
73 DataType::Int32,
74 DataType::Int16,
75 DataType::Int8,
76 DataType::UInt64,
77 DataType::UInt32,
78 DataType::UInt16,
79 DataType::UInt8,
80 ];
81
82 for x_type in &numeric_types {
83 for d_type in &int_types {
84 signatures.push(TypeSignature::Exact(vec![x_type.clone(), d_type.clone()]));
85 }
86 }
87
88 Self {
89 signature: Signature::one_of(signatures, Volatility::Immutable),
90 }
91 }
92}
93
94impl fmt::Display for FormatFunction {
95 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
96 write!(f, "{}", NAME.to_ascii_uppercase())
97 }
98}
99
100impl Function for FormatFunction {
101 fn name(&self) -> &str {
102 NAME
103 }
104
105 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
106 Ok(DataType::LargeUtf8)
107 }
108
109 fn signature(&self) -> &Signature {
110 &self.signature
111 }
112
113 fn invoke_with_args(
114 &self,
115 args: ScalarFunctionArgs,
116 ) -> datafusion_common::Result<ColumnarValue> {
117 if args.args.len() != 2 {
118 return Err(DataFusionError::Execution(
119 "FORMAT requires exactly 2 arguments: FORMAT(X, D)".to_string(),
120 ));
121 }
122
123 let arrays = ColumnarValue::values_to_arrays(&args.args)?;
124 let len = arrays[0].len();
125
126 let x_array = &arrays[0];
127 let d_array = &arrays[1];
128
129 let mut builder = LargeStringBuilder::with_capacity(len, len * 20);
130
131 for i in 0..len {
132 if x_array.is_null(i) || d_array.is_null(i) {
133 builder.append_null();
134 continue;
135 }
136
137 let decimal_places = get_decimal_places(d_array, i)?.clamp(0, 30) as usize;
138
139 let formatted = match x_array.data_type() {
140 DataType::Float64 | DataType::Float32 => {
141 format_number_float(get_float_value(x_array, i)?, decimal_places)
142 }
143 DataType::Int64
144 | DataType::Int32
145 | DataType::Int16
146 | DataType::Int8
147 | DataType::UInt64
148 | DataType::UInt32
149 | DataType::UInt16
150 | DataType::UInt8 => format_number_integer(x_array, i, decimal_places)?,
151 _ => {
152 return Err(DataFusionError::Execution(format!(
153 "FORMAT: unsupported type {:?}",
154 x_array.data_type()
155 )));
156 }
157 };
158 builder.append_value(&formatted);
159 }
160
161 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
162 }
163}
164
165fn get_float_value(
167 array: &datafusion_common::arrow::array::ArrayRef,
168 index: usize,
169) -> datafusion_common::Result<f64> {
170 match array.data_type() {
171 DataType::Float64 => Ok(array
172 .as_primitive::<arrow_types::Float64Type>()
173 .value(index)),
174 DataType::Float32 => Ok(array
175 .as_primitive::<arrow_types::Float32Type>()
176 .value(index) as f64),
177 _ => Err(DataFusionError::Execution(format!(
178 "FORMAT: unsupported type {:?}",
179 array.data_type()
180 ))),
181 }
182}
183
184fn get_decimal_places(
188 array: &datafusion_common::arrow::array::ArrayRef,
189 index: usize,
190) -> datafusion_common::Result<i64> {
191 match array.data_type() {
192 DataType::Int64 => Ok(array.as_primitive::<arrow_types::Int64Type>().value(index)),
193 DataType::Int32 => Ok(array.as_primitive::<arrow_types::Int32Type>().value(index) as i64),
194 DataType::Int16 => Ok(array.as_primitive::<arrow_types::Int16Type>().value(index) as i64),
195 DataType::Int8 => Ok(array.as_primitive::<arrow_types::Int8Type>().value(index) as i64),
196 DataType::UInt64 => {
197 let v = array.as_primitive::<arrow_types::UInt64Type>().value(index);
198 Ok(if v > i64::MAX as u64 {
199 i64::MAX
200 } else {
201 v as i64
202 })
203 }
204 DataType::UInt32 => Ok(array.as_primitive::<arrow_types::UInt32Type>().value(index) as i64),
205 DataType::UInt16 => Ok(array.as_primitive::<arrow_types::UInt16Type>().value(index) as i64),
206 DataType::UInt8 => Ok(array.as_primitive::<arrow_types::UInt8Type>().value(index) as i64),
207 _ => Err(DataFusionError::Execution(format!(
208 "FORMAT: unsupported type {:?}",
209 array.data_type()
210 ))),
211 }
212}
213
214fn format_number_integer(
215 array: &datafusion_common::arrow::array::ArrayRef,
216 index: usize,
217 decimal_places: usize,
218) -> datafusion_common::Result<String> {
219 let (is_negative, abs_digits) = match array.data_type() {
220 DataType::Int64 => {
221 let v = array.as_primitive::<arrow_types::Int64Type>().value(index) as i128;
222 (v.is_negative(), v.unsigned_abs().to_string())
223 }
224 DataType::Int32 => {
225 let v = array.as_primitive::<arrow_types::Int32Type>().value(index) as i128;
226 (v.is_negative(), v.unsigned_abs().to_string())
227 }
228 DataType::Int16 => {
229 let v = array.as_primitive::<arrow_types::Int16Type>().value(index) as i128;
230 (v.is_negative(), v.unsigned_abs().to_string())
231 }
232 DataType::Int8 => {
233 let v = array.as_primitive::<arrow_types::Int8Type>().value(index) as i128;
234 (v.is_negative(), v.unsigned_abs().to_string())
235 }
236 DataType::UInt64 => {
237 let v = array.as_primitive::<arrow_types::UInt64Type>().value(index) as u128;
238 (false, v.to_string())
239 }
240 DataType::UInt32 => {
241 let v = array.as_primitive::<arrow_types::UInt32Type>().value(index) as u128;
242 (false, v.to_string())
243 }
244 DataType::UInt16 => {
245 let v = array.as_primitive::<arrow_types::UInt16Type>().value(index) as u128;
246 (false, v.to_string())
247 }
248 DataType::UInt8 => {
249 let v = array.as_primitive::<arrow_types::UInt8Type>().value(index) as u128;
250 (false, v.to_string())
251 }
252 _ => {
253 return Err(DataFusionError::Execution(format!(
254 "FORMAT: unsupported type {:?}",
255 array.data_type()
256 )));
257 }
258 };
259
260 let mut result = String::new();
261 if is_negative {
262 result.push('-');
263 }
264 result.push_str(&add_thousand_separators(&abs_digits));
265
266 if decimal_places > 0 {
267 result.push('.');
268 result.push_str(&"0".repeat(decimal_places));
269 }
270
271 Ok(result)
272}
273
274fn format_number_float(x: f64, decimal_places: usize) -> String {
276 if x.is_nan() {
278 return "NaN".to_string();
279 }
280 if x.is_infinite() {
281 return if x.is_sign_positive() {
282 "Infinity".to_string()
283 } else {
284 "-Infinity".to_string()
285 };
286 }
287
288 let multiplier = 10f64.powi(decimal_places as i32);
290 let rounded = (x * multiplier).round() / multiplier;
291
292 let is_negative = rounded < 0.0;
294 let abs_value = rounded.abs();
295
296 let formatted = if decimal_places == 0 {
298 format!("{:.0}", abs_value)
299 } else {
300 format!("{:.prec$}", abs_value, prec = decimal_places)
301 };
302
303 let parts: Vec<&str> = formatted.split('.').collect();
305 let int_part = parts[0];
306 let dec_part = parts.get(1).copied();
307
308 let int_with_sep = add_thousand_separators(int_part);
310
311 let mut result = String::new();
313 if is_negative {
314 result.push('-');
315 }
316 result.push_str(&int_with_sep);
317 if let Some(dec) = dec_part {
318 result.push('.');
319 result.push_str(dec);
320 }
321
322 result
323}
324
325fn add_thousand_separators(s: &str) -> String {
327 let chars: Vec<char> = s.chars().collect();
328 let len = chars.len();
329
330 if len <= 3 {
331 return s.to_string();
332 }
333
334 let mut result = String::with_capacity(len + len / 3);
335 let first_group_len = len % 3;
336 let first_group_len = if first_group_len == 0 {
337 3
338 } else {
339 first_group_len
340 };
341
342 for (i, ch) in chars.iter().enumerate() {
343 if i > 0 && i >= first_group_len && (i - first_group_len) % 3 == 0 {
344 result.push(',');
345 }
346 result.push(*ch);
347 }
348
349 result
350}
351
352#[cfg(test)]
353mod tests {
354 use std::sync::Arc;
355
356 use datafusion_common::arrow::array::{Float64Array, Int64Array};
357 use datafusion_common::arrow::datatypes::Field;
358 use datafusion_expr::ScalarFunctionArgs;
359
360 use super::*;
361
362 fn create_args(arrays: Vec<datafusion_common::arrow::array::ArrayRef>) -> ScalarFunctionArgs {
363 let arg_fields: Vec<_> = arrays
364 .iter()
365 .enumerate()
366 .map(|(i, arr)| {
367 Arc::new(Field::new(
368 format!("arg_{}", i),
369 arr.data_type().clone(),
370 true,
371 ))
372 })
373 .collect();
374
375 ScalarFunctionArgs {
376 args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
377 arg_fields,
378 return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
379 number_rows: arrays[0].len(),
380 config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
381 }
382 }
383
384 #[test]
385 fn test_format_basic() {
386 let function = FormatFunction::default();
387
388 let x = Arc::new(Float64Array::from(vec![1234567.891, 1234.5, 1234567.0]));
389 let d = Arc::new(Int64Array::from(vec![2, 0, 3]));
390
391 let args = create_args(vec![x, d]);
392 let result = function.invoke_with_args(args).unwrap();
393
394 if let ColumnarValue::Array(array) = result {
395 let str_array = array.as_string::<i64>();
396 assert_eq!(str_array.value(0), "1,234,567.89");
397 assert_eq!(str_array.value(1), "1,235"); assert_eq!(str_array.value(2), "1,234,567.000");
399 } else {
400 panic!("Expected array result");
401 }
402 }
403
404 #[test]
405 fn test_format_negative() {
406 let function = FormatFunction::default();
407
408 let x = Arc::new(Float64Array::from(vec![-1234567.891]));
409 let d = Arc::new(Int64Array::from(vec![2]));
410
411 let args = create_args(vec![x, d]);
412 let result = function.invoke_with_args(args).unwrap();
413
414 if let ColumnarValue::Array(array) = result {
415 let str_array = array.as_string::<i64>();
416 assert_eq!(str_array.value(0), "-1,234,567.89");
417 } else {
418 panic!("Expected array result");
419 }
420 }
421
422 #[test]
423 fn test_format_small_numbers() {
424 let function = FormatFunction::default();
425
426 let x = Arc::new(Float64Array::from(vec![0.5, 12.345, 123.0]));
427 let d = Arc::new(Int64Array::from(vec![2, 2, 0]));
428
429 let args = create_args(vec![x, d]);
430 let result = function.invoke_with_args(args).unwrap();
431
432 if let ColumnarValue::Array(array) = result {
433 let str_array = array.as_string::<i64>();
434 assert_eq!(str_array.value(0), "0.50");
435 assert_eq!(str_array.value(1), "12.35"); assert_eq!(str_array.value(2), "123");
437 } else {
438 panic!("Expected array result");
439 }
440 }
441
442 #[test]
443 fn test_format_with_nulls() {
444 let function = FormatFunction::default();
445
446 let x = Arc::new(Float64Array::from(vec![Some(1234.5), None]));
447 let d = Arc::new(Int64Array::from(vec![2, 2]));
448
449 let args = create_args(vec![x, d]);
450 let result = function.invoke_with_args(args).unwrap();
451
452 if let ColumnarValue::Array(array) = result {
453 let str_array = array.as_string::<i64>();
454 assert_eq!(str_array.value(0), "1,234.50");
455 assert!(str_array.is_null(1));
456 } else {
457 panic!("Expected array result");
458 }
459 }
460
461 #[test]
462 fn test_add_thousand_separators() {
463 assert_eq!(add_thousand_separators("1"), "1");
464 assert_eq!(add_thousand_separators("12"), "12");
465 assert_eq!(add_thousand_separators("123"), "123");
466 assert_eq!(add_thousand_separators("1234"), "1,234");
467 assert_eq!(add_thousand_separators("12345"), "12,345");
468 assert_eq!(add_thousand_separators("123456"), "123,456");
469 assert_eq!(add_thousand_separators("1234567"), "1,234,567");
470 assert_eq!(add_thousand_separators("12345678"), "12,345,678");
471 assert_eq!(add_thousand_separators("123456789"), "123,456,789");
472 }
473
474 #[test]
475 fn test_format_large_int_no_float_precision_loss() {
476 let function = FormatFunction::default();
477
478 let x = Arc::new(Int64Array::from(vec![9_007_199_254_740_993i64]));
480 let d = Arc::new(Int64Array::from(vec![0]));
481
482 let args = create_args(vec![x, d]);
483 let result = function.invoke_with_args(args).unwrap();
484
485 if let ColumnarValue::Array(array) = result {
486 let str_array = array.as_string::<i64>();
487 assert_eq!(str_array.value(0), "9,007,199,254,740,993");
488 } else {
489 panic!("Expected array result");
490 }
491 }
492
493 #[test]
494 fn test_format_decimal_places_u64_overflow_clamps() {
495 use datafusion_common::arrow::array::UInt64Array;
496
497 let function = FormatFunction::default();
498
499 let x = Arc::new(Int64Array::from(vec![1]));
500 let d = Arc::new(UInt64Array::from(vec![u64::MAX]));
501
502 let args = create_args(vec![x, d]);
503 let result = function.invoke_with_args(args).unwrap();
504
505 if let ColumnarValue::Array(array) = result {
506 let str_array = array.as_string::<i64>();
507 assert_eq!(str_array.value(0), format!("1.{}", "0".repeat(30)));
508 } else {
509 panic!("Expected array result");
510 }
511 }
512}