common_function/scalars/geo/
geohash.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16use std::sync::Arc;
17
18use common_error::ext::{BoxedError, PlainError};
19use common_error::status_code::StatusCode;
20use common_query::error;
21use datafusion::arrow::array::{Array, AsArray, ListBuilder, StringViewBuilder};
22use datafusion::arrow::datatypes::{DataType, Field, Float64Type, UInt8Type};
23use datafusion::logical_expr::ColumnarValue;
24use datafusion_common::DataFusionError;
25use datafusion_expr::type_coercion::aggregates::INTEGERS;
26use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
27use geohash::Coord;
28use snafu::ResultExt;
29
30use crate::function::{Function, extract_args};
31use crate::scalars::geo::helpers;
32
33fn ensure_resolution_usize(v: u8) -> datafusion_common::Result<usize> {
34    if v == 0 || v > 12 {
35        return Err(DataFusionError::Execution(format!(
36            "Invalid geohash resolution {v}, valid value range: [1, 12]"
37        )));
38    }
39    Ok(v as usize)
40}
41
42/// Function that return geohash string for a given geospatial coordinate.
43#[derive(Clone, Debug)]
44pub(crate) struct GeohashFunction {
45    signature: Signature,
46}
47
48impl Default for GeohashFunction {
49    fn default() -> Self {
50        let mut signatures = Vec::new();
51        for coord_type in &[DataType::Float32, DataType::Float64] {
52            for resolution_type in INTEGERS {
53                signatures.push(TypeSignature::Exact(vec![
54                    // latitude
55                    coord_type.clone(),
56                    // longitude
57                    coord_type.clone(),
58                    // resolution
59                    resolution_type.clone(),
60                ]));
61            }
62        }
63        Self {
64            signature: Signature::one_of(signatures, Volatility::Stable),
65        }
66    }
67}
68
69impl GeohashFunction {
70    const NAME: &'static str = "geohash";
71}
72
73impl Function for GeohashFunction {
74    fn name(&self) -> &str {
75        Self::NAME
76    }
77
78    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
79        Ok(DataType::Utf8)
80    }
81
82    fn signature(&self) -> &Signature {
83        &self.signature
84    }
85
86    fn invoke_with_args(
87        &self,
88        args: ScalarFunctionArgs,
89    ) -> datafusion_common::Result<ColumnarValue> {
90        let [lat_vec, lon_vec, resolutions] = extract_args(self.name(), &args)?;
91
92        let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
93        let lat_vec = lat_vec.as_primitive::<Float64Type>();
94        let lon_vec = helpers::cast::<Float64Type>(&lon_vec)?;
95        let lon_vec = lon_vec.as_primitive::<Float64Type>();
96        let resolutions = helpers::cast::<UInt8Type>(&resolutions)?;
97        let resolutions = resolutions.as_primitive::<UInt8Type>();
98
99        let size = lat_vec.len();
100        let mut builder = StringViewBuilder::with_capacity(size);
101
102        for i in 0..size {
103            let lat = lat_vec.is_valid(i).then(|| lat_vec.value(i));
104            let lon = lon_vec.is_valid(i).then(|| lon_vec.value(i));
105            let r = resolutions
106                .is_valid(i)
107                .then(|| ensure_resolution_usize(resolutions.value(i)))
108                .transpose()?;
109
110            let result = match (lat, lon, r) {
111                (Some(lat), Some(lon), Some(r)) => {
112                    let coord = Coord { x: lon, y: lat };
113                    let encoded = geohash::encode(coord, r)
114                        .map_err(|e| {
115                            BoxedError::new(PlainError::new(
116                                format!("Geohash error: {}", e),
117                                StatusCode::EngineExecuteQuery,
118                            ))
119                        })
120                        .context(error::ExecuteSnafu)?;
121                    Some(encoded)
122                }
123                _ => None,
124            };
125
126            builder.append_option(result);
127        }
128
129        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
130    }
131}
132
133impl fmt::Display for GeohashFunction {
134    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
135        write!(f, "{}", Self::NAME)
136    }
137}
138
139/// Function that return geohash string for a given geospatial coordinate.
140#[derive(Clone, Debug)]
141pub(crate) struct GeohashNeighboursFunction {
142    signature: Signature,
143}
144
145impl Default for GeohashNeighboursFunction {
146    fn default() -> Self {
147        let mut signatures = Vec::new();
148        for coord_type in &[DataType::Float32, DataType::Float64] {
149            for resolution_type in INTEGERS {
150                signatures.push(TypeSignature::Exact(vec![
151                    // latitude
152                    coord_type.clone(),
153                    // longitude
154                    coord_type.clone(),
155                    // resolution
156                    resolution_type.clone(),
157                ]));
158            }
159        }
160        Self {
161            signature: Signature::one_of(signatures, Volatility::Stable),
162        }
163    }
164}
165
166impl GeohashNeighboursFunction {
167    const NAME: &'static str = "geohash_neighbours";
168}
169
170impl Function for GeohashNeighboursFunction {
171    fn name(&self) -> &str {
172        GeohashNeighboursFunction::NAME
173    }
174
175    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
176        Ok(DataType::List(Arc::new(Field::new(
177            "item",
178            DataType::Utf8View,
179            false,
180        ))))
181    }
182
183    fn signature(&self) -> &Signature {
184        &self.signature
185    }
186
187    fn invoke_with_args(
188        &self,
189        args: ScalarFunctionArgs,
190    ) -> datafusion_common::Result<ColumnarValue> {
191        let [lat_vec, lon_vec, resolutions] = extract_args(self.name(), &args)?;
192
193        let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
194        let lat_vec = lat_vec.as_primitive::<Float64Type>();
195        let lon_vec = helpers::cast::<Float64Type>(&lon_vec)?;
196        let lon_vec = lon_vec.as_primitive::<Float64Type>();
197        let resolutions = helpers::cast::<UInt8Type>(&resolutions)?;
198        let resolutions = resolutions.as_primitive::<UInt8Type>();
199
200        let size = lat_vec.len();
201        let mut builder = ListBuilder::new(StringViewBuilder::new());
202
203        for i in 0..size {
204            let lat = lat_vec.is_valid(i).then(|| lat_vec.value(i));
205            let lon = lon_vec.is_valid(i).then(|| lon_vec.value(i));
206            let r = resolutions
207                .is_valid(i)
208                .then(|| ensure_resolution_usize(resolutions.value(i)))
209                .transpose()?;
210
211            match (lat, lon, r) {
212                (Some(lat), Some(lon), Some(r)) => {
213                    let coord = Coord { x: lon, y: lat };
214                    let encoded = geohash::encode(coord, r)
215                        .map_err(|e| {
216                            BoxedError::new(PlainError::new(
217                                format!("Geohash error: {}", e),
218                                StatusCode::EngineExecuteQuery,
219                            ))
220                        })
221                        .context(error::ExecuteSnafu)?;
222                    let neighbours = geohash::neighbors(&encoded)
223                        .map_err(|e| {
224                            BoxedError::new(PlainError::new(
225                                format!("Geohash error: {}", e),
226                                StatusCode::EngineExecuteQuery,
227                            ))
228                        })
229                        .context(error::ExecuteSnafu)?;
230                    builder.append_value(
231                        [
232                            neighbours.n,
233                            neighbours.nw,
234                            neighbours.w,
235                            neighbours.sw,
236                            neighbours.s,
237                            neighbours.se,
238                            neighbours.e,
239                            neighbours.ne,
240                        ]
241                        .into_iter()
242                        .map(Some),
243                    );
244                }
245                _ => builder.append_null(),
246            };
247        }
248
249        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
250    }
251}
252
253impl fmt::Display for GeohashNeighboursFunction {
254    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
255        write!(f, "{}", GeohashNeighboursFunction::NAME)
256    }
257}