common_function/scalars/vector/
impl_conv.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::borrow::Cow;
16
17use common_query::error::{InvalidFuncArgsSnafu, Result};
18use datafusion_common::ScalarValue;
19
20/// Convert a string or binary literal to a vector literal.
21pub fn as_veclit(arg: &ScalarValue) -> Result<Option<Cow<'_, [f32]>>> {
22    match arg {
23        ScalarValue::Binary(b) | ScalarValue::BinaryView(b) => {
24            b.as_ref().map(|x| binlit_as_veclit(x)).transpose()
25        }
26        ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) => s
27            .as_ref()
28            .map(|x| parse_veclit_from_strlit(x).map(Cow::Owned))
29            .transpose(),
30        _ => InvalidFuncArgsSnafu {
31            err_msg: format!("Unsupported data type: {:?}", arg.data_type()),
32        }
33        .fail(),
34    }
35}
36
37/// Convert a u8 slice to a vector literal.
38pub fn binlit_as_veclit(bytes: &[u8]) -> Result<Cow<'_, [f32]>> {
39    if !bytes.len().is_multiple_of(size_of::<f32>()) {
40        return InvalidFuncArgsSnafu {
41            err_msg: format!("Invalid binary length of vector: {}", bytes.len()),
42        }
43        .fail();
44    }
45
46    if cfg!(target_endian = "little") {
47        Ok(unsafe {
48            let vec = std::slice::from_raw_parts(
49                bytes.as_ptr() as *const f32,
50                bytes.len() / std::mem::size_of::<f32>(),
51            );
52            Cow::Borrowed(vec)
53        })
54    } else {
55        let v = bytes
56            .chunks_exact(std::mem::size_of::<f32>())
57            .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
58            .collect::<Vec<f32>>();
59        Ok(Cow::Owned(v))
60    }
61}
62
63/// Parse a string literal to a vector literal.
64/// Valid inputs are strings like "[1.0, 2.0, 3.0]".
65pub fn parse_veclit_from_strlit(s: &str) -> Result<Vec<f32>> {
66    let trimmed = s.trim();
67    if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
68        return InvalidFuncArgsSnafu {
69            err_msg: format!(
70                "Failed to parse {s} to Vector value: not properly enclosed in brackets"
71            ),
72        }
73        .fail();
74    }
75    let content = trimmed[1..trimmed.len() - 1].trim();
76    if content.is_empty() {
77        return Ok(Vec::new());
78    }
79
80    content
81        .split(',')
82        .map(|s| s.trim().parse::<f32>())
83        .collect::<std::result::Result<_, _>>()
84        .map_err(|e| {
85            InvalidFuncArgsSnafu {
86                err_msg: format!("Failed to parse {s} to Vector value: {e}"),
87            }
88            .build()
89        })
90}
91
92/// Convert a vector literal to a binary literal.
93pub fn veclit_to_binlit(vec: &[f32]) -> Vec<u8> {
94    if cfg!(target_endian = "little") {
95        unsafe {
96            std::slice::from_raw_parts(vec.as_ptr() as *const u8, std::mem::size_of_val(vec))
97                .to_vec()
98        }
99    } else {
100        let mut bytes = Vec::with_capacity(std::mem::size_of_val(vec));
101        for e in vec {
102            bytes.extend_from_slice(&e.to_le_bytes());
103        }
104        bytes
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[test]
113    fn test_parse_veclit_from_strlit() {
114        let result = parse_veclit_from_strlit("[1.0, 2.0, 3.0]").unwrap();
115        assert_eq!(result, vec![1.0, 2.0, 3.0]);
116
117        let result = parse_veclit_from_strlit("[]").unwrap();
118        assert_eq!(result, Vec::<f32>::new());
119
120        let result = parse_veclit_from_strlit("[1.0, a, 3.0]");
121        assert!(result.is_err());
122    }
123
124    #[test]
125    fn test_binlit_as_veclit() {
126        let vec = &[1.0, 2.0, 3.0];
127        let bytes = veclit_to_binlit(vec);
128        let result = binlit_as_veclit(&bytes).unwrap();
129        assert_eq!(result.as_ref(), vec);
130
131        let invalid_bytes = [0, 0, 128];
132        let result = binlit_as_veclit(&invalid_bytes);
133        assert!(result.is_err());
134    }
135}