common_function/scalars/vector/
impl_conv.rs1use std::borrow::Cow;
16
17use common_query::error::{InvalidFuncArgsSnafu, Result};
18use datafusion_common::ScalarValue;
19
20pub fn as_veclit(arg: &ScalarValue) -> Result<Option<Cow<'_, [f32]>>> {
22 match arg {
23 ScalarValue::Binary(b) | ScalarValue::BinaryView(b) => {
24 b.as_ref().map(|x| binlit_as_veclit(x)).transpose()
25 }
26 ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) => s
27 .as_ref()
28 .map(|x| parse_veclit_from_strlit(x).map(Cow::Owned))
29 .transpose(),
30 _ => InvalidFuncArgsSnafu {
31 err_msg: format!("Unsupported data type: {:?}", arg.data_type()),
32 }
33 .fail(),
34 }
35}
36
37pub fn binlit_as_veclit(bytes: &[u8]) -> Result<Cow<'_, [f32]>> {
39 if !bytes.len().is_multiple_of(size_of::<f32>()) {
40 return InvalidFuncArgsSnafu {
41 err_msg: format!("Invalid binary length of vector: {}", bytes.len()),
42 }
43 .fail();
44 }
45
46 if cfg!(target_endian = "little") {
47 Ok(unsafe {
48 let vec = std::slice::from_raw_parts(
49 bytes.as_ptr() as *const f32,
50 bytes.len() / std::mem::size_of::<f32>(),
51 );
52 Cow::Borrowed(vec)
53 })
54 } else {
55 let v = bytes
56 .chunks_exact(std::mem::size_of::<f32>())
57 .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
58 .collect::<Vec<f32>>();
59 Ok(Cow::Owned(v))
60 }
61}
62
63pub fn parse_veclit_from_strlit(s: &str) -> Result<Vec<f32>> {
66 let trimmed = s.trim();
67 if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
68 return InvalidFuncArgsSnafu {
69 err_msg: format!(
70 "Failed to parse {s} to Vector value: not properly enclosed in brackets"
71 ),
72 }
73 .fail();
74 }
75 let content = trimmed[1..trimmed.len() - 1].trim();
76 if content.is_empty() {
77 return Ok(Vec::new());
78 }
79
80 content
81 .split(',')
82 .map(|s| s.trim().parse::<f32>())
83 .collect::<std::result::Result<_, _>>()
84 .map_err(|e| {
85 InvalidFuncArgsSnafu {
86 err_msg: format!("Failed to parse {s} to Vector value: {e}"),
87 }
88 .build()
89 })
90}
91
92pub fn veclit_to_binlit(vec: &[f32]) -> Vec<u8> {
94 if cfg!(target_endian = "little") {
95 unsafe {
96 std::slice::from_raw_parts(vec.as_ptr() as *const u8, std::mem::size_of_val(vec))
97 .to_vec()
98 }
99 } else {
100 let mut bytes = Vec::with_capacity(std::mem::size_of_val(vec));
101 for e in vec {
102 bytes.extend_from_slice(&e.to_le_bytes());
103 }
104 bytes
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[test]
113 fn test_parse_veclit_from_strlit() {
114 let result = parse_veclit_from_strlit("[1.0, 2.0, 3.0]").unwrap();
115 assert_eq!(result, vec![1.0, 2.0, 3.0]);
116
117 let result = parse_veclit_from_strlit("[]").unwrap();
118 assert_eq!(result, Vec::<f32>::new());
119
120 let result = parse_veclit_from_strlit("[1.0, a, 3.0]");
121 assert!(result.is_err());
122 }
123
124 #[test]
125 fn test_binlit_as_veclit() {
126 let vec = &[1.0, 2.0, 3.0];
127 let bytes = veclit_to_binlit(vec);
128 let result = binlit_as_veclit(&bytes).unwrap();
129 assert_eq!(result.as_ref(), vec);
130
131 let invalid_bytes = [0, 0, 128];
132 let result = binlit_as_veclit(&invalid_bytes);
133 assert!(result.is_err());
134 }
135}