common_function/scalars/vector/convert/
parse_vector.rs1use std::fmt::Display;
16
17use common_query::error::{InvalidFuncArgsSnafu, InvalidVectorStringSnafu, Result};
18use datafusion::arrow::datatypes::DataType;
19use datafusion_expr::{Signature, Volatility};
20use datatypes::scalars::ScalarVectorBuilder;
21use datatypes::types::parse_string_to_vector_type_value;
22use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
23use snafu::{ResultExt, ensure};
24
25use crate::function::{Function, FunctionContext};
26
27const NAME: &str = "parse_vec";
28
29#[derive(Debug, Clone, Default)]
30pub struct ParseVectorFunction;
31
32impl Function for ParseVectorFunction {
33 fn name(&self) -> &str {
34 NAME
35 }
36
37 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
38 Ok(DataType::Binary)
39 }
40
41 fn signature(&self) -> Signature {
42 Signature::string(1, Volatility::Immutable)
43 }
44
45 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
46 ensure!(
47 columns.len() == 1,
48 InvalidFuncArgsSnafu {
49 err_msg: format!(
50 "The length of the args is not correct, expect exactly one, have: {}",
51 columns.len()
52 ),
53 }
54 );
55
56 let column = &columns[0];
57 let size = column.len();
58
59 let mut result = BinaryVectorBuilder::with_capacity(size);
60 for i in 0..size {
61 let value = column.get(i).as_string();
62 if let Some(value) = value {
63 let res = parse_string_to_vector_type_value(&value, None)
64 .context(InvalidVectorStringSnafu { vec_str: &value })?;
65 result.push(Some(&res));
66 } else {
67 result.push_null();
68 }
69 }
70
71 Ok(result.to_vector())
72 }
73}
74
75impl Display for ParseVectorFunction {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 write!(f, "{}", NAME.to_ascii_uppercase())
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use std::sync::Arc;
84
85 use common_base::bytes::Bytes;
86 use datatypes::value::Value;
87 use datatypes::vectors::StringVector;
88
89 use super::*;
90
91 #[test]
92 fn test_parse_vector() {
93 let func = ParseVectorFunction;
94
95 let input = Arc::new(StringVector::from(vec![
96 Some("[1.0,2.0,3.0]".to_string()),
97 Some("[4.0,5.0,6.0]".to_string()),
98 None,
99 ]));
100
101 let result = func.eval(&FunctionContext::default(), &[input]).unwrap();
102
103 let result = result.as_ref();
104 assert_eq!(result.len(), 3);
105 assert_eq!(
106 result.get(0),
107 Value::Binary(Bytes::from(
108 [1.0f32, 2.0, 3.0]
109 .iter()
110 .flat_map(|e| e.to_le_bytes())
111 .collect::<Vec<u8>>()
112 ))
113 );
114 assert_eq!(
115 result.get(1),
116 Value::Binary(Bytes::from(
117 [4.0f32, 5.0, 6.0]
118 .iter()
119 .flat_map(|e| e.to_le_bytes())
120 .collect::<Vec<u8>>()
121 ))
122 );
123 assert!(result.get(2).is_null());
124 }
125
126 #[test]
127 fn test_parse_vector_error() {
128 let func = ParseVectorFunction;
129
130 let input = Arc::new(StringVector::from(vec![
131 Some("[1.0,2.0,3.0]".to_string()),
132 Some("[4.0,5.0,6.0]".to_string()),
133 Some("[7.0,8.0,9.0".to_string()),
134 ]));
135
136 let result = func.eval(&FunctionContext::default(), &[input]);
137 assert!(result.is_err());
138
139 let input = Arc::new(StringVector::from(vec![
140 Some("[1.0,2.0,3.0]".to_string()),
141 Some("[4.0,5.0,6.0]".to_string()),
142 Some("7.0,8.0,9.0]".to_string()),
143 ]));
144
145 let result = func.eval(&FunctionContext::default(), &[input]);
146 assert!(result.is_err());
147
148 let input = Arc::new(StringVector::from(vec![
149 Some("[1.0,2.0,3.0]".to_string()),
150 Some("[4.0,5.0,6.0]".to_string()),
151 Some("[7.0,hello,9.0]".to_string()),
152 ]));
153
154 let result = func.eval(&FunctionContext::default(), &[input]);
155 assert!(result.is_err());
156 }
157}