common_function/aggrs/geo/
encoding.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_error::ext::{BoxedError, PlainError};
18use common_error::status_code::StatusCode;
19use common_macro::{as_aggr_func_creator, AggrFuncTypeStore};
20use common_query::error::{self, InvalidInputStateSnafu, Result};
21use common_query::logical_plan::accumulator::AggrFuncTypeStore;
22use common_query::logical_plan::{
23    create_aggregate_function, Accumulator, AggregateFunctionCreator,
24};
25use common_query::prelude::AccumulatorCreatorFunction;
26use common_time::Timestamp;
27use datafusion_expr::AggregateUDF;
28use datatypes::prelude::ConcreteDataType;
29use datatypes::value::{ListValue, Value};
30use datatypes::vectors::VectorRef;
31use snafu::{ensure, ResultExt};
32
33use crate::scalars::geo::helpers::{ensure_columns_len, ensure_columns_n};
34
35/// Accumulator of lat, lng, timestamp tuples
36#[derive(Debug)]
37pub struct JsonPathAccumulator {
38    timestamp_type: ConcreteDataType,
39    lat: Vec<Option<f64>>,
40    lng: Vec<Option<f64>>,
41    timestamp: Vec<Option<Timestamp>>,
42}
43
44impl JsonPathAccumulator {
45    fn new(timestamp_type: ConcreteDataType) -> Self {
46        Self {
47            lat: Vec::default(),
48            lng: Vec::default(),
49            timestamp: Vec::default(),
50            timestamp_type,
51        }
52    }
53
54    /// Create a new `AggregateUDF` for the `json_encode_path` aggregate function.
55    pub fn uadf_impl() -> AggregateUDF {
56        create_aggregate_function(
57            "json_encode_path".to_string(),
58            3,
59            Arc::new(JsonPathEncodeFunctionCreator::default()),
60        )
61        .into()
62    }
63}
64
65impl Accumulator for JsonPathAccumulator {
66    fn state(&self) -> Result<Vec<Value>> {
67        Ok(vec![
68            Value::List(ListValue::new(
69                self.lat.iter().map(|i| Value::from(*i)).collect(),
70                ConcreteDataType::float64_datatype(),
71            )),
72            Value::List(ListValue::new(
73                self.lng.iter().map(|i| Value::from(*i)).collect(),
74                ConcreteDataType::float64_datatype(),
75            )),
76            Value::List(ListValue::new(
77                self.timestamp.iter().map(|i| Value::from(*i)).collect(),
78                self.timestamp_type.clone(),
79            )),
80        ])
81    }
82
83    fn update_batch(&mut self, columns: &[VectorRef]) -> Result<()> {
84        // update batch as in datafusion just provides the accumulator original
85        //  input.
86        //
87        // columns is vec of [`lat`, `lng`, `timestamp`]
88        // where
89        // - `lat` is a vector of `Value::Float64` or similar type. Each item in
90        //  the vector is a row in given dataset.
91        // - so on so forth for `lng` and `timestamp`
92        ensure_columns_n!(columns, 3);
93
94        let lat = &columns[0];
95        let lng = &columns[1];
96        let ts = &columns[2];
97
98        let size = lat.len();
99
100        for idx in 0..size {
101            self.lat.push(lat.get(idx).as_f64_lossy());
102            self.lng.push(lng.get(idx).as_f64_lossy());
103            self.timestamp.push(ts.get(idx).as_timestamp());
104        }
105
106        Ok(())
107    }
108
109    fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()> {
110        // merge batch as in datafusion gives state accumulated from the data
111        //  returned from child accumulators' state() call
112        // In our particular implementation, the data structure is like
113        //
114        // states is vec of [`lat`, `lng`, `timestamp`]
115        // where
116        // - `lat` is a vector of `Value::List`. Each item in the list is all
117        //  coordinates from a child accumulator.
118        // - so on so forth for `lng` and `timestamp`
119
120        ensure_columns_n!(states, 3);
121
122        let lat_lists = &states[0];
123        let lng_lists = &states[1];
124        let ts_lists = &states[2];
125
126        let len = lat_lists.len();
127
128        for idx in 0..len {
129            if let Some(lat_list) = lat_lists
130                .get(idx)
131                .as_list()
132                .map_err(BoxedError::new)
133                .context(error::ExecuteSnafu)?
134            {
135                for v in lat_list.items() {
136                    self.lat.push(v.as_f64_lossy());
137                }
138            }
139
140            if let Some(lng_list) = lng_lists
141                .get(idx)
142                .as_list()
143                .map_err(BoxedError::new)
144                .context(error::ExecuteSnafu)?
145            {
146                for v in lng_list.items() {
147                    self.lng.push(v.as_f64_lossy());
148                }
149            }
150
151            if let Some(ts_list) = ts_lists
152                .get(idx)
153                .as_list()
154                .map_err(BoxedError::new)
155                .context(error::ExecuteSnafu)?
156            {
157                for v in ts_list.items() {
158                    self.timestamp.push(v.as_timestamp());
159                }
160            }
161        }
162
163        Ok(())
164    }
165
166    fn evaluate(&self) -> Result<Value> {
167        let mut work_vec: Vec<(&Option<f64>, &Option<f64>, &Option<Timestamp>)> = self
168            .lat
169            .iter()
170            .zip(self.lng.iter())
171            .zip(self.timestamp.iter())
172            .map(|((a, b), c)| (a, b, c))
173            .collect();
174
175        // sort by timestamp, we treat null timestamp as 0
176        work_vec.sort_unstable_by_key(|tuple| tuple.2.unwrap_or_else(|| Timestamp::new_second(0)));
177
178        let result = serde_json::to_string(
179            &work_vec
180                .into_iter()
181                // note that we transform to lng,lat for geojson compatibility
182                .map(|(lat, lng, _)| vec![lng, lat])
183                .collect::<Vec<Vec<&Option<f64>>>>(),
184        )
185        .map_err(|e| {
186            BoxedError::new(PlainError::new(
187                format!("Serialization failure: {}", e),
188                StatusCode::EngineExecuteQuery,
189            ))
190        })
191        .context(error::ExecuteSnafu)?;
192
193        Ok(Value::String(result.into()))
194    }
195}
196
197/// This function accept rows of lat, lng and timestamp, sort with timestamp and
198/// encoding them into a geojson-like path.
199///
200/// Example:
201///
202/// ```sql
203/// SELECT json_encode_path(lat, lon, timestamp) FROM table [group by ...];
204/// ```
205///
206#[as_aggr_func_creator]
207#[derive(Debug, Default, AggrFuncTypeStore)]
208pub struct JsonPathEncodeFunctionCreator {}
209
210impl AggregateFunctionCreator for JsonPathEncodeFunctionCreator {
211    fn creator(&self) -> AccumulatorCreatorFunction {
212        let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
213            let ts_type = types[2].clone();
214            Ok(Box::new(JsonPathAccumulator::new(ts_type)))
215        });
216
217        creator
218    }
219
220    fn output_type(&self) -> Result<ConcreteDataType> {
221        Ok(ConcreteDataType::string_datatype())
222    }
223
224    fn state_types(&self) -> Result<Vec<ConcreteDataType>> {
225        let input_types = self.input_types()?;
226        ensure!(input_types.len() == 3, InvalidInputStateSnafu);
227
228        let timestamp_type = input_types[2].clone();
229
230        Ok(vec![
231            ConcreteDataType::list_datatype(ConcreteDataType::float64_datatype()),
232            ConcreteDataType::list_datatype(ConcreteDataType::float64_datatype()),
233            ConcreteDataType::list_datatype(timestamp_type),
234        ])
235    }
236}