common_function/aggrs/geo/
encoding.rs1use std::sync::Arc;
16
17use arrow::array::AsArray;
18use datafusion::arrow::array::{Array, ArrayRef};
19use datafusion::common::cast::as_primitive_array;
20use datafusion::error::{DataFusionError, Result as DfResult};
21use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF, Volatility};
22use datafusion::prelude::create_udaf;
23use datafusion_common::cast::{as_list_array, as_struct_array};
24use datafusion_common::ScalarValue;
25use datatypes::arrow::array::{Float64Array, Int64Array, ListArray, StructArray};
26use datatypes::arrow::datatypes::{
27 DataType, Field, Float64Type, Int64Type, TimeUnit, TimestampNanosecondType,
28};
29use datatypes::compute::{self, sort_to_indices};
30
31pub const JSON_ENCODE_PATH_NAME: &str = "json_encode_path";
32
33const LATITUDE_FIELD: &str = "lat";
34const LONGITUDE_FIELD: &str = "lng";
35const TIMESTAMP_FIELD: &str = "timestamp";
36const DEFAULT_LIST_FIELD_NAME: &str = "item";
37
38#[derive(Debug, Default)]
39pub struct JsonEncodePathAccumulator {
40 lat: Vec<Option<f64>>,
41 lng: Vec<Option<f64>>,
42 timestamp: Vec<Option<i64>>,
43}
44
45impl JsonEncodePathAccumulator {
46 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub fn uadf_impl() -> AggregateUDF {
51 create_udaf(
52 JSON_ENCODE_PATH_NAME,
53 vec![
55 DataType::Float64,
56 DataType::Float64,
57 DataType::Timestamp(TimeUnit::Nanosecond, None),
58 ],
59 Arc::new(DataType::Utf8),
61 Volatility::Immutable,
62 Arc::new(|_| Ok(Box::new(Self::new()))),
64 Arc::new(vec![DataType::Struct(
66 vec![
67 Field::new(
68 LATITUDE_FIELD,
69 DataType::List(Arc::new(Field::new(
70 DEFAULT_LIST_FIELD_NAME,
71 DataType::Float64,
72 true,
73 ))),
74 false,
75 ),
76 Field::new(
77 LONGITUDE_FIELD,
78 DataType::List(Arc::new(Field::new(
79 DEFAULT_LIST_FIELD_NAME,
80 DataType::Float64,
81 true,
82 ))),
83 false,
84 ),
85 Field::new(
86 TIMESTAMP_FIELD,
87 DataType::List(Arc::new(Field::new(
88 DEFAULT_LIST_FIELD_NAME,
89 DataType::Int64,
90 true,
91 ))),
92 false,
93 ),
94 ]
95 .into(),
96 )]),
97 )
98 }
99}
100
101impl DfAccumulator for JsonEncodePathAccumulator {
102 fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion::error::Result<()> {
103 if values.len() != 3 {
104 return Err(DataFusionError::Internal(format!(
105 "Expected 3 columns for json_encode_path, got {}",
106 values.len()
107 )));
108 }
109
110 let lat_array = as_primitive_array::<Float64Type>(&values[0])?;
111 let lng_array = as_primitive_array::<Float64Type>(&values[1])?;
112 let ts_array = as_primitive_array::<TimestampNanosecondType>(&values[2])?;
113
114 let size = lat_array.len();
115 self.lat.reserve(size);
116 self.lng.reserve(size);
117
118 for idx in 0..size {
119 self.lat.push(if lat_array.is_null(idx) {
120 None
121 } else {
122 Some(lat_array.value(idx))
123 });
124
125 self.lng.push(if lng_array.is_null(idx) {
126 None
127 } else {
128 Some(lng_array.value(idx))
129 });
130
131 self.timestamp.push(if ts_array.is_null(idx) {
132 None
133 } else {
134 Some(ts_array.value(idx))
135 });
136 }
137
138 Ok(())
139 }
140
141 fn evaluate(&mut self) -> DfResult<ScalarValue> {
142 let unordered_lng_array = Float64Array::from(self.lng.clone());
143 let unordered_lat_array = Float64Array::from(self.lat.clone());
144 let ts_array = Int64Array::from(self.timestamp.clone());
145
146 let ordered_indices = sort_to_indices(&ts_array, None, None)?;
147 let lat_array = compute::take(&unordered_lat_array, &ordered_indices, None)?;
148 let lng_array = compute::take(&unordered_lng_array, &ordered_indices, None)?;
149
150 let len = ts_array.len();
151 let lat_array = lat_array.as_primitive::<Float64Type>();
152 let lng_array = lng_array.as_primitive::<Float64Type>();
153
154 let mut coords = Vec::with_capacity(len);
155 for i in 0..len {
156 let lng = lng_array.value(i);
157 let lat = lat_array.value(i);
158 coords.push(vec![lng, lat]);
159 }
160
161 let result = serde_json::to_string(&coords)
162 .map_err(|e| DataFusionError::Execution(format!("Failed to encode json, {}", e)))?;
163
164 Ok(ScalarValue::Utf8(Some(result)))
165 }
166
167 fn size(&self) -> usize {
168 let mut total_size = std::mem::size_of::<Self>();
170
171 total_size += self.lat.capacity() * std::mem::size_of::<Option<f64>>();
173 total_size += self.lng.capacity() * std::mem::size_of::<Option<f64>>();
174 total_size += self.timestamp.capacity() * std::mem::size_of::<Option<i64>>();
175
176 total_size
177 }
178
179 fn state(&mut self) -> datafusion::error::Result<Vec<ScalarValue>> {
180 let lat_array = Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
181 Some(self.lat.clone()),
182 ]));
183 let lng_array = Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
184 Some(self.lng.clone()),
185 ]));
186 let ts_array = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
187 Some(self.timestamp.clone()),
188 ]));
189
190 let state_struct = StructArray::new(
191 vec![
192 Field::new(
193 LATITUDE_FIELD,
194 DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
195 false,
196 ),
197 Field::new(
198 LONGITUDE_FIELD,
199 DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
200 false,
201 ),
202 Field::new(
203 TIMESTAMP_FIELD,
204 DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
205 false,
206 ),
207 ]
208 .into(),
209 vec![lat_array, lng_array, ts_array],
210 None,
211 );
212
213 Ok(vec![ScalarValue::Struct(Arc::new(state_struct))])
214 }
215
216 fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion::error::Result<()> {
217 if states.len() != 1 {
218 return Err(DataFusionError::Internal(format!(
219 "Expected 1 states for json_encode_path, got {}",
220 states.len()
221 )));
222 }
223
224 for state in states {
225 let state = as_struct_array(state)?;
226 let lat_list = as_list_array(state.column(0))?.value(0);
227 let lat_array = as_primitive_array::<Float64Type>(&lat_list)?;
228 let lng_list = as_list_array(state.column(1))?.value(0);
229 let lng_array = as_primitive_array::<Float64Type>(&lng_list)?;
230 let ts_list = as_list_array(state.column(2))?.value(0);
231 let ts_array = as_primitive_array::<Int64Type>(&ts_list)?;
232
233 self.lat.extend(lat_array);
234 self.lng.extend(lng_array);
235 self.timestamp.extend(ts_array);
236 }
237
238 Ok(())
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use datafusion::arrow::array::{Float64Array, TimestampNanosecondArray};
245 use datafusion::scalar::ScalarValue;
246
247 use super::*;
248
249 #[test]
250 fn test_json_encode_path_basic() {
251 let mut accumulator = JsonEncodePathAccumulator::new();
252
253 let lat_array = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]));
255 let lng_array = Arc::new(Float64Array::from(vec![4.0, 5.0, 6.0]));
256 let ts_array = Arc::new(TimestampNanosecondArray::from(vec![100, 200, 300]));
257
258 accumulator
260 .update_batch(&[lat_array, lng_array, ts_array])
261 .unwrap();
262
263 let result = accumulator.evaluate().unwrap();
265 assert_eq!(
266 result,
267 ScalarValue::Utf8(Some("[[4.0,1.0],[5.0,2.0],[6.0,3.0]]".to_string()))
268 );
269 }
270
271 #[test]
272 fn test_json_encode_path_sort_by_timestamp() {
273 let mut accumulator = JsonEncodePathAccumulator::new();
274
275 let lat_array = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]));
277 let lng_array = Arc::new(Float64Array::from(vec![4.0, 5.0, 6.0]));
278 let ts_array = Arc::new(TimestampNanosecondArray::from(vec![300, 100, 200]));
279
280 accumulator
282 .update_batch(&[lat_array, lng_array, ts_array])
283 .unwrap();
284
285 let result = accumulator.evaluate().unwrap();
287 assert_eq!(
288 result,
289 ScalarValue::Utf8(Some("[[5.0,2.0],[6.0,3.0],[4.0,1.0]]".to_string()))
290 );
291 }
292
293 #[test]
294 fn test_json_encode_path_merge() {
295 let mut accumulator1 = JsonEncodePathAccumulator::new();
296 let mut accumulator2 = JsonEncodePathAccumulator::new();
297
298 let lat_array1 = Arc::new(Float64Array::from(vec![1.0]));
300 let lng_array1 = Arc::new(Float64Array::from(vec![4.0]));
301 let ts_array1 = Arc::new(TimestampNanosecondArray::from(vec![100]));
302
303 let lat_array2 = Arc::new(Float64Array::from(vec![2.0]));
305 let lng_array2 = Arc::new(Float64Array::from(vec![5.0]));
306 let ts_array2 = Arc::new(TimestampNanosecondArray::from(vec![200]));
307
308 accumulator1
310 .update_batch(&[lat_array1, lng_array1, ts_array1])
311 .unwrap();
312 accumulator2
313 .update_batch(&[lat_array2, lng_array2, ts_array2])
314 .unwrap();
315
316 let state1 = accumulator1.state().unwrap();
318 let state2 = accumulator2.state().unwrap();
319
320 let mut merged = JsonEncodePathAccumulator::new();
322
323 let state_array1 = match &state1[0] {
325 ScalarValue::Struct(array) => array.clone(),
326 _ => panic!("Expected Struct scalar value"),
327 };
328
329 let state_array2 = match &state2[0] {
330 ScalarValue::Struct(array) => array.clone(),
331 _ => panic!("Expected Struct scalar value"),
332 };
333
334 merged.merge_batch(&[state_array1]).unwrap();
336 merged.merge_batch(&[state_array2]).unwrap();
337
338 let result = merged.evaluate().unwrap();
340 assert_eq!(
341 result,
342 ScalarValue::Utf8(Some("[[4.0,1.0],[5.0,2.0]]".to_string()))
343 );
344 }
345}