common_function/scalars/vector/convert/
parse_vector.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use common_query::error::{InvalidFuncArgsSnafu, InvalidVectorStringSnafu, Result};
18use datafusion::arrow::datatypes::DataType;
19use datafusion_expr::{Signature, Volatility};
20use datatypes::scalars::ScalarVectorBuilder;
21use datatypes::types::parse_string_to_vector_type_value;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use snafu::{ResultExt, ensure};
24
25use crate::function::{Function, FunctionContext};
26
27const NAME: &str = "parse_vec";
28
29#[derive(Debug, Clone, Default)]
30pub struct ParseVectorFunction;
31
32impl Function for ParseVectorFunction {
33    fn name(&self) -> &str {
34        NAME
35    }
36
37    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
38        Ok(DataType::Binary)
39    }
40
41    fn signature(&self) -> Signature {
42        Signature::string(1, Volatility::Immutable)
43    }
44
45    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
46        ensure!(
47            columns.len() == 1,
48            InvalidFuncArgsSnafu {
49                err_msg: format!(
50                    "The length of the args is not correct, expect exactly one, have: {}",
51                    columns.len()
52                ),
53            }
54        );
55
56        let column = &columns[0];
57        let size = column.len();
58
59        let mut result = BinaryVectorBuilder::with_capacity(size);
60        for i in 0..size {
61            let value = column.get(i).as_string();
62            if let Some(value) = value {
63                let res = parse_string_to_vector_type_value(&value, None)
64                    .context(InvalidVectorStringSnafu { vec_str: &value })?;
65                result.push(Some(&res));
66            } else {
67                result.push_null();
68            }
69        }
70
71        Ok(result.to_vector())
72    }
73}
74
75impl Display for ParseVectorFunction {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        write!(f, "{}", NAME.to_ascii_uppercase())
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use std::sync::Arc;
84
85    use common_base::bytes::Bytes;
86    use datatypes::value::Value;
87    use datatypes::vectors::StringVector;
88
89    use super::*;
90
91    #[test]
92    fn test_parse_vector() {
93        let func = ParseVectorFunction;
94
95        let input = Arc::new(StringVector::from(vec![
96            Some("[1.0,2.0,3.0]".to_string()),
97            Some("[4.0,5.0,6.0]".to_string()),
98            None,
99        ]));
100
101        let result = func.eval(&FunctionContext::default(), &[input]).unwrap();
102
103        let result = result.as_ref();
104        assert_eq!(result.len(), 3);
105        assert_eq!(
106            result.get(0),
107            Value::Binary(Bytes::from(
108                [1.0f32, 2.0, 3.0]
109                    .iter()
110                    .flat_map(|e| e.to_le_bytes())
111                    .collect::<Vec<u8>>()
112            ))
113        );
114        assert_eq!(
115            result.get(1),
116            Value::Binary(Bytes::from(
117                [4.0f32, 5.0, 6.0]
118                    .iter()
119                    .flat_map(|e| e.to_le_bytes())
120                    .collect::<Vec<u8>>()
121            ))
122        );
123        assert!(result.get(2).is_null());
124    }
125
126    #[test]
127    fn test_parse_vector_error() {
128        let func = ParseVectorFunction;
129
130        let input = Arc::new(StringVector::from(vec![
131            Some("[1.0,2.0,3.0]".to_string()),
132            Some("[4.0,5.0,6.0]".to_string()),
133            Some("[7.0,8.0,9.0".to_string()),
134        ]));
135
136        let result = func.eval(&FunctionContext::default(), &[input]);
137        assert!(result.is_err());
138
139        let input = Arc::new(StringVector::from(vec![
140            Some("[1.0,2.0,3.0]".to_string()),
141            Some("[4.0,5.0,6.0]".to_string()),
142            Some("7.0,8.0,9.0]".to_string()),
143        ]));
144
145        let result = func.eval(&FunctionContext::default(), &[input]);
146        assert!(result.is_err());
147
148        let input = Arc::new(StringVector::from(vec![
149            Some("[1.0,2.0,3.0]".to_string()),
150            Some("[4.0,5.0,6.0]".to_string()),
151            Some("[7.0,hello,9.0]".to_string()),
152        ]));
153
154        let result = func.eval(&FunctionContext::default(), &[input]);
155        assert!(result.is_err());
156    }
157}