1use arrow::datatypes::DataType as ArrowDataType;
16use common_base::bytes::Bytes;
17use serde::{Deserialize, Serialize};
18
19use crate::data_type::DataType;
20use crate::error::{InvalidVectorSnafu, Result};
21use crate::scalars::ScalarVectorBuilder;
22use crate::type_id::LogicalTypeId;
23use crate::value::Value;
24use crate::vectors::{BinaryVectorBuilder, MutableVector};
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
30pub struct VectorType {
31 pub dim: u32,
32}
33
34impl VectorType {
35 pub fn new(dim: u32) -> Self {
36 Self { dim }
37 }
38}
39
40impl DataType for VectorType {
41 fn name(&self) -> String {
42 format!("Vector({})", self.dim)
43 }
44
45 fn logical_type_id(&self) -> LogicalTypeId {
46 LogicalTypeId::Vector
47 }
48
49 fn default_value(&self) -> Value {
50 Bytes::default().into()
51 }
52
53 fn as_arrow_type(&self) -> ArrowDataType {
54 ArrowDataType::Binary
55 }
56
57 fn create_mutable_vector(&self, capacity: usize) -> Box<dyn MutableVector> {
58 Box::new(BinaryVectorBuilder::with_capacity(capacity))
59 }
60
61 fn try_cast(&self, from: Value) -> Option<Value> {
62 match from {
63 Value::Binary(v) => Some(Value::Binary(v)),
64 _ => None,
65 }
66 }
67}
68
69pub fn vector_type_value_to_string(val: &[u8], dim: u32) -> Result<String> {
72 let expected_len = dim as usize * std::mem::size_of::<f32>();
73 if val.len() != expected_len {
74 return InvalidVectorSnafu {
75 msg: format!(
76 "Failed to convert Vector value to string: wrong byte size, expected {}, got {}",
77 expected_len,
78 val.len()
79 ),
80 }
81 .fail();
82 }
83
84 if dim == 0 {
85 return Ok("[]".to_string());
86 }
87
88 let elements = val
89 .chunks_exact(std::mem::size_of::<f32>())
90 .map(|e| f32::from_le_bytes(e.try_into().unwrap()));
91
92 let mut s = String::from("[");
93 for (i, e) in elements.enumerate() {
94 if i > 0 {
95 s.push(',');
96 }
97 s.push_str(&e.to_string());
98 }
99 s.push(']');
100 Ok(s)
101}
102
103pub fn parse_string_to_vector_type_value(s: &str, dim: Option<u32>) -> Result<Vec<u8>> {
106 let trimmed = s.trim();
108 if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
109 return InvalidVectorSnafu {
110 msg: format!("Failed to parse {s} to Vector value: not properly enclosed in brackets"),
111 }
112 .fail();
113 }
114 let content = trimmed[1..trimmed.len() - 1].trim();
116
117 if content.is_empty() {
118 if dim.is_some_and(|d| d != 0) {
119 return InvalidVectorSnafu {
120 msg: format!("Failed to parse {s} to Vector value: wrong dimension"),
121 }
122 .fail();
123 }
124 return Ok(vec![]);
125 }
126
127 let elements = content
128 .split(',')
129 .map(|e| {
130 e.trim().parse::<f32>().map_err(|_| {
131 InvalidVectorSnafu {
132 msg: format!(
133 "Failed to parse {s} to Vector value: elements are not all float32"
134 ),
135 }
136 .build()
137 })
138 })
139 .collect::<Result<Vec<f32>>>()?;
140
141 if dim.is_some_and(|d| d as usize != elements.len()) {
143 return InvalidVectorSnafu {
144 msg: format!("Failed to parse {s} to Vector value: wrong dimension"),
145 }
146 .fail();
147 }
148
149 let bytes = if cfg!(target_endian = "little") {
151 unsafe {
152 std::slice::from_raw_parts(
153 elements.as_ptr() as *const u8,
154 elements.len() * std::mem::size_of::<f32>(),
155 )
156 .to_vec()
157 }
158 } else {
159 elements
160 .iter()
161 .flat_map(|e| e.to_le_bytes())
162 .collect::<Vec<u8>>()
163 };
164
165 Ok(bytes)
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_conversion_between_string_and_vector_type_value() {
174 let dim = 3;
175
176 let cases = [
177 ("[1.0,2.0,3]", "[1,2,3]"),
178 ("[0.0 , 0.0 , 0.0]", "[0,0,0]"),
179 ("[3.4028235e38, -3.4028235e38, 1.1754944e-38]", "[340282350000000000000000000000000000000,-340282350000000000000000000000000000000,0.000000000000000000000000000000000000011754944]"),
180 ];
181
182 for (s, expected) in cases.iter() {
183 let val = parse_string_to_vector_type_value(s, Some(dim)).unwrap();
184 let s = vector_type_value_to_string(&val, dim).unwrap();
185 assert_eq!(s, *expected);
186 }
187
188 let dim = 0;
189 let cases = [("[]", "[]"), ("[ ]", "[]"), ("[ ]", "[]")];
190 for (s, expected) in cases.iter() {
191 let val = parse_string_to_vector_type_value(s, Some(dim)).unwrap();
192 let s = vector_type_value_to_string(&val, dim).unwrap();
193 assert_eq!(s, *expected);
194 }
195 }
196
197 #[test]
198 fn test_vector_type_value_to_string_wrong_byte_size() {
199 let dim = 3;
200 let val = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
201 let res = vector_type_value_to_string(&val, dim);
202 assert!(res.is_err());
203
204 let dim = 0;
205 let val = vec![1];
206 let res = vector_type_value_to_string(&val, dim);
207 assert!(res.is_err());
208 }
209
210 #[test]
211 fn test_parse_string_to_vector_type_value_not_properly_enclosed_in_brackets() {
212 let dim = 3;
213 let s = "1.0,2.0,3.0";
214 let res = parse_string_to_vector_type_value(s, Some(dim));
215 assert!(res.is_err());
216
217 let s = "[1.0,2.0,3.0";
218 let res = parse_string_to_vector_type_value(s, Some(dim));
219 assert!(res.is_err());
220
221 let s = "1.0,2.0,3.0]";
222 let res = parse_string_to_vector_type_value(s, Some(dim));
223 assert!(res.is_err());
224 }
225
226 #[test]
227 fn test_parse_string_to_vector_type_value_wrong_dimension() {
228 let dim = 3;
229 let s = "[1.0,2.0]";
230 let res = parse_string_to_vector_type_value(s, Some(dim));
231 assert!(res.is_err());
232 }
233
234 #[test]
235 fn test_parse_string_to_vector_type_value_elements_are_not_all_float32() {
236 let dim = 3;
237 let s = "[1.0,2.0,ah]";
238 let res = parse_string_to_vector_type_value(s, Some(dim));
239 assert!(res.is_err());
240 }
241}