1use std::sync::Arc;
16
17use datafusion::arrow::array::{Array, ArrayRef};
18use datafusion::common::cast::as_primitive_array;
19use datafusion::error::{DataFusionError, Result as DfResult};
20use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF, Volatility};
21use datafusion::prelude::create_udaf;
22use datafusion_common::cast::{as_list_array, as_struct_array};
23use datafusion_common::utils::SingleRowListArrayBuilder;
24use datafusion_common::ScalarValue;
25use datatypes::arrow::array::{Float64Array, Int64Array, ListArray, StructArray};
26use datatypes::arrow::datatypes::{
27 DataType, Field, Float64Type, Int64Type, TimeUnit, TimestampNanosecondType,
28};
29use datatypes::compute::{self, sort_to_indices};
30
31pub const GEO_PATH_NAME: &str = "geo_path";
32
33const LATITUDE_FIELD: &str = "lat";
34const LONGITUDE_FIELD: &str = "lng";
35const TIMESTAMP_FIELD: &str = "timestamp";
36const DEFAULT_LIST_FIELD_NAME: &str = "item";
37
38#[derive(Debug, Default)]
39pub struct GeoPathAccumulator {
40 lat: Vec<Option<f64>>,
41 lng: Vec<Option<f64>>,
42 timestamp: Vec<Option<i64>>,
43}
44
45impl GeoPathAccumulator {
46 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub fn udf_impl() -> AggregateUDF {
51 create_udaf(
52 GEO_PATH_NAME,
53 vec![
55 DataType::Float64,
56 DataType::Float64,
57 DataType::Timestamp(TimeUnit::Nanosecond, None),
58 ],
59 Arc::new(DataType::Struct(
61 vec![
62 Field::new(
63 LATITUDE_FIELD,
64 DataType::List(Arc::new(Field::new(
65 DEFAULT_LIST_FIELD_NAME,
66 DataType::Float64,
67 true,
68 ))),
69 false,
70 ),
71 Field::new(
72 LONGITUDE_FIELD,
73 DataType::List(Arc::new(Field::new(
74 DEFAULT_LIST_FIELD_NAME,
75 DataType::Float64,
76 true,
77 ))),
78 false,
79 ),
80 ]
81 .into(),
82 )),
83 Volatility::Immutable,
84 Arc::new(|_| Ok(Box::new(GeoPathAccumulator::new()))),
86 Arc::new(vec![DataType::Struct(
88 vec![
89 Field::new(
90 LATITUDE_FIELD,
91 DataType::List(Arc::new(Field::new(
92 DEFAULT_LIST_FIELD_NAME,
93 DataType::Float64,
94 true,
95 ))),
96 false,
97 ),
98 Field::new(
99 LONGITUDE_FIELD,
100 DataType::List(Arc::new(Field::new(
101 DEFAULT_LIST_FIELD_NAME,
102 DataType::Float64,
103 true,
104 ))),
105 false,
106 ),
107 Field::new(
108 TIMESTAMP_FIELD,
109 DataType::List(Arc::new(Field::new(
110 DEFAULT_LIST_FIELD_NAME,
111 DataType::Int64,
112 true,
113 ))),
114 false,
115 ),
116 ]
117 .into(),
118 )]),
119 )
120 }
121}
122
123impl DfAccumulator for GeoPathAccumulator {
124 fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion::error::Result<()> {
125 if values.len() != 3 {
126 return Err(DataFusionError::Internal(format!(
127 "Expected 3 columns for geo_path, got {}",
128 values.len()
129 )));
130 }
131
132 let lat_array = as_primitive_array::<Float64Type>(&values[0])?;
133 let lng_array = as_primitive_array::<Float64Type>(&values[1])?;
134 let ts_array = as_primitive_array::<TimestampNanosecondType>(&values[2])?;
135
136 let size = lat_array.len();
137 self.lat.reserve(size);
138 self.lng.reserve(size);
139
140 for idx in 0..size {
141 self.lat.push(if lat_array.is_null(idx) {
142 None
143 } else {
144 Some(lat_array.value(idx))
145 });
146
147 self.lng.push(if lng_array.is_null(idx) {
148 None
149 } else {
150 Some(lng_array.value(idx))
151 });
152
153 self.timestamp.push(if ts_array.is_null(idx) {
154 None
155 } else {
156 Some(ts_array.value(idx))
157 });
158 }
159
160 Ok(())
161 }
162
163 fn evaluate(&mut self) -> DfResult<ScalarValue> {
164 let unordered_lng_array = Float64Array::from(self.lng.clone());
165 let unordered_lat_array = Float64Array::from(self.lat.clone());
166 let ts_array = Int64Array::from(self.timestamp.clone());
167
168 let ordered_indices = sort_to_indices(&ts_array, None, None)?;
169 let lat_array = compute::take(&unordered_lat_array, &ordered_indices, None)?;
170 let lng_array = compute::take(&unordered_lng_array, &ordered_indices, None)?;
171
172 let lat_list = Arc::new(SingleRowListArrayBuilder::new(lat_array).build_list_array());
173 let lng_list = Arc::new(SingleRowListArrayBuilder::new(lng_array).build_list_array());
174
175 let result = ScalarValue::Struct(Arc::new(StructArray::new(
176 vec![
177 Field::new(
178 LATITUDE_FIELD,
179 DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
180 false,
181 ),
182 Field::new(
183 LONGITUDE_FIELD,
184 DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
185 false,
186 ),
187 ]
188 .into(),
189 vec![lat_list, lng_list],
190 None,
191 )));
192
193 Ok(result)
194 }
195
196 fn size(&self) -> usize {
197 let mut total_size = std::mem::size_of::<Self>();
199
200 total_size += self.lat.capacity() * std::mem::size_of::<Option<f64>>();
202 total_size += self.lng.capacity() * std::mem::size_of::<Option<f64>>();
203 total_size += self.timestamp.capacity() * std::mem::size_of::<Option<i64>>();
204
205 total_size
206 }
207
208 fn state(&mut self) -> datafusion::error::Result<Vec<ScalarValue>> {
209 let lat_array = Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
210 Some(self.lat.clone()),
211 ]));
212 let lng_array = Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
213 Some(self.lng.clone()),
214 ]));
215 let ts_array = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
216 Some(self.timestamp.clone()),
217 ]));
218
219 let state_struct = StructArray::new(
220 vec![
221 Field::new(
222 LATITUDE_FIELD,
223 DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
224 false,
225 ),
226 Field::new(
227 LONGITUDE_FIELD,
228 DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
229 false,
230 ),
231 Field::new(
232 TIMESTAMP_FIELD,
233 DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
234 false,
235 ),
236 ]
237 .into(),
238 vec![lat_array, lng_array, ts_array],
239 None,
240 );
241
242 Ok(vec![ScalarValue::Struct(Arc::new(state_struct))])
243 }
244
245 fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion::error::Result<()> {
246 if states.len() != 1 {
247 return Err(DataFusionError::Internal(format!(
248 "Expected 1 states for geo_path, got {}",
249 states.len()
250 )));
251 }
252
253 for state in states {
254 let state = as_struct_array(state)?;
255 let lat_list = as_list_array(state.column(0))?.value(0);
256 let lat_array = as_primitive_array::<Float64Type>(&lat_list)?;
257 let lng_list = as_list_array(state.column(1))?.value(0);
258 let lng_array = as_primitive_array::<Float64Type>(&lng_list)?;
259 let ts_list = as_list_array(state.column(2))?.value(0);
260 let ts_array = as_primitive_array::<Int64Type>(&ts_list)?;
261
262 self.lat.extend(lat_array);
263 self.lng.extend(lng_array);
264 self.timestamp.extend(ts_array);
265 }
266
267 Ok(())
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use datafusion::arrow::array::{Float64Array, TimestampNanosecondArray};
274 use datafusion::scalar::ScalarValue;
275
276 use super::*;
277
278 #[test]
279 fn test_geo_path_basic() {
280 let mut accumulator = GeoPathAccumulator::new();
281
282 let lat_array = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]));
284 let lng_array = Arc::new(Float64Array::from(vec![4.0, 5.0, 6.0]));
285 let ts_array = Arc::new(TimestampNanosecondArray::from(vec![100, 200, 300]));
286
287 accumulator
289 .update_batch(&[lat_array, lng_array, ts_array])
290 .unwrap();
291
292 let result = accumulator.evaluate().unwrap();
294 if let ScalarValue::Struct(struct_array) = result {
295 let fields = struct_array.fields().clone();
297 assert_eq!(fields.len(), 2);
298 assert_eq!(fields[0].name(), LATITUDE_FIELD);
299 assert_eq!(fields[1].name(), LONGITUDE_FIELD);
300
301 let columns = struct_array.columns();
303 assert_eq!(columns.len(), 2);
304
305 let lat_list = as_list_array(&columns[0]).unwrap().value(0);
307 let lat_array = as_primitive_array::<Float64Type>(&lat_list).unwrap();
308 assert_eq!(lat_array.len(), 3);
309 assert_eq!(lat_array.value(0), 1.0);
310 assert_eq!(lat_array.value(1), 2.0);
311 assert_eq!(lat_array.value(2), 3.0);
312
313 let lng_list = as_list_array(&columns[1]).unwrap().value(0);
315 let lng_array = as_primitive_array::<Float64Type>(&lng_list).unwrap();
316 assert_eq!(lng_array.len(), 3);
317 assert_eq!(lng_array.value(0), 4.0);
318 assert_eq!(lng_array.value(1), 5.0);
319 assert_eq!(lng_array.value(2), 6.0);
320 } else {
321 panic!("Expected Struct scalar value");
322 }
323 }
324
325 #[test]
326 fn test_geo_path_sort_by_timestamp() {
327 let mut accumulator = GeoPathAccumulator::new();
328
329 let lat_array = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]));
331 let lng_array = Arc::new(Float64Array::from(vec![4.0, 5.0, 6.0]));
332 let ts_array = Arc::new(TimestampNanosecondArray::from(vec![300, 100, 200]));
333
334 accumulator
336 .update_batch(&[lat_array, lng_array, ts_array])
337 .unwrap();
338
339 let result = accumulator.evaluate().unwrap();
341 if let ScalarValue::Struct(struct_array) = result {
342 let columns = struct_array.columns();
344
345 let lat_list = as_list_array(&columns[0]).unwrap().value(0);
347 let lat_array = as_primitive_array::<Float64Type>(&lat_list).unwrap();
348 assert_eq!(lat_array.len(), 3);
349 assert_eq!(lat_array.value(0), 2.0); assert_eq!(lat_array.value(1), 3.0); assert_eq!(lat_array.value(2), 1.0); let lng_list = as_list_array(&columns[1]).unwrap().value(0);
355 let lng_array = as_primitive_array::<Float64Type>(&lng_list).unwrap();
356 assert_eq!(lng_array.len(), 3);
357 assert_eq!(lng_array.value(0), 5.0); assert_eq!(lng_array.value(1), 6.0); assert_eq!(lng_array.value(2), 4.0); } else {
361 panic!("Expected Struct scalar value");
362 }
363 }
364
365 #[test]
366 fn test_geo_path_merge() {
367 let mut accumulator1 = GeoPathAccumulator::new();
368 let mut accumulator2 = GeoPathAccumulator::new();
369
370 let lat_array1 = Arc::new(Float64Array::from(vec![1.0]));
372 let lng_array1 = Arc::new(Float64Array::from(vec![4.0]));
373 let ts_array1 = Arc::new(TimestampNanosecondArray::from(vec![100]));
374
375 let lat_array2 = Arc::new(Float64Array::from(vec![2.0]));
377 let lng_array2 = Arc::new(Float64Array::from(vec![5.0]));
378 let ts_array2 = Arc::new(TimestampNanosecondArray::from(vec![200]));
379
380 accumulator1
382 .update_batch(&[lat_array1, lng_array1, ts_array1])
383 .unwrap();
384 accumulator2
385 .update_batch(&[lat_array2, lng_array2, ts_array2])
386 .unwrap();
387
388 let state1 = accumulator1.state().unwrap();
390 let state2 = accumulator2.state().unwrap();
391
392 let mut merged = GeoPathAccumulator::new();
394
395 let state_array1 = match &state1[0] {
397 ScalarValue::Struct(array) => array.clone(),
398 _ => panic!("Expected Struct scalar value"),
399 };
400
401 let state_array2 = match &state2[0] {
402 ScalarValue::Struct(array) => array.clone(),
403 _ => panic!("Expected Struct scalar value"),
404 };
405
406 merged.merge_batch(&[state_array1]).unwrap();
408 merged.merge_batch(&[state_array2]).unwrap();
409
410 let result = merged.evaluate().unwrap();
412 if let ScalarValue::Struct(struct_array) = result {
413 let columns = struct_array.columns();
415
416 let lat_list = as_list_array(&columns[0]).unwrap().value(0);
418 let lat_array = as_primitive_array::<Float64Type>(&lat_list).unwrap();
419 assert_eq!(lat_array.len(), 2);
420 assert_eq!(lat_array.value(0), 1.0); assert_eq!(lat_array.value(1), 2.0); let lng_list = as_list_array(&columns[1]).unwrap().value(0);
425 let lng_array = as_primitive_array::<Float64Type>(&lng_list).unwrap();
426 assert_eq!(lng_array.len(), 2);
427 assert_eq!(lng_array.value(0), 4.0); assert_eq!(lng_array.value(1), 5.0); } else {
430 panic!("Expected Struct scalar value");
431 }
432 }
433}