datatypes/types/
vector_type.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use arrow::datatypes::DataType as ArrowDataType;
16use common_base::bytes::Bytes;
17use serde::{Deserialize, Serialize};
18
19use crate::data_type::DataType;
20use crate::error::{InvalidVectorSnafu, Result};
21use crate::scalars::ScalarVectorBuilder;
22use crate::type_id::LogicalTypeId;
23use crate::value::Value;
24use crate::vectors::{BinaryVectorBuilder, MutableVector};
25
26/// `VectorType` is a data type for vector data with a fixed dimension.
27/// The type of items in the vector is float32.
28/// It is stored as binary data that contains the concatenated float32 values.
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
30pub struct VectorType {
31    pub dim: u32,
32}
33
34impl VectorType {
35    pub fn new(dim: u32) -> Self {
36        Self { dim }
37    }
38}
39
40impl DataType for VectorType {
41    fn name(&self) -> String {
42        format!("Vector({})", self.dim)
43    }
44
45    fn logical_type_id(&self) -> LogicalTypeId {
46        LogicalTypeId::Vector
47    }
48
49    fn default_value(&self) -> Value {
50        Bytes::default().into()
51    }
52
53    fn as_arrow_type(&self) -> ArrowDataType {
54        ArrowDataType::Binary
55    }
56
57    fn create_mutable_vector(&self, capacity: usize) -> Box<dyn MutableVector> {
58        Box::new(BinaryVectorBuilder::with_capacity(capacity))
59    }
60
61    fn try_cast(&self, from: Value) -> Option<Value> {
62        match from {
63            Value::Binary(v) => Some(Value::Binary(v)),
64            _ => None,
65        }
66    }
67}
68
69/// Converts a vector type value to string
70/// for example: [1.0, 2.0, 3.0] -> "[1.0,2.0,3.0]"
71pub fn vector_type_value_to_string(val: &[u8], dim: u32) -> Result<String> {
72    let expected_len = dim as usize * std::mem::size_of::<f32>();
73    if val.len() != expected_len {
74        return InvalidVectorSnafu {
75            msg: format!(
76                "Failed to convert Vector value to string: wrong byte size, expected {}, got {}",
77                expected_len,
78                val.len()
79            ),
80        }
81        .fail();
82    }
83
84    if dim == 0 {
85        return Ok("[]".to_string());
86    }
87
88    let elements = val
89        .chunks_exact(std::mem::size_of::<f32>())
90        .map(|e| f32::from_le_bytes(e.try_into().unwrap()));
91
92    let mut s = String::from("[");
93    for (i, e) in elements.enumerate() {
94        if i > 0 {
95            s.push(',');
96        }
97        s.push_str(&e.to_string());
98    }
99    s.push(']');
100    Ok(s)
101}
102
103/// Parses a string to a vector type value
104/// Valid input format: "[1.0,2.0,3.0]", "[1.0, 2.0, 3.0]"
105pub fn parse_string_to_vector_type_value(s: &str, dim: Option<u32>) -> Result<Vec<u8>> {
106    // Trim the brackets
107    let trimmed = s.trim();
108    if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
109        return InvalidVectorSnafu {
110            msg: format!("Failed to parse {s} to Vector value: not properly enclosed in brackets"),
111        }
112        .fail();
113    }
114    // Remove the brackets
115    let content = trimmed[1..trimmed.len() - 1].trim();
116
117    if content.is_empty() {
118        if dim.is_some_and(|d| d != 0) {
119            return InvalidVectorSnafu {
120                msg: format!("Failed to parse {s} to Vector value: wrong dimension"),
121            }
122            .fail();
123        }
124        return Ok(vec![]);
125    }
126
127    let elements = content
128        .split(',')
129        .map(|e| {
130            e.trim().parse::<f32>().map_err(|_| {
131                InvalidVectorSnafu {
132                    msg: format!(
133                        "Failed to parse {s} to Vector value: elements are not all float32"
134                    ),
135                }
136                .build()
137            })
138        })
139        .collect::<Result<Vec<f32>>>()?;
140
141    // Check dimension
142    if dim.is_some_and(|d| d as usize != elements.len()) {
143        return InvalidVectorSnafu {
144            msg: format!("Failed to parse {s} to Vector value: wrong dimension"),
145        }
146        .fail();
147    }
148
149    // Convert Vec<f32> to Vec<u8>
150    let bytes = if cfg!(target_endian = "little") {
151        unsafe {
152            std::slice::from_raw_parts(
153                elements.as_ptr() as *const u8,
154                elements.len() * std::mem::size_of::<f32>(),
155            )
156            .to_vec()
157        }
158    } else {
159        elements
160            .iter()
161            .flat_map(|e| e.to_le_bytes())
162            .collect::<Vec<u8>>()
163    };
164
165    Ok(bytes)
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_conversion_between_string_and_vector_type_value() {
174        let dim = 3;
175
176        let cases = [
177            ("[1.0,2.0,3]", "[1,2,3]"),
178            ("[0.0 , 0.0 , 0.0]", "[0,0,0]"),
179            ("[3.4028235e38, -3.4028235e38, 1.1754944e-38]", "[340282350000000000000000000000000000000,-340282350000000000000000000000000000000,0.000000000000000000000000000000000000011754944]"),
180        ];
181
182        for (s, expected) in cases.iter() {
183            let val = parse_string_to_vector_type_value(s, Some(dim)).unwrap();
184            let s = vector_type_value_to_string(&val, dim).unwrap();
185            assert_eq!(s, *expected);
186        }
187
188        let dim = 0;
189        let cases = [("[]", "[]"), ("[ ]", "[]"), ("[  ]", "[]")];
190        for (s, expected) in cases.iter() {
191            let val = parse_string_to_vector_type_value(s, Some(dim)).unwrap();
192            let s = vector_type_value_to_string(&val, dim).unwrap();
193            assert_eq!(s, *expected);
194        }
195    }
196
197    #[test]
198    fn test_vector_type_value_to_string_wrong_byte_size() {
199        let dim = 3;
200        let val = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
201        let res = vector_type_value_to_string(&val, dim);
202        assert!(res.is_err());
203
204        let dim = 0;
205        let val = vec![1];
206        let res = vector_type_value_to_string(&val, dim);
207        assert!(res.is_err());
208    }
209
210    #[test]
211    fn test_parse_string_to_vector_type_value_not_properly_enclosed_in_brackets() {
212        let dim = 3;
213        let s = "1.0,2.0,3.0";
214        let res = parse_string_to_vector_type_value(s, Some(dim));
215        assert!(res.is_err());
216
217        let s = "[1.0,2.0,3.0";
218        let res = parse_string_to_vector_type_value(s, Some(dim));
219        assert!(res.is_err());
220
221        let s = "1.0,2.0,3.0]";
222        let res = parse_string_to_vector_type_value(s, Some(dim));
223        assert!(res.is_err());
224    }
225
226    #[test]
227    fn test_parse_string_to_vector_type_value_wrong_dimension() {
228        let dim = 3;
229        let s = "[1.0,2.0]";
230        let res = parse_string_to_vector_type_value(s, Some(dim));
231        assert!(res.is_err());
232    }
233
234    #[test]
235    fn test_parse_string_to_vector_type_value_elements_are_not_all_float32() {
236        let dim = 3;
237        let s = "[1.0,2.0,ah]";
238        let res = parse_string_to_vector_type_value(s, Some(dim));
239        assert!(res.is_err());
240    }
241}