common_function/scalars/vector/convert/
parse_vector.rs1use std::fmt::Display;
16use std::sync::Arc;
17
18use common_query::error::InvalidVectorStringSnafu;
19use datafusion_common::arrow::array::{Array, AsArray, BinaryViewBuilder};
20use datafusion_common::arrow::compute;
21use datafusion_common::arrow::datatypes::DataType;
22use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
23use datatypes::types::parse_string_to_vector_type_value;
24use snafu::ResultExt;
25
26use crate::function::{Function, extract_args};
27
28const NAME: &str = "parse_vec";
29
30#[derive(Debug, Clone)]
31pub struct ParseVectorFunction {
32 signature: Signature,
33}
34
35impl Default for ParseVectorFunction {
36 fn default() -> Self {
37 Self {
38 signature: Signature::string(1, Volatility::Immutable),
39 }
40 }
41}
42
43impl Function for ParseVectorFunction {
44 fn name(&self) -> &str {
45 NAME
46 }
47
48 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
49 Ok(DataType::BinaryView)
50 }
51
52 fn signature(&self) -> &Signature {
53 &self.signature
54 }
55
56 fn invoke_with_args(
57 &self,
58 args: ScalarFunctionArgs,
59 ) -> datafusion_common::Result<ColumnarValue> {
60 let [arg0] = extract_args(self.name(), &args)?;
61 let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
62 let column = arg0.as_string_view();
63
64 let size = column.len();
65
66 let mut builder = BinaryViewBuilder::with_capacity(size);
67 for i in 0..size {
68 let value = column.is_valid(i).then(|| column.value(i));
69 if let Some(value) = value {
70 let result = parse_string_to_vector_type_value(value, None)
71 .context(InvalidVectorStringSnafu { vec_str: value })?;
72 builder.append_value(result);
73 } else {
74 builder.append_null();
75 }
76 }
77
78 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
79 }
80}
81
82impl Display for ParseVectorFunction {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(f, "{}", NAME.to_ascii_uppercase())
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use std::sync::Arc;
91
92 use arrow_schema::Field;
93 use common_base::bytes::Bytes;
94 use datafusion_common::arrow::array::StringViewArray;
95
96 use super::*;
97
98 #[test]
99 fn test_parse_vector() {
100 let func = ParseVectorFunction::default();
101
102 let arg0 = Arc::new(StringViewArray::from_iter([
103 Some("[1.0,2.0,3.0]".to_string()),
104 Some("[4.0,5.0,6.0]".to_string()),
105 None,
106 ]));
107 let args = ScalarFunctionArgs {
108 args: vec![ColumnarValue::Array(arg0)],
109 arg_fields: vec![],
110 number_rows: 3,
111 return_field: Arc::new(Field::new("", DataType::BinaryView, false)),
112 config_options: Arc::new(Default::default()),
113 };
114
115 let result = func
116 .invoke_with_args(args)
117 .and_then(|x| x.to_array(3))
118 .unwrap();
119 let result = result.as_binary_view();
120
121 assert_eq!(result.len(), 3);
122 assert_eq!(
123 result.value(0),
124 &Bytes::from(
125 [1.0f32, 2.0, 3.0]
126 .iter()
127 .flat_map(|e| e.to_le_bytes())
128 .collect::<Vec<u8>>()
129 )
130 );
131 assert_eq!(
132 result.value(1),
133 &Bytes::from(
134 [4.0f32, 5.0, 6.0]
135 .iter()
136 .flat_map(|e| e.to_le_bytes())
137 .collect::<Vec<u8>>()
138 )
139 );
140 assert!(result.is_null(2));
141 }
142
143 #[test]
144 fn test_parse_vector_error() {
145 let func = ParseVectorFunction::default();
146
147 let inputs = [
148 StringViewArray::from_iter([
149 Some("[1.0,2.0,3.0]".to_string()),
150 Some("[4.0,5.0,6.0]".to_string()),
151 Some("[7.0,8.0,9.0".to_string()),
152 ]),
153 StringViewArray::from_iter([
154 Some("[1.0,2.0,3.0]".to_string()),
155 Some("[4.0,5.0,6.0]".to_string()),
156 Some("7.0,8.0,9.0]".to_string()),
157 ]),
158 StringViewArray::from_iter([
159 Some("[1.0,2.0,3.0]".to_string()),
160 Some("[4.0,5.0,6.0]".to_string()),
161 Some("[7.0,hello,9.0]".to_string()),
162 ]),
163 ];
164 let expected = [
165 "External error: Invalid vector string: [7.0,8.0,9.0",
166 "External error: Invalid vector string: 7.0,8.0,9.0]",
167 "External error: Invalid vector string: [7.0,hello,9.0]",
168 ];
169
170 for (input, expected) in inputs.into_iter().zip(expected.into_iter()) {
171 let args = ScalarFunctionArgs {
172 args: vec![ColumnarValue::Array(Arc::new(input))],
173 arg_fields: vec![],
174 number_rows: 3,
175 return_field: Arc::new(Field::new("", DataType::BinaryView, false)),
176 config_options: Arc::new(Default::default()),
177 };
178 let result = func.invoke_with_args(args);
179 assert_eq!(result.unwrap_err().to_string(), expected);
180 }
181 }
182}