common_function/scalars/vector/
impl_conv.rs1use std::borrow::Cow;
16use std::sync::Arc;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datatypes::prelude::ConcreteDataType;
20use datatypes::value::ValueRef;
21use datatypes::vectors::Vector;
22
23pub fn as_veclit_if_const(arg: &Arc<dyn Vector>) -> Result<Option<Cow<'_, [f32]>>> {
25 if !arg.is_const() {
26 return Ok(None);
27 }
28 if arg.data_type() != ConcreteDataType::string_datatype()
29 && arg.data_type() != ConcreteDataType::binary_datatype()
30 {
31 return Ok(None);
32 }
33 as_veclit(arg.get_ref(0))
34}
35
36pub fn as_veclit(arg: ValueRef<'_>) -> Result<Option<Cow<'_, [f32]>>> {
38 match arg.data_type() {
39 ConcreteDataType::Binary(_) => arg
40 .as_binary()
41 .unwrap() .map(binlit_as_veclit)
43 .transpose(),
44 ConcreteDataType::String(_) => arg
45 .as_string()
46 .unwrap() .map(|s| Ok(Cow::Owned(parse_veclit_from_strlit(s)?)))
48 .transpose(),
49 ConcreteDataType::Null(_) => Ok(None),
50 _ => InvalidFuncArgsSnafu {
51 err_msg: format!("Unsupported data type: {:?}", arg.data_type()),
52 }
53 .fail(),
54 }
55}
56
57pub fn binlit_as_veclit(bytes: &[u8]) -> Result<Cow<'_, [f32]>> {
59 if bytes.len() % std::mem::size_of::<f32>() != 0 {
60 return InvalidFuncArgsSnafu {
61 err_msg: format!("Invalid binary length of vector: {}", bytes.len()),
62 }
63 .fail();
64 }
65
66 if cfg!(target_endian = "little") {
67 Ok(unsafe {
68 let vec = std::slice::from_raw_parts(
69 bytes.as_ptr() as *const f32,
70 bytes.len() / std::mem::size_of::<f32>(),
71 );
72 Cow::Borrowed(vec)
73 })
74 } else {
75 let v = bytes
76 .chunks_exact(std::mem::size_of::<f32>())
77 .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
78 .collect::<Vec<f32>>();
79 Ok(Cow::Owned(v))
80 }
81}
82
83pub fn parse_veclit_from_strlit(s: &str) -> Result<Vec<f32>> {
86 let trimmed = s.trim();
87 if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
88 return InvalidFuncArgsSnafu {
89 err_msg: format!(
90 "Failed to parse {s} to Vector value: not properly enclosed in brackets"
91 ),
92 }
93 .fail();
94 }
95 let content = trimmed[1..trimmed.len() - 1].trim();
96 if content.is_empty() {
97 return Ok(Vec::new());
98 }
99
100 content
101 .split(',')
102 .map(|s| s.trim().parse::<f32>())
103 .collect::<std::result::Result<_, _>>()
104 .map_err(|e| {
105 InvalidFuncArgsSnafu {
106 err_msg: format!("Failed to parse {s} to Vector value: {e}"),
107 }
108 .build()
109 })
110}
111
112pub fn veclit_to_binlit(vec: &[f32]) -> Vec<u8> {
114 if cfg!(target_endian = "little") {
115 unsafe {
116 std::slice::from_raw_parts(vec.as_ptr() as *const u8, std::mem::size_of_val(vec))
117 .to_vec()
118 }
119 } else {
120 let mut bytes = Vec::with_capacity(std::mem::size_of_val(vec));
121 for e in vec {
122 bytes.extend_from_slice(&e.to_le_bytes());
123 }
124 bytes
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[test]
133 fn test_parse_veclit_from_strlit() {
134 let result = parse_veclit_from_strlit("[1.0, 2.0, 3.0]").unwrap();
135 assert_eq!(result, vec![1.0, 2.0, 3.0]);
136
137 let result = parse_veclit_from_strlit("[]").unwrap();
138 assert_eq!(result, Vec::<f32>::new());
139
140 let result = parse_veclit_from_strlit("[1.0, a, 3.0]");
141 assert!(result.is_err());
142 }
143
144 #[test]
145 fn test_binlit_as_veclit() {
146 let vec = &[1.0, 2.0, 3.0];
147 let bytes = veclit_to_binlit(vec);
148 let result = binlit_as_veclit(&bytes).unwrap();
149 assert_eq!(result.as_ref(), vec);
150
151 let invalid_bytes = [0, 0, 128];
152 let result = binlit_as_veclit(&invalid_bytes);
153 assert!(result.is_err());
154 }
155}