common_function/scalars/vector/
impl_conv.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16use std::sync::Arc;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datatypes::prelude::ConcreteDataType;
20use datatypes::value::ValueRef;
21use datatypes::vectors::Vector;
22
23/// Convert a constant string or binary literal to a vector literal.
24pub fn as_veclit_if_const(arg: &Arc<dyn Vector>) -> Result<Option<Cow<'_, [f32]>>> {
25    if !arg.is_const() {
26        return Ok(None);
27    }
28    if arg.data_type() != ConcreteDataType::string_datatype()
29        && arg.data_type() != ConcreteDataType::binary_datatype()
30    {
31        return Ok(None);
32    }
33    as_veclit(arg.get_ref(0))
34}
35
36/// Convert a string or binary literal to a vector literal.
37pub fn as_veclit(arg: ValueRef<'_>) -> Result<Option<Cow<'_, [f32]>>> {
38    match arg.data_type() {
39        ConcreteDataType::Binary(_) => arg
40            .as_binary()
41            .unwrap() // Safe: checked if it is a binary
42            .map(binlit_as_veclit)
43            .transpose(),
44        ConcreteDataType::String(_) => arg
45            .as_string()
46            .unwrap() // Safe: checked if it is a string
47            .map(|s| Ok(Cow::Owned(parse_veclit_from_strlit(s)?)))
48            .transpose(),
49        ConcreteDataType::Null(_) => Ok(None),
50        _ => InvalidFuncArgsSnafu {
51            err_msg: format!("Unsupported data type: {:?}", arg.data_type()),
52        }
53        .fail(),
54    }
55}
56
57/// Convert a u8 slice to a vector literal.
58pub fn binlit_as_veclit(bytes: &[u8]) -> Result<Cow<'_, [f32]>> {
59    if bytes.len() % std::mem::size_of::<f32>() != 0 {
60        return InvalidFuncArgsSnafu {
61            err_msg: format!("Invalid binary length of vector: {}", bytes.len()),
62        }
63        .fail();
64    }
65
66    if cfg!(target_endian = "little") {
67        Ok(unsafe {
68            let vec = std::slice::from_raw_parts(
69                bytes.as_ptr() as *const f32,
70                bytes.len() / std::mem::size_of::<f32>(),
71            );
72            Cow::Borrowed(vec)
73        })
74    } else {
75        let v = bytes
76            .chunks_exact(std::mem::size_of::<f32>())
77            .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
78            .collect::<Vec<f32>>();
79        Ok(Cow::Owned(v))
80    }
81}
82
83/// Parse a string literal to a vector literal.
84/// Valid inputs are strings like "[1.0, 2.0, 3.0]".
85pub fn parse_veclit_from_strlit(s: &str) -> Result<Vec<f32>> {
86    let trimmed = s.trim();
87    if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
88        return InvalidFuncArgsSnafu {
89            err_msg: format!(
90                "Failed to parse {s} to Vector value: not properly enclosed in brackets"
91            ),
92        }
93        .fail();
94    }
95    let content = trimmed[1..trimmed.len() - 1].trim();
96    if content.is_empty() {
97        return Ok(Vec::new());
98    }
99
100    content
101        .split(',')
102        .map(|s| s.trim().parse::<f32>())
103        .collect::<std::result::Result<_, _>>()
104        .map_err(|e| {
105            InvalidFuncArgsSnafu {
106                err_msg: format!("Failed to parse {s} to Vector value: {e}"),
107            }
108            .build()
109        })
110}
111
112/// Convert a vector literal to a binary literal.
113pub fn veclit_to_binlit(vec: &[f32]) -> Vec<u8> {
114    if cfg!(target_endian = "little") {
115        unsafe {
116            std::slice::from_raw_parts(vec.as_ptr() as *const u8, std::mem::size_of_val(vec))
117                .to_vec()
118        }
119    } else {
120        let mut bytes = Vec::with_capacity(std::mem::size_of_val(vec));
121        for e in vec {
122            bytes.extend_from_slice(&e.to_le_bytes());
123        }
124        bytes
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn test_parse_veclit_from_strlit() {
134        let result = parse_veclit_from_strlit("[1.0, 2.0, 3.0]").unwrap();
135        assert_eq!(result, vec![1.0, 2.0, 3.0]);
136
137        let result = parse_veclit_from_strlit("[]").unwrap();
138        assert_eq!(result, Vec::<f32>::new());
139
140        let result = parse_veclit_from_strlit("[1.0, a, 3.0]");
141        assert!(result.is_err());
142    }
143
144    #[test]
145    fn test_binlit_as_veclit() {
146        let vec = &[1.0, 2.0, 3.0];
147        let bytes = veclit_to_binlit(vec);
148        let result = binlit_as_veclit(&bytes).unwrap();
149        assert_eq!(result.as_ref(), vec);
150
151        let invalid_bytes = [0, 0, 128];
152        let result = binlit_as_veclit(&invalid_bytes);
153        assert!(result.is_err());
154    }
155}