common_function/aggrs/geo/
encoding.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use arrow::array::AsArray;
18use datafusion::arrow::array::{Array, ArrayRef};
19use datafusion::common::cast::as_primitive_array;
20use datafusion::error::{DataFusionError, Result as DfResult};
21use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF, Volatility};
22use datafusion::prelude::create_udaf;
23use datafusion_common::cast::{as_list_array, as_struct_array};
24use datafusion_common::ScalarValue;
25use datatypes::arrow::array::{Float64Array, Int64Array, ListArray, StructArray};
26use datatypes::arrow::datatypes::{
27    DataType, Field, Float64Type, Int64Type, TimeUnit, TimestampNanosecondType,
28};
29use datatypes::compute::{self, sort_to_indices};
30
31pub const JSON_ENCODE_PATH_NAME: &str = "json_encode_path";
32
33const LATITUDE_FIELD: &str = "lat";
34const LONGITUDE_FIELD: &str = "lng";
35const TIMESTAMP_FIELD: &str = "timestamp";
36const DEFAULT_LIST_FIELD_NAME: &str = "item";
37
38#[derive(Debug, Default)]
39pub struct JsonEncodePathAccumulator {
40    lat: Vec<Option<f64>>,
41    lng: Vec<Option<f64>>,
42    timestamp: Vec<Option<i64>>,
43}
44
45impl JsonEncodePathAccumulator {
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    pub fn uadf_impl() -> AggregateUDF {
51        create_udaf(
52            JSON_ENCODE_PATH_NAME,
53            // Input types: lat, lng, timestamp
54            vec![
55                DataType::Float64,
56                DataType::Float64,
57                DataType::Timestamp(TimeUnit::Nanosecond, None),
58            ],
59            // Output type: geojson compatible linestring
60            Arc::new(DataType::Utf8),
61            Volatility::Immutable,
62            // Create the accumulator
63            Arc::new(|_| Ok(Box::new(Self::new()))),
64            // Intermediate state types
65            Arc::new(vec![DataType::Struct(
66                vec![
67                    Field::new(
68                        LATITUDE_FIELD,
69                        DataType::List(Arc::new(Field::new(
70                            DEFAULT_LIST_FIELD_NAME,
71                            DataType::Float64,
72                            true,
73                        ))),
74                        false,
75                    ),
76                    Field::new(
77                        LONGITUDE_FIELD,
78                        DataType::List(Arc::new(Field::new(
79                            DEFAULT_LIST_FIELD_NAME,
80                            DataType::Float64,
81                            true,
82                        ))),
83                        false,
84                    ),
85                    Field::new(
86                        TIMESTAMP_FIELD,
87                        DataType::List(Arc::new(Field::new(
88                            DEFAULT_LIST_FIELD_NAME,
89                            DataType::Int64,
90                            true,
91                        ))),
92                        false,
93                    ),
94                ]
95                .into(),
96            )]),
97        )
98    }
99}
100
101impl DfAccumulator for JsonEncodePathAccumulator {
102    fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion::error::Result<()> {
103        if values.len() != 3 {
104            return Err(DataFusionError::Internal(format!(
105                "Expected 3 columns for json_encode_path, got {}",
106                values.len()
107            )));
108        }
109
110        let lat_array = as_primitive_array::<Float64Type>(&values[0])?;
111        let lng_array = as_primitive_array::<Float64Type>(&values[1])?;
112        let ts_array = as_primitive_array::<TimestampNanosecondType>(&values[2])?;
113
114        let size = lat_array.len();
115        self.lat.reserve(size);
116        self.lng.reserve(size);
117
118        for idx in 0..size {
119            self.lat.push(if lat_array.is_null(idx) {
120                None
121            } else {
122                Some(lat_array.value(idx))
123            });
124
125            self.lng.push(if lng_array.is_null(idx) {
126                None
127            } else {
128                Some(lng_array.value(idx))
129            });
130
131            self.timestamp.push(if ts_array.is_null(idx) {
132                None
133            } else {
134                Some(ts_array.value(idx))
135            });
136        }
137
138        Ok(())
139    }
140
141    fn evaluate(&mut self) -> DfResult<ScalarValue> {
142        let unordered_lng_array = Float64Array::from(self.lng.clone());
143        let unordered_lat_array = Float64Array::from(self.lat.clone());
144        let ts_array = Int64Array::from(self.timestamp.clone());
145
146        let ordered_indices = sort_to_indices(&ts_array, None, None)?;
147        let lat_array = compute::take(&unordered_lat_array, &ordered_indices, None)?;
148        let lng_array = compute::take(&unordered_lng_array, &ordered_indices, None)?;
149
150        let len = ts_array.len();
151        let lat_array = lat_array.as_primitive::<Float64Type>();
152        let lng_array = lng_array.as_primitive::<Float64Type>();
153
154        let mut coords = Vec::with_capacity(len);
155        for i in 0..len {
156            let lng = lng_array.value(i);
157            let lat = lat_array.value(i);
158            coords.push(vec![lng, lat]);
159        }
160
161        let result = serde_json::to_string(&coords)
162            .map_err(|e| DataFusionError::Execution(format!("Failed to encode json, {}", e)))?;
163
164        Ok(ScalarValue::Utf8(Some(result)))
165    }
166
167    fn size(&self) -> usize {
168        // Base size of JsonEncodePathAccumulator struct fields
169        let mut total_size = std::mem::size_of::<Self>();
170
171        // Size of vectors (approximation)
172        total_size += self.lat.capacity() * std::mem::size_of::<Option<f64>>();
173        total_size += self.lng.capacity() * std::mem::size_of::<Option<f64>>();
174        total_size += self.timestamp.capacity() * std::mem::size_of::<Option<i64>>();
175
176        total_size
177    }
178
179    fn state(&mut self) -> datafusion::error::Result<Vec<ScalarValue>> {
180        let lat_array = Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
181            Some(self.lat.clone()),
182        ]));
183        let lng_array = Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
184            Some(self.lng.clone()),
185        ]));
186        let ts_array = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
187            Some(self.timestamp.clone()),
188        ]));
189
190        let state_struct = StructArray::new(
191            vec![
192                Field::new(
193                    LATITUDE_FIELD,
194                    DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
195                    false,
196                ),
197                Field::new(
198                    LONGITUDE_FIELD,
199                    DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
200                    false,
201                ),
202                Field::new(
203                    TIMESTAMP_FIELD,
204                    DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
205                    false,
206                ),
207            ]
208            .into(),
209            vec![lat_array, lng_array, ts_array],
210            None,
211        );
212
213        Ok(vec![ScalarValue::Struct(Arc::new(state_struct))])
214    }
215
216    fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion::error::Result<()> {
217        if states.len() != 1 {
218            return Err(DataFusionError::Internal(format!(
219                "Expected 1 states for json_encode_path, got {}",
220                states.len()
221            )));
222        }
223
224        for state in states {
225            let state = as_struct_array(state)?;
226            let lat_list = as_list_array(state.column(0))?.value(0);
227            let lat_array = as_primitive_array::<Float64Type>(&lat_list)?;
228            let lng_list = as_list_array(state.column(1))?.value(0);
229            let lng_array = as_primitive_array::<Float64Type>(&lng_list)?;
230            let ts_list = as_list_array(state.column(2))?.value(0);
231            let ts_array = as_primitive_array::<Int64Type>(&ts_list)?;
232
233            self.lat.extend(lat_array);
234            self.lng.extend(lng_array);
235            self.timestamp.extend(ts_array);
236        }
237
238        Ok(())
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use datafusion::arrow::array::{Float64Array, TimestampNanosecondArray};
245    use datafusion::scalar::ScalarValue;
246
247    use super::*;
248
249    #[test]
250    fn test_json_encode_path_basic() {
251        let mut accumulator = JsonEncodePathAccumulator::new();
252
253        // Create test data
254        let lat_array = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]));
255        let lng_array = Arc::new(Float64Array::from(vec![4.0, 5.0, 6.0]));
256        let ts_array = Arc::new(TimestampNanosecondArray::from(vec![100, 200, 300]));
257
258        // Update batch
259        accumulator
260            .update_batch(&[lat_array, lng_array, ts_array])
261            .unwrap();
262
263        // Evaluate
264        let result = accumulator.evaluate().unwrap();
265        assert_eq!(
266            result,
267            ScalarValue::Utf8(Some("[[4.0,1.0],[5.0,2.0],[6.0,3.0]]".to_string()))
268        );
269    }
270
271    #[test]
272    fn test_json_encode_path_sort_by_timestamp() {
273        let mut accumulator = JsonEncodePathAccumulator::new();
274
275        // Create test data with unordered timestamps
276        let lat_array = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]));
277        let lng_array = Arc::new(Float64Array::from(vec![4.0, 5.0, 6.0]));
278        let ts_array = Arc::new(TimestampNanosecondArray::from(vec![300, 100, 200]));
279
280        // Update batch
281        accumulator
282            .update_batch(&[lat_array, lng_array, ts_array])
283            .unwrap();
284
285        // Evaluate
286        let result = accumulator.evaluate().unwrap();
287        assert_eq!(
288            result,
289            ScalarValue::Utf8(Some("[[5.0,2.0],[6.0,3.0],[4.0,1.0]]".to_string()))
290        );
291    }
292
293    #[test]
294    fn test_json_encode_path_merge() {
295        let mut accumulator1 = JsonEncodePathAccumulator::new();
296        let mut accumulator2 = JsonEncodePathAccumulator::new();
297
298        // Create test data for first accumulator
299        let lat_array1 = Arc::new(Float64Array::from(vec![1.0]));
300        let lng_array1 = Arc::new(Float64Array::from(vec![4.0]));
301        let ts_array1 = Arc::new(TimestampNanosecondArray::from(vec![100]));
302
303        // Create test data for second accumulator
304        let lat_array2 = Arc::new(Float64Array::from(vec![2.0]));
305        let lng_array2 = Arc::new(Float64Array::from(vec![5.0]));
306        let ts_array2 = Arc::new(TimestampNanosecondArray::from(vec![200]));
307
308        // Update batches
309        accumulator1
310            .update_batch(&[lat_array1, lng_array1, ts_array1])
311            .unwrap();
312        accumulator2
313            .update_batch(&[lat_array2, lng_array2, ts_array2])
314            .unwrap();
315
316        // Get states
317        let state1 = accumulator1.state().unwrap();
318        let state2 = accumulator2.state().unwrap();
319
320        // Create a merged accumulator
321        let mut merged = JsonEncodePathAccumulator::new();
322
323        // Extract the struct arrays from the states
324        let state_array1 = match &state1[0] {
325            ScalarValue::Struct(array) => array.clone(),
326            _ => panic!("Expected Struct scalar value"),
327        };
328
329        let state_array2 = match &state2[0] {
330            ScalarValue::Struct(array) => array.clone(),
331            _ => panic!("Expected Struct scalar value"),
332        };
333
334        // Merge state arrays
335        merged.merge_batch(&[state_array1]).unwrap();
336        merged.merge_batch(&[state_array2]).unwrap();
337
338        // Evaluate merged result
339        let result = merged.evaluate().unwrap();
340        assert_eq!(
341            result,
342            ScalarValue::Utf8(Some("[[4.0,1.0],[5.0,2.0]]".to_string()))
343        );
344    }
345}