common_function/scalars/vector/
vector_kth_elem.rs1use std::borrow::Cow;
16use std::fmt::Display;
17
18use common_query::error::{InvalidFuncArgsSnafu, Result};
19use datafusion_expr::Signature;
20use datatypes::arrow::datatypes::DataType;
21use datatypes::scalars::ScalarVectorBuilder;
22use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
23use snafu::ensure;
24
25use crate::function::{Function, FunctionContext};
26use crate::helper;
27use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
28
29const NAME: &str = "vec_kth_elem";
30
31#[derive(Debug, Clone, Default)]
48pub struct VectorKthElemFunction;
49
50impl Function for VectorKthElemFunction {
51 fn name(&self) -> &str {
52 NAME
53 }
54
55 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
56 Ok(DataType::Float32)
57 }
58
59 fn signature(&self) -> Signature {
60 helper::one_of_sigs2(
61 vec![DataType::Utf8, DataType::Binary],
62 vec![DataType::Int64],
63 )
64 }
65
66 fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
67 ensure!(
68 columns.len() == 2,
69 InvalidFuncArgsSnafu {
70 err_msg: format!(
71 "The length of the args is not correct, expect exactly two, have: {}",
72 columns.len()
73 ),
74 }
75 );
76
77 let arg0 = &columns[0];
78 let arg1 = &columns[1];
79
80 let len = arg0.len();
81 let mut result = Float32VectorBuilder::with_capacity(len);
82 if len == 0 {
83 return Ok(result.to_vector());
84 };
85
86 let arg0_const = as_veclit_if_const(arg0)?;
87
88 for i in 0..len {
89 let arg0 = match arg0_const.as_ref() {
90 Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
91 None => as_veclit(arg0.get_ref(i))?,
92 };
93 let Some(arg0) = arg0 else {
94 result.push_null();
95 continue;
96 };
97
98 let arg1 = arg1.get(i).as_f64_lossy();
99 let Some(arg1) = arg1 else {
100 result.push_null();
101 continue;
102 };
103
104 ensure!(
105 arg1 >= 0.0 && arg1.fract() == 0.0,
106 InvalidFuncArgsSnafu {
107 err_msg: format!(
108 "Invalid argument: k must be a non-negative integer, but got k = {}.",
109 arg1
110 ),
111 }
112 );
113
114 let k = arg1 as usize;
115
116 ensure!(
117 k < arg0.len(),
118 InvalidFuncArgsSnafu {
119 err_msg: format!(
120 "Out of range: k must be in the range [0, {}], but got k = {}.",
121 arg0.len() - 1,
122 k
123 ),
124 }
125 );
126
127 let value = arg0[k];
128
129 result.push(Some(value));
130 }
131 Ok(result.to_vector())
132 }
133}
134
135impl Display for VectorKthElemFunction {
136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137 write!(f, "{}", NAME.to_ascii_uppercase())
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use std::sync::Arc;
144
145 use common_query::error;
146 use datatypes::vectors::{Int64Vector, StringVector};
147
148 use super::*;
149
150 #[test]
151 fn test_vec_kth_elem() {
152 let func = VectorKthElemFunction;
153
154 let input0 = Arc::new(StringVector::from(vec![
155 Some("[1.0,2.0,3.0]".to_string()),
156 Some("[4.0,5.0,6.0]".to_string()),
157 Some("[7.0,8.0,9.0]".to_string()),
158 None,
159 ]));
160 let input1 = Arc::new(Int64Vector::from(vec![Some(0), Some(2), None, Some(1)]));
161
162 let result = func
163 .eval(&FunctionContext::default(), &[input0, input1])
164 .unwrap();
165
166 let result = result.as_ref();
167 assert_eq!(result.len(), 4);
168 assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(1.0));
169 assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(6.0));
170 assert!(result.get_ref(2).is_null());
171 assert!(result.get_ref(3).is_null());
172
173 let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
174 let input1 = Arc::new(Int64Vector::from(vec![Some(3)]));
175
176 let err = func
177 .eval(&FunctionContext::default(), &[input0, input1])
178 .unwrap_err();
179 match err {
180 error::Error::InvalidFuncArgs { err_msg, .. } => {
181 assert_eq!(
182 err_msg,
183 format!("Out of range: k must be in the range [0, 2], but got k = 3.")
184 )
185 }
186 _ => unreachable!(),
187 }
188
189 let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
190 let input1 = Arc::new(Int64Vector::from(vec![Some(-1)]));
191
192 let err = func
193 .eval(&FunctionContext::default(), &[input0, input1])
194 .unwrap_err();
195 match err {
196 error::Error::InvalidFuncArgs { err_msg, .. } => {
197 assert_eq!(
198 err_msg,
199 format!("Invalid argument: k must be a non-negative integer, but got k = -1.")
200 )
201 }
202 _ => unreachable!(),
203 }
204 }
205}