common_function/aggr/
geo_path.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use datafusion::arrow::array::{Array, ArrayRef};
18use datafusion::common::cast::as_primitive_array;
19use datafusion::error::{DataFusionError, Result as DfResult};
20use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF, Volatility};
21use datafusion::prelude::create_udaf;
22use datafusion_common::cast::{as_list_array, as_struct_array};
23use datafusion_common::utils::SingleRowListArrayBuilder;
24use datafusion_common::ScalarValue;
25use datatypes::arrow::array::{Float64Array, Int64Array, ListArray, StructArray};
26use datatypes::arrow::datatypes::{
27    DataType, Field, Float64Type, Int64Type, TimeUnit, TimestampNanosecondType,
28};
29use datatypes::compute::{self, sort_to_indices};
30
31pub const GEO_PATH_NAME: &str = "geo_path";
32
33const LATITUDE_FIELD: &str = "lat";
34const LONGITUDE_FIELD: &str = "lng";
35const TIMESTAMP_FIELD: &str = "timestamp";
36const DEFAULT_LIST_FIELD_NAME: &str = "item";
37
38#[derive(Debug, Default)]
39pub struct GeoPathAccumulator {
40    lat: Vec<Option<f64>>,
41    lng: Vec<Option<f64>>,
42    timestamp: Vec<Option<i64>>,
43}
44
45impl GeoPathAccumulator {
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    pub fn udf_impl() -> AggregateUDF {
51        create_udaf(
52            GEO_PATH_NAME,
53            // Input types: lat, lng, timestamp
54            vec![
55                DataType::Float64,
56                DataType::Float64,
57                DataType::Timestamp(TimeUnit::Nanosecond, None),
58            ],
59            // Output type: list of points {[lat], [lng]}
60            Arc::new(DataType::Struct(
61                vec![
62                    Field::new(
63                        LATITUDE_FIELD,
64                        DataType::List(Arc::new(Field::new(
65                            DEFAULT_LIST_FIELD_NAME,
66                            DataType::Float64,
67                            true,
68                        ))),
69                        false,
70                    ),
71                    Field::new(
72                        LONGITUDE_FIELD,
73                        DataType::List(Arc::new(Field::new(
74                            DEFAULT_LIST_FIELD_NAME,
75                            DataType::Float64,
76                            true,
77                        ))),
78                        false,
79                    ),
80                ]
81                .into(),
82            )),
83            Volatility::Immutable,
84            // Create the accumulator
85            Arc::new(|_| Ok(Box::new(GeoPathAccumulator::new()))),
86            // Intermediate state types
87            Arc::new(vec![DataType::Struct(
88                vec![
89                    Field::new(
90                        LATITUDE_FIELD,
91                        DataType::List(Arc::new(Field::new(
92                            DEFAULT_LIST_FIELD_NAME,
93                            DataType::Float64,
94                            true,
95                        ))),
96                        false,
97                    ),
98                    Field::new(
99                        LONGITUDE_FIELD,
100                        DataType::List(Arc::new(Field::new(
101                            DEFAULT_LIST_FIELD_NAME,
102                            DataType::Float64,
103                            true,
104                        ))),
105                        false,
106                    ),
107                    Field::new(
108                        TIMESTAMP_FIELD,
109                        DataType::List(Arc::new(Field::new(
110                            DEFAULT_LIST_FIELD_NAME,
111                            DataType::Int64,
112                            true,
113                        ))),
114                        false,
115                    ),
116                ]
117                .into(),
118            )]),
119        )
120    }
121}
122
123impl DfAccumulator for GeoPathAccumulator {
124    fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion::error::Result<()> {
125        if values.len() != 3 {
126            return Err(DataFusionError::Internal(format!(
127                "Expected 3 columns for geo_path, got {}",
128                values.len()
129            )));
130        }
131
132        let lat_array = as_primitive_array::<Float64Type>(&values[0])?;
133        let lng_array = as_primitive_array::<Float64Type>(&values[1])?;
134        let ts_array = as_primitive_array::<TimestampNanosecondType>(&values[2])?;
135
136        let size = lat_array.len();
137        self.lat.reserve(size);
138        self.lng.reserve(size);
139
140        for idx in 0..size {
141            self.lat.push(if lat_array.is_null(idx) {
142                None
143            } else {
144                Some(lat_array.value(idx))
145            });
146
147            self.lng.push(if lng_array.is_null(idx) {
148                None
149            } else {
150                Some(lng_array.value(idx))
151            });
152
153            self.timestamp.push(if ts_array.is_null(idx) {
154                None
155            } else {
156                Some(ts_array.value(idx))
157            });
158        }
159
160        Ok(())
161    }
162
163    fn evaluate(&mut self) -> DfResult<ScalarValue> {
164        let unordered_lng_array = Float64Array::from(self.lng.clone());
165        let unordered_lat_array = Float64Array::from(self.lat.clone());
166        let ts_array = Int64Array::from(self.timestamp.clone());
167
168        let ordered_indices = sort_to_indices(&ts_array, None, None)?;
169        let lat_array = compute::take(&unordered_lat_array, &ordered_indices, None)?;
170        let lng_array = compute::take(&unordered_lng_array, &ordered_indices, None)?;
171
172        let lat_list = Arc::new(SingleRowListArrayBuilder::new(lat_array).build_list_array());
173        let lng_list = Arc::new(SingleRowListArrayBuilder::new(lng_array).build_list_array());
174
175        let result = ScalarValue::Struct(Arc::new(StructArray::new(
176            vec![
177                Field::new(
178                    LATITUDE_FIELD,
179                    DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
180                    false,
181                ),
182                Field::new(
183                    LONGITUDE_FIELD,
184                    DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
185                    false,
186                ),
187            ]
188            .into(),
189            vec![lat_list, lng_list],
190            None,
191        )));
192
193        Ok(result)
194    }
195
196    fn size(&self) -> usize {
197        // Base size of GeoPathAccumulator struct fields
198        let mut total_size = std::mem::size_of::<Self>();
199
200        // Size of vectors (approximation)
201        total_size += self.lat.capacity() * std::mem::size_of::<Option<f64>>();
202        total_size += self.lng.capacity() * std::mem::size_of::<Option<f64>>();
203        total_size += self.timestamp.capacity() * std::mem::size_of::<Option<i64>>();
204
205        total_size
206    }
207
208    fn state(&mut self) -> datafusion::error::Result<Vec<ScalarValue>> {
209        let lat_array = Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
210            Some(self.lat.clone()),
211        ]));
212        let lng_array = Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
213            Some(self.lng.clone()),
214        ]));
215        let ts_array = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
216            Some(self.timestamp.clone()),
217        ]));
218
219        let state_struct = StructArray::new(
220            vec![
221                Field::new(
222                    LATITUDE_FIELD,
223                    DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
224                    false,
225                ),
226                Field::new(
227                    LONGITUDE_FIELD,
228                    DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
229                    false,
230                ),
231                Field::new(
232                    TIMESTAMP_FIELD,
233                    DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
234                    false,
235                ),
236            ]
237            .into(),
238            vec![lat_array, lng_array, ts_array],
239            None,
240        );
241
242        Ok(vec![ScalarValue::Struct(Arc::new(state_struct))])
243    }
244
245    fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion::error::Result<()> {
246        if states.len() != 1 {
247            return Err(DataFusionError::Internal(format!(
248                "Expected 1 states for geo_path, got {}",
249                states.len()
250            )));
251        }
252
253        for state in states {
254            let state = as_struct_array(state)?;
255            let lat_list = as_list_array(state.column(0))?.value(0);
256            let lat_array = as_primitive_array::<Float64Type>(&lat_list)?;
257            let lng_list = as_list_array(state.column(1))?.value(0);
258            let lng_array = as_primitive_array::<Float64Type>(&lng_list)?;
259            let ts_list = as_list_array(state.column(2))?.value(0);
260            let ts_array = as_primitive_array::<Int64Type>(&ts_list)?;
261
262            self.lat.extend(lat_array);
263            self.lng.extend(lng_array);
264            self.timestamp.extend(ts_array);
265        }
266
267        Ok(())
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use datafusion::arrow::array::{Float64Array, TimestampNanosecondArray};
274    use datafusion::scalar::ScalarValue;
275
276    use super::*;
277
278    #[test]
279    fn test_geo_path_basic() {
280        let mut accumulator = GeoPathAccumulator::new();
281
282        // Create test data
283        let lat_array = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]));
284        let lng_array = Arc::new(Float64Array::from(vec![4.0, 5.0, 6.0]));
285        let ts_array = Arc::new(TimestampNanosecondArray::from(vec![100, 200, 300]));
286
287        // Update batch
288        accumulator
289            .update_batch(&[lat_array, lng_array, ts_array])
290            .unwrap();
291
292        // Evaluate
293        let result = accumulator.evaluate().unwrap();
294        if let ScalarValue::Struct(struct_array) = result {
295            // Verify structure
296            let fields = struct_array.fields().clone();
297            assert_eq!(fields.len(), 2);
298            assert_eq!(fields[0].name(), LATITUDE_FIELD);
299            assert_eq!(fields[1].name(), LONGITUDE_FIELD);
300
301            // Verify data
302            let columns = struct_array.columns();
303            assert_eq!(columns.len(), 2);
304
305            // Check latitude values
306            let lat_list = as_list_array(&columns[0]).unwrap().value(0);
307            let lat_array = as_primitive_array::<Float64Type>(&lat_list).unwrap();
308            assert_eq!(lat_array.len(), 3);
309            assert_eq!(lat_array.value(0), 1.0);
310            assert_eq!(lat_array.value(1), 2.0);
311            assert_eq!(lat_array.value(2), 3.0);
312
313            // Check longitude values
314            let lng_list = as_list_array(&columns[1]).unwrap().value(0);
315            let lng_array = as_primitive_array::<Float64Type>(&lng_list).unwrap();
316            assert_eq!(lng_array.len(), 3);
317            assert_eq!(lng_array.value(0), 4.0);
318            assert_eq!(lng_array.value(1), 5.0);
319            assert_eq!(lng_array.value(2), 6.0);
320        } else {
321            panic!("Expected Struct scalar value");
322        }
323    }
324
325    #[test]
326    fn test_geo_path_sort_by_timestamp() {
327        let mut accumulator = GeoPathAccumulator::new();
328
329        // Create test data with unordered timestamps
330        let lat_array = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]));
331        let lng_array = Arc::new(Float64Array::from(vec![4.0, 5.0, 6.0]));
332        let ts_array = Arc::new(TimestampNanosecondArray::from(vec![300, 100, 200]));
333
334        // Update batch
335        accumulator
336            .update_batch(&[lat_array, lng_array, ts_array])
337            .unwrap();
338
339        // Evaluate
340        let result = accumulator.evaluate().unwrap();
341        if let ScalarValue::Struct(struct_array) = result {
342            // Extract arrays
343            let columns = struct_array.columns();
344
345            // Check latitude values
346            let lat_list = as_list_array(&columns[0]).unwrap().value(0);
347            let lat_array = as_primitive_array::<Float64Type>(&lat_list).unwrap();
348            assert_eq!(lat_array.len(), 3);
349            assert_eq!(lat_array.value(0), 2.0); // timestamp 100
350            assert_eq!(lat_array.value(1), 3.0); // timestamp 200
351            assert_eq!(lat_array.value(2), 1.0); // timestamp 300
352
353            // Check longitude values (should be sorted by timestamp)
354            let lng_list = as_list_array(&columns[1]).unwrap().value(0);
355            let lng_array = as_primitive_array::<Float64Type>(&lng_list).unwrap();
356            assert_eq!(lng_array.len(), 3);
357            assert_eq!(lng_array.value(0), 5.0); // timestamp 100
358            assert_eq!(lng_array.value(1), 6.0); // timestamp 200
359            assert_eq!(lng_array.value(2), 4.0); // timestamp 300
360        } else {
361            panic!("Expected Struct scalar value");
362        }
363    }
364
365    #[test]
366    fn test_geo_path_merge() {
367        let mut accumulator1 = GeoPathAccumulator::new();
368        let mut accumulator2 = GeoPathAccumulator::new();
369
370        // Create test data for first accumulator
371        let lat_array1 = Arc::new(Float64Array::from(vec![1.0]));
372        let lng_array1 = Arc::new(Float64Array::from(vec![4.0]));
373        let ts_array1 = Arc::new(TimestampNanosecondArray::from(vec![100]));
374
375        // Create test data for second accumulator
376        let lat_array2 = Arc::new(Float64Array::from(vec![2.0]));
377        let lng_array2 = Arc::new(Float64Array::from(vec![5.0]));
378        let ts_array2 = Arc::new(TimestampNanosecondArray::from(vec![200]));
379
380        // Update batches
381        accumulator1
382            .update_batch(&[lat_array1, lng_array1, ts_array1])
383            .unwrap();
384        accumulator2
385            .update_batch(&[lat_array2, lng_array2, ts_array2])
386            .unwrap();
387
388        // Get states
389        let state1 = accumulator1.state().unwrap();
390        let state2 = accumulator2.state().unwrap();
391
392        // Create a merged accumulator
393        let mut merged = GeoPathAccumulator::new();
394
395        // Extract the struct arrays from the states
396        let state_array1 = match &state1[0] {
397            ScalarValue::Struct(array) => array.clone(),
398            _ => panic!("Expected Struct scalar value"),
399        };
400
401        let state_array2 = match &state2[0] {
402            ScalarValue::Struct(array) => array.clone(),
403            _ => panic!("Expected Struct scalar value"),
404        };
405
406        // Merge state arrays
407        merged.merge_batch(&[state_array1]).unwrap();
408        merged.merge_batch(&[state_array2]).unwrap();
409
410        // Evaluate merged result
411        let result = merged.evaluate().unwrap();
412        if let ScalarValue::Struct(struct_array) = result {
413            // Extract arrays
414            let columns = struct_array.columns();
415
416            // Check latitude values
417            let lat_list = as_list_array(&columns[0]).unwrap().value(0);
418            let lat_array = as_primitive_array::<Float64Type>(&lat_list).unwrap();
419            assert_eq!(lat_array.len(), 2);
420            assert_eq!(lat_array.value(0), 1.0); // timestamp 100
421            assert_eq!(lat_array.value(1), 2.0); // timestamp 200
422
423            // Check longitude values (should be sorted by timestamp)
424            let lng_list = as_list_array(&columns[1]).unwrap().value(0);
425            let lng_array = as_primitive_array::<Float64Type>(&lng_list).unwrap();
426            assert_eq!(lng_array.len(), 2);
427            assert_eq!(lng_array.value(0), 4.0); // timestamp 100
428            assert_eq!(lng_array.value(1), 5.0); // timestamp 200
429        } else {
430            panic!("Expected Struct scalar value");
431        }
432    }
433}