common_function/scalars/vector/
vector_kth_elem.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use common_query::prelude::Signature;
20use datatypes::prelude::ConcreteDataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
23use snafu::ensure;
24
25use crate::function::{Function, FunctionContext};
26use crate::helper;
27use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
28
29const NAME: &str = "vec_kth_elem";
30
31#[derive(Debug, Clone, Default)]
48pub struct VectorKthElemFunction;
49
50impl Function for VectorKthElemFunction {
51 fn name(&self) -> &str {
52 NAME
53 }
54
55 fn return_type(
56 &self,
57 _input_types: &[ConcreteDataType],
58 ) -> common_query::error::Result<ConcreteDataType> {
59 Ok(ConcreteDataType::float32_datatype())
60 }
61
62 fn signature(&self) -> Signature {
63 helper::one_of_sigs2(
64 vec![
65 ConcreteDataType::string_datatype(),
66 ConcreteDataType::binary_datatype(),
67 ],
68 vec![ConcreteDataType::int64_datatype()],
69 )
70 }
71
72 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
73 ensure!(
74 columns.len() == 2,
75 InvalidFuncArgsSnafu {
76 err_msg: format!(
77 "The length of the args is not correct, expect exactly two, have: {}",
78 columns.len()
79 ),
80 }
81 );
82
83 let arg0 = &columns[0];
84 let arg1 = &columns[1];
85
86 let len = arg0.len();
87 let mut result = Float32VectorBuilder::with_capacity(len);
88 if len == 0 {
89 return Ok(result.to_vector());
90 };
91
92 let arg0_const = as_veclit_if_const(arg0)?;
93
94 for i in 0..len {
95 let arg0 = match arg0_const.as_ref() {
96 Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
97 None => as_veclit(arg0.get_ref(i))?,
98 };
99 let Some(arg0) = arg0 else {
100 result.push_null();
101 continue;
102 };
103
104 let arg1 = arg1.get(i).as_f64_lossy();
105 let Some(arg1) = arg1 else {
106 result.push_null();
107 continue;
108 };
109
110 ensure!(
111 arg1 >= 0.0 && arg1.fract() == 0.0,
112 InvalidFuncArgsSnafu {
113 err_msg: format!(
114 "Invalid argument: k must be a non-negative integer, but got k = {}.",
115 arg1
116 ),
117 }
118 );
119
120 let k = arg1 as usize;
121
122 ensure!(
123 k < arg0.len(),
124 InvalidFuncArgsSnafu {
125 err_msg: format!(
126 "Out of range: k must be in the range [0, {}], but got k = {}.",
127 arg0.len() - 1,
128 k
129 ),
130 }
131 );
132
133 let value = arg0[k];
134
135 result.push(Some(value));
136 }
137 Ok(result.to_vector())
138 }
139}
140
141impl Display for VectorKthElemFunction {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 write!(f, "{}", NAME.to_ascii_uppercase())
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use std::sync::Arc;
150
151 use common_query::error;
152 use datatypes::vectors::{Int64Vector, StringVector};
153
154 use super::*;
155
156 #[test]
157 fn test_vec_kth_elem() {
158 let func = VectorKthElemFunction;
159
160 let input0 = Arc::new(StringVector::from(vec![
161 Some("[1.0,2.0,3.0]".to_string()),
162 Some("[4.0,5.0,6.0]".to_string()),
163 Some("[7.0,8.0,9.0]".to_string()),
164 None,
165 ]));
166 let input1 = Arc::new(Int64Vector::from(vec![Some(0), Some(2), None, Some(1)]));
167
168 let result = func
169 .eval(&FunctionContext::default(), &[input0, input1])
170 .unwrap();
171
172 let result = result.as_ref();
173 assert_eq!(result.len(), 4);
174 assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(1.0));
175 assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(6.0));
176 assert!(result.get_ref(2).is_null());
177 assert!(result.get_ref(3).is_null());
178
179 let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
180 let input1 = Arc::new(Int64Vector::from(vec![Some(3)]));
181
182 let err = func
183 .eval(&FunctionContext::default(), &[input0, input1])
184 .unwrap_err();
185 match err {
186 error::Error::InvalidFuncArgs { err_msg, .. } => {
187 assert_eq!(
188 err_msg,
189 format!("Out of range: k must be in the range [0, 2], but got k = 3.")
190 )
191 }
192 _ => unreachable!(),
193 }
194
195 let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
196 let input1 = Arc::new(Int64Vector::from(vec![Some(-1)]));
197
198 let err = func
199 .eval(&FunctionContext::default(), &[input0, input1])
200 .unwrap_err();
201 match err {
202 error::Error::InvalidFuncArgs { err_msg, .. } => {
203 assert_eq!(
204 err_msg,
205 format!("Invalid argument: k must be a non-negative integer, but got k = -1.")
206 )
207 }
208 _ => unreachable!(),
209 }
210 }
211}