common_function/scalars/vector/convert/
vector_to_string.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16use std::sync::Arc;
17
18use datafusion_common::DataFusionError;
19use datafusion_common::arrow::array::{Array, AsArray, StringViewBuilder};
20use datafusion_common::arrow::compute;
21use datafusion_common::arrow::datatypes::DataType;
22use datafusion_common::types::logical_binary;
23use datafusion_expr::{
24    Coercion, ColumnarValue, ScalarFunctionArgs, Signature, TypeSignatureClass, Volatility,
25};
26use datatypes::types::vector_type_value_to_string;
27
28use crate::function::{Function, extract_args};
29
30const NAME: &str = "vec_to_string";
31
32#[derive(Debug, Clone)]
33pub struct VectorToStringFunction {
34    signature: Signature,
35}
36
37impl Default for VectorToStringFunction {
38    fn default() -> Self {
39        Self {
40            signature: Signature::coercible(
41                vec![Coercion::new_exact(TypeSignatureClass::Native(
42                    logical_binary(),
43                ))],
44                Volatility::Immutable,
45            ),
46        }
47    }
48}
49
50impl Function for VectorToStringFunction {
51    fn name(&self) -> &str {
52        NAME
53    }
54
55    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
56        Ok(DataType::Utf8View)
57    }
58
59    fn signature(&self) -> &Signature {
60        &self.signature
61    }
62
63    fn invoke_with_args(
64        &self,
65        args: ScalarFunctionArgs,
66    ) -> datafusion_common::Result<ColumnarValue> {
67        let [arg0] = extract_args(self.name(), &args)?;
68        let arg0 = compute::cast(&arg0, &DataType::BinaryView)?;
69        let column = arg0.as_binary_view();
70
71        let size = column.len();
72
73        let mut builder = StringViewBuilder::with_capacity(size);
74        for i in 0..size {
75            let value = column.is_valid(i).then(|| column.value(i));
76            match value {
77                Some(bytes) => {
78                    let len = bytes.len();
79                    if len % std::mem::size_of::<f32>() != 0 {
80                        return Err(DataFusionError::Execution(format!(
81                            "Invalid binary length of vector: {len}"
82                        )));
83                    }
84
85                    let dim = len / std::mem::size_of::<f32>();
86                    // Safety: `dim` is calculated from the length of `bytes` and is guaranteed to be valid
87                    let result = vector_type_value_to_string(bytes, dim as _).unwrap();
88                    builder.append_value(result);
89                }
90                None => {
91                    builder.append_null();
92                }
93            }
94        }
95
96        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
97    }
98}
99
100impl Display for VectorToStringFunction {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        write!(f, "{}", NAME.to_ascii_uppercase())
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use arrow_schema::Field;
109    use datafusion_common::arrow::array::BinaryViewBuilder;
110
111    use super::*;
112
113    #[test]
114    fn test_vector_to_string() {
115        let func = VectorToStringFunction::default();
116
117        let mut builder = BinaryViewBuilder::with_capacity(3);
118        builder.append_option(Some(
119            [1.0f32, 2.0, 3.0]
120                .iter()
121                .flat_map(|e| e.to_le_bytes())
122                .collect::<Vec<_>>()
123                .as_slice(),
124        ));
125        builder.append_option(Some(
126            [4.0f32, 5.0, 6.0]
127                .iter()
128                .flat_map(|e| e.to_le_bytes())
129                .collect::<Vec<_>>()
130                .as_slice(),
131        ));
132        builder.append_null();
133        let args = ScalarFunctionArgs {
134            args: vec![ColumnarValue::Array(Arc::new(builder.finish()))],
135            arg_fields: vec![],
136            number_rows: 3,
137            return_field: Arc::new(Field::new("", DataType::Utf8View, false)),
138            config_options: Arc::new(Default::default()),
139        };
140
141        let result = func
142            .invoke_with_args(args)
143            .and_then(|x| x.to_array(3))
144            .unwrap();
145        let result = result.as_string_view();
146
147        assert_eq!(result.len(), 3);
148        assert_eq!(result.value(0), "[1,2,3]");
149        assert_eq!(result.value(1), "[4,5,6]");
150        assert!(result.is_null(2));
151    }
152}