common_function/scalars/vector/convert/
parse_vector.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16use std::sync::Arc;
17
18use common_query::error::InvalidVectorStringSnafu;
19use datafusion_common::arrow::array::{Array, AsArray, BinaryViewBuilder};
20use datafusion_common::arrow::compute;
21use datafusion_common::arrow::datatypes::DataType;
22use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
23use datatypes::types::parse_string_to_vector_type_value;
24use snafu::ResultExt;
25
26use crate::function::{Function, extract_args};
27
28const NAME: &str = "parse_vec";
29
30#[derive(Debug, Clone)]
31pub struct ParseVectorFunction {
32    signature: Signature,
33}
34
35impl Default for ParseVectorFunction {
36    fn default() -> Self {
37        Self {
38            signature: Signature::string(1, Volatility::Immutable),
39        }
40    }
41}
42
43impl Function for ParseVectorFunction {
44    fn name(&self) -> &str {
45        NAME
46    }
47
48    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
49        Ok(DataType::BinaryView)
50    }
51
52    fn signature(&self) -> &Signature {
53        &self.signature
54    }
55
56    fn invoke_with_args(
57        &self,
58        args: ScalarFunctionArgs,
59    ) -> datafusion_common::Result<ColumnarValue> {
60        let [arg0] = extract_args(self.name(), &args)?;
61        let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
62        let column = arg0.as_string_view();
63
64        let size = column.len();
65
66        let mut builder = BinaryViewBuilder::with_capacity(size);
67        for i in 0..size {
68            let value = column.is_valid(i).then(|| column.value(i));
69            if let Some(value) = value {
70                let result = parse_string_to_vector_type_value(value, None)
71                    .context(InvalidVectorStringSnafu { vec_str: value })?;
72                builder.append_value(result);
73            } else {
74                builder.append_null();
75            }
76        }
77
78        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
79    }
80}
81
82impl Display for ParseVectorFunction {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        write!(f, "{}", NAME.to_ascii_uppercase())
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use std::sync::Arc;
91
92    use arrow_schema::Field;
93    use common_base::bytes::Bytes;
94    use datafusion_common::arrow::array::StringViewArray;
95
96    use super::*;
97
98    #[test]
99    fn test_parse_vector() {
100        let func = ParseVectorFunction::default();
101
102        let arg0 = Arc::new(StringViewArray::from_iter([
103            Some("[1.0,2.0,3.0]".to_string()),
104            Some("[4.0,5.0,6.0]".to_string()),
105            None,
106        ]));
107        let args = ScalarFunctionArgs {
108            args: vec![ColumnarValue::Array(arg0)],
109            arg_fields: vec![],
110            number_rows: 3,
111            return_field: Arc::new(Field::new("", DataType::BinaryView, false)),
112            config_options: Arc::new(Default::default()),
113        };
114
115        let result = func
116            .invoke_with_args(args)
117            .and_then(|x| x.to_array(3))
118            .unwrap();
119        let result = result.as_binary_view();
120
121        assert_eq!(result.len(), 3);
122        assert_eq!(
123            result.value(0),
124            &Bytes::from(
125                [1.0f32, 2.0, 3.0]
126                    .iter()
127                    .flat_map(|e| e.to_le_bytes())
128                    .collect::<Vec<u8>>()
129            )
130        );
131        assert_eq!(
132            result.value(1),
133            &Bytes::from(
134                [4.0f32, 5.0, 6.0]
135                    .iter()
136                    .flat_map(|e| e.to_le_bytes())
137                    .collect::<Vec<u8>>()
138            )
139        );
140        assert!(result.is_null(2));
141    }
142
143    #[test]
144    fn test_parse_vector_error() {
145        let func = ParseVectorFunction::default();
146
147        let inputs = [
148            StringViewArray::from_iter([
149                Some("[1.0,2.0,3.0]".to_string()),
150                Some("[4.0,5.0,6.0]".to_string()),
151                Some("[7.0,8.0,9.0".to_string()),
152            ]),
153            StringViewArray::from_iter([
154                Some("[1.0,2.0,3.0]".to_string()),
155                Some("[4.0,5.0,6.0]".to_string()),
156                Some("7.0,8.0,9.0]".to_string()),
157            ]),
158            StringViewArray::from_iter([
159                Some("[1.0,2.0,3.0]".to_string()),
160                Some("[4.0,5.0,6.0]".to_string()),
161                Some("[7.0,hello,9.0]".to_string()),
162            ]),
163        ];
164        let expected = [
165            "External error: Invalid vector string: [7.0,8.0,9.0",
166            "External error: Invalid vector string: 7.0,8.0,9.0]",
167            "External error: Invalid vector string: [7.0,hello,9.0]",
168        ];
169
170        for (input, expected) in inputs.into_iter().zip(expected.into_iter()) {
171            let args = ScalarFunctionArgs {
172                args: vec![ColumnarValue::Array(Arc::new(input))],
173                arg_fields: vec![],
174                number_rows: 3,
175                return_field: Arc::new(Field::new("", DataType::BinaryView, false)),
176                config_options: Arc::new(Default::default()),
177            };
178            let result = func.invoke_with_args(args);
179            assert_eq!(result.unwrap_err().to_string(), expected);
180        }
181    }
182}