common_function/scalars/vector/convert/
parse_vector.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use common_query::error::{InvalidFuncArgsSnafu, InvalidVectorStringSnafu, Result};
18use common_query::prelude::{Signature, Volatility};
19use datatypes::prelude::ConcreteDataType;
20use datatypes::scalars::ScalarVectorBuilder;
21use datatypes::types::parse_string_to_vector_type_value;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use snafu::{ensure, ResultExt};
24
25use crate::function::{Function, FunctionContext};
26
27const NAME: &str = "parse_vec";
28
29#[derive(Debug, Clone, Default)]
30pub struct ParseVectorFunction;
31
32impl Function for ParseVectorFunction {
33    fn name(&self) -> &str {
34        NAME
35    }
36
37    fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
38        Ok(ConcreteDataType::binary_datatype())
39    }
40
41    fn signature(&self) -> Signature {
42        Signature::exact(
43            vec![ConcreteDataType::string_datatype()],
44            Volatility::Immutable,
45        )
46    }
47
48    fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
49        ensure!(
50            columns.len() == 1,
51            InvalidFuncArgsSnafu {
52                err_msg: format!(
53                    "The length of the args is not correct, expect exactly one, have: {}",
54                    columns.len()
55                ),
56            }
57        );
58
59        let column = &columns[0];
60        let size = column.len();
61
62        let mut result = BinaryVectorBuilder::with_capacity(size);
63        for i in 0..size {
64            let value = column.get(i).as_string();
65            if let Some(value) = value {
66                let res = parse_string_to_vector_type_value(&value, None)
67                    .context(InvalidVectorStringSnafu { vec_str: &value })?;
68                result.push(Some(&res));
69            } else {
70                result.push_null();
71            }
72        }
73
74        Ok(result.to_vector())
75    }
76}
77
78impl Display for ParseVectorFunction {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        write!(f, "{}", NAME.to_ascii_uppercase())
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use std::sync::Arc;
87
88    use common_base::bytes::Bytes;
89    use datatypes::value::Value;
90    use datatypes::vectors::StringVector;
91
92    use super::*;
93
94    #[test]
95    fn test_parse_vector() {
96        let func = ParseVectorFunction;
97
98        let input = Arc::new(StringVector::from(vec![
99            Some("[1.0,2.0,3.0]".to_string()),
100            Some("[4.0,5.0,6.0]".to_string()),
101            None,
102        ]));
103
104        let result = func.eval(&FunctionContext::default(), &[input]).unwrap();
105
106        let result = result.as_ref();
107        assert_eq!(result.len(), 3);
108        assert_eq!(
109            result.get(0),
110            Value::Binary(Bytes::from(
111                [1.0f32, 2.0, 3.0]
112                    .iter()
113                    .flat_map(|e| e.to_le_bytes())
114                    .collect::<Vec<u8>>()
115            ))
116        );
117        assert_eq!(
118            result.get(1),
119            Value::Binary(Bytes::from(
120                [4.0f32, 5.0, 6.0]
121                    .iter()
122                    .flat_map(|e| e.to_le_bytes())
123                    .collect::<Vec<u8>>()
124            ))
125        );
126        assert!(result.get(2).is_null());
127    }
128
129    #[test]
130    fn test_parse_vector_error() {
131        let func = ParseVectorFunction;
132
133        let input = Arc::new(StringVector::from(vec![
134            Some("[1.0,2.0,3.0]".to_string()),
135            Some("[4.0,5.0,6.0]".to_string()),
136            Some("[7.0,8.0,9.0".to_string()),
137        ]));
138
139        let result = func.eval(&FunctionContext::default(), &[input]);
140        assert!(result.is_err());
141
142        let input = Arc::new(StringVector::from(vec![
143            Some("[1.0,2.0,3.0]".to_string()),
144            Some("[4.0,5.0,6.0]".to_string()),
145            Some("7.0,8.0,9.0]".to_string()),
146        ]));
147
148        let result = func.eval(&FunctionContext::default(), &[input]);
149        assert!(result.is_err());
150
151        let input = Arc::new(StringVector::from(vec![
152            Some("[1.0,2.0,3.0]".to_string()),
153            Some("[4.0,5.0,6.0]".to_string()),
154            Some("[7.0,hello,9.0]".to_string()),
155        ]));
156
157        let result = func.eval(&FunctionContext::default(), &[input]);
158        assert!(result.is_err());
159    }
160}