common_function/scalars/vector/
vector_kth_elem.rs1use std::fmt::Display;
16
17use datafusion::logical_expr::ColumnarValue;
18use datafusion_common::{DataFusionError, ScalarValue};
19use datafusion_expr::{ScalarFunctionArgs, Signature};
20use datatypes::arrow::datatypes::DataType;
21
22use crate::function::Function;
23use crate::helper;
24use crate::scalars::vector::VectorCalculator;
25use crate::scalars::vector::impl_conv::as_veclit;
26
27const NAME: &str = "vec_kth_elem";
28
29#[derive(Debug, Clone)]
46pub(crate) struct VectorKthElemFunction {
47 signature: Signature,
48}
49
50impl Default for VectorKthElemFunction {
51 fn default() -> Self {
52 Self {
53 signature: helper::one_of_sigs2(
54 vec![DataType::Utf8, DataType::Binary],
55 vec![DataType::Int64],
56 ),
57 }
58 }
59}
60
61impl Function for VectorKthElemFunction {
62 fn name(&self) -> &str {
63 NAME
64 }
65
66 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
67 Ok(DataType::Float32)
68 }
69
70 fn signature(&self) -> &Signature {
71 &self.signature
72 }
73
74 fn invoke_with_args(
75 &self,
76 args: ScalarFunctionArgs,
77 ) -> datafusion_common::Result<ColumnarValue> {
78 let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
79 let v0 = as_veclit(v0)?;
80
81 let v1 = match v1 {
82 ScalarValue::Int64(None) => return Ok(ScalarValue::Float32(None)),
83 ScalarValue::Int64(Some(v1)) if *v1 >= 0 => *v1 as usize,
84 _ => {
85 return Err(DataFusionError::Execution(format!(
86 "2nd argument not a valid index or expected datatype: {}",
87 self.name()
88 )));
89 }
90 };
91
92 let result = v0
93 .map(|v0| {
94 if v1 >= v0.len() {
95 Err(DataFusionError::Execution(format!(
96 "index out of bound: {}",
97 self.name()
98 )))
99 } else {
100 Ok(v0[v1])
101 }
102 })
103 .transpose()?;
104 Ok(ScalarValue::Float32(result))
105 };
106
107 let calculator = VectorCalculator {
108 name: self.name(),
109 func: body,
110 };
111 calculator.invoke_with_args(args)
112 }
113}
114
115impl Display for VectorKthElemFunction {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 write!(f, "{}", NAME.to_ascii_uppercase())
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use std::sync::Arc;
124
125 use arrow_schema::Field;
126 use datafusion::arrow::array::{Array, ArrayRef, AsArray, Int64Array, StringViewArray};
127 use datafusion::arrow::datatypes::Float32Type;
128 use datafusion_common::config::ConfigOptions;
129
130 use super::*;
131
132 #[test]
133 fn test_vec_kth_elem() {
134 let func = VectorKthElemFunction::default();
135
136 let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
137 Some("[1.0,2.0,3.0]".to_string()),
138 Some("[4.0,5.0,6.0]".to_string()),
139 Some("[7.0,8.0,9.0]".to_string()),
140 None,
141 ]));
142 let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(0), Some(2), None, Some(1)]));
143
144 let args = ScalarFunctionArgs {
145 args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
146 arg_fields: vec![],
147 number_rows: 4,
148 return_field: Arc::new(Field::new("x", DataType::Float32, false)),
149 config_options: Arc::new(ConfigOptions::new()),
150 };
151 let result = func
152 .invoke_with_args(args)
153 .and_then(|x| x.to_array(4))
154 .unwrap();
155
156 let result = result.as_primitive::<Float32Type>();
157 assert_eq!(result.len(), 4);
158 assert_eq!(result.value(0), 1.0);
159 assert_eq!(result.value(1), 6.0);
160 assert!(result.is_null(2));
161 assert!(result.is_null(3));
162
163 let input0: ArrayRef = Arc::new(StringViewArray::from(vec![Some(
164 "[1.0,2.0,3.0]".to_string(),
165 )]));
166 let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(3)]));
167
168 let args = ScalarFunctionArgs {
169 args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
170 arg_fields: vec![],
171 number_rows: 3,
172 return_field: Arc::new(Field::new("x", DataType::Float32, false)),
173 config_options: Arc::new(ConfigOptions::new()),
174 };
175 let e = func.invoke_with_args(args).unwrap_err();
176 assert!(
177 e.to_string()
178 .starts_with("Execution error: index out of bound: vec_kth_elem")
179 );
180
181 let input0: ArrayRef = Arc::new(StringViewArray::from(vec![Some(
182 "[1.0,2.0,3.0]".to_string(),
183 )]));
184 let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(-1)]));
185
186 let args = ScalarFunctionArgs {
187 args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
188 arg_fields: vec![],
189 number_rows: 3,
190 return_field: Arc::new(Field::new("x", DataType::Float32, false)),
191 config_options: Arc::new(ConfigOptions::new()),
192 };
193 let e = func.invoke_with_args(args).unwrap_err();
194 assert!(e.to_string().starts_with(
195 "Execution error: 2nd argument not a valid index or expected datatype: vec_kth_elem"
196 ));
197 }
198}