common_function/scalars/vector/convert/
vector_to_string.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16use std::sync::Arc;
17
18use datafusion_common::DataFusionError;
19use datafusion_common::arrow::array::{Array, AsArray, StringViewBuilder};
20use datafusion_common::arrow::compute;
21use datafusion_common::arrow::datatypes::DataType;
22use datafusion_expr::type_coercion::aggregates::BINARYS;
23use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
24use datatypes::types::vector_type_value_to_string;
25
26use crate::function::{Function, extract_args};
27
28const NAME: &str = "vec_to_string";
29
30#[derive(Debug, Clone)]
31pub struct VectorToStringFunction {
32    signature: Signature,
33}
34
35impl Default for VectorToStringFunction {
36    fn default() -> Self {
37        Self {
38            signature: Signature::one_of(
39                vec![
40                    TypeSignature::Uniform(1, vec![DataType::BinaryView]),
41                    TypeSignature::Uniform(1, BINARYS.to_vec()),
42                ],
43                Volatility::Immutable,
44            ),
45        }
46    }
47}
48
49impl Function for VectorToStringFunction {
50    fn name(&self) -> &str {
51        NAME
52    }
53
54    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
55        Ok(DataType::Utf8View)
56    }
57
58    fn signature(&self) -> &Signature {
59        &self.signature
60    }
61
62    fn invoke_with_args(
63        &self,
64        args: ScalarFunctionArgs,
65    ) -> datafusion_common::Result<ColumnarValue> {
66        let [arg0] = extract_args(self.name(), &args)?;
67        let arg0 = compute::cast(&arg0, &DataType::BinaryView)?;
68        let column = arg0.as_binary_view();
69
70        let size = column.len();
71
72        let mut builder = StringViewBuilder::with_capacity(size);
73        for i in 0..size {
74            let value = column.is_valid(i).then(|| column.value(i));
75            match value {
76                Some(bytes) => {
77                    let len = bytes.len();
78                    if len % std::mem::size_of::<f32>() != 0 {
79                        return Err(DataFusionError::Execution(format!(
80                            "Invalid binary length of vector: {len}"
81                        )));
82                    }
83
84                    let dim = len / std::mem::size_of::<f32>();
85                    // Safety: `dim` is calculated from the length of `bytes` and is guaranteed to be valid
86                    let result = vector_type_value_to_string(bytes, dim as _).unwrap();
87                    builder.append_value(result);
88                }
89                None => {
90                    builder.append_null();
91                }
92            }
93        }
94
95        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
96    }
97}
98
99impl Display for VectorToStringFunction {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        write!(f, "{}", NAME.to_ascii_uppercase())
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use arrow_schema::Field;
108    use datafusion_common::arrow::array::BinaryViewBuilder;
109
110    use super::*;
111
112    #[test]
113    fn test_vector_to_string() {
114        let func = VectorToStringFunction::default();
115
116        let mut builder = BinaryViewBuilder::with_capacity(3);
117        builder.append_option(Some(
118            [1.0f32, 2.0, 3.0]
119                .iter()
120                .flat_map(|e| e.to_le_bytes())
121                .collect::<Vec<_>>()
122                .as_slice(),
123        ));
124        builder.append_option(Some(
125            [4.0f32, 5.0, 6.0]
126                .iter()
127                .flat_map(|e| e.to_le_bytes())
128                .collect::<Vec<_>>()
129                .as_slice(),
130        ));
131        builder.append_null();
132        let args = ScalarFunctionArgs {
133            args: vec![ColumnarValue::Array(Arc::new(builder.finish()))],
134            arg_fields: vec![],
135            number_rows: 3,
136            return_field: Arc::new(Field::new("", DataType::Utf8View, false)),
137            config_options: Arc::new(Default::default()),
138        };
139
140        let result = func
141            .invoke_with_args(args)
142            .and_then(|x| x.to_array(3))
143            .unwrap();
144        let result = result.as_string_view();
145
146        assert_eq!(result.len(), 3);
147        assert_eq!(result.value(0), "[1,2,3]");
148        assert_eq!(result.value(1), "[4,5,6]");
149        assert!(result.is_null(2));
150    }
151}