common_function/scalars/geo/
relation.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16use std::sync::Arc;
17
18use datafusion_common::arrow::array::{Array, AsArray, BooleanBuilder};
19use datafusion_common::arrow::compute;
20use datafusion_common::arrow::datatypes::DataType;
21use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
22use derive_more::Display;
23use geo::algorithm::contains::Contains;
24use geo::algorithm::intersects::Intersects;
25use geo::algorithm::within::Within;
26use geo_types::Geometry;
27
28use crate::function::{Function, extract_args};
29use crate::scalars::geo::wkt::parse_wkt;
30
31/// Test if spatial relationship: contains
32#[derive(Clone, Debug, Display)]
33#[display("{}", self.name())]
34pub(crate) struct STContains {
35    signature: Signature,
36}
37
38impl Default for STContains {
39    fn default() -> Self {
40        Self {
41            signature: Signature::string(2, Volatility::Stable),
42        }
43    }
44}
45
46impl StFunction for STContains {
47    const NAME: &'static str = "st_contains";
48
49    fn signature(&self) -> &Signature {
50        &self.signature
51    }
52
53    fn invoke(g1: Geometry, g2: Geometry) -> bool {
54        g1.contains(&g2)
55    }
56}
57
58/// Test if spatial relationship: within
59#[derive(Clone, Debug, Display)]
60#[display("{}", self.name())]
61pub(crate) struct STWithin {
62    signature: Signature,
63}
64
65impl Default for STWithin {
66    fn default() -> Self {
67        Self {
68            signature: Signature::string(2, Volatility::Stable),
69        }
70    }
71}
72
73impl StFunction for STWithin {
74    const NAME: &'static str = "st_within";
75
76    fn signature(&self) -> &Signature {
77        &self.signature
78    }
79
80    fn invoke(g1: Geometry, g2: Geometry) -> bool {
81        g1.is_within(&g2)
82    }
83}
84
85/// Test if spatial relationship: within
86#[derive(Clone, Debug, Display)]
87#[display("{}", self.name())]
88pub(crate) struct STIntersects {
89    signature: Signature,
90}
91
92impl Default for STIntersects {
93    fn default() -> Self {
94        Self {
95            signature: Signature::string(2, Volatility::Stable),
96        }
97    }
98}
99
100impl StFunction for STIntersects {
101    const NAME: &'static str = "st_intersects";
102
103    fn signature(&self) -> &Signature {
104        &self.signature
105    }
106
107    fn invoke(g1: Geometry, g2: Geometry) -> bool {
108        g1.intersects(&g2)
109    }
110}
111
112trait StFunction {
113    const NAME: &'static str;
114
115    fn signature(&self) -> &Signature;
116
117    fn invoke(g1: Geometry, g2: Geometry) -> bool;
118}
119
120impl<T: StFunction + Display + Send + Sync> Function for T {
121    fn name(&self) -> &str {
122        T::NAME
123    }
124
125    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
126        Ok(DataType::Boolean)
127    }
128
129    fn signature(&self) -> &Signature {
130        self.signature()
131    }
132
133    fn invoke_with_args(
134        &self,
135        args: ScalarFunctionArgs,
136    ) -> datafusion_common::Result<ColumnarValue> {
137        let [arg0, arg1] = extract_args(self.name(), &args)?;
138
139        let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
140        let wkt_this_vec = arg0.as_string_view();
141        let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
142        let wkt_that_vec = arg1.as_string_view();
143
144        let size = wkt_this_vec.len();
145        let mut builder = BooleanBuilder::with_capacity(size);
146
147        for i in 0..size {
148            let wkt_this = wkt_this_vec.is_valid(i).then(|| wkt_this_vec.value(i));
149            let wkt_that = wkt_that_vec.is_valid(i).then(|| wkt_that_vec.value(i));
150
151            let result = match (wkt_this, wkt_that) {
152                (Some(wkt_this), Some(wkt_that)) => {
153                    Some(T::invoke(parse_wkt(wkt_this)?, parse_wkt(wkt_that)?))
154                }
155                _ => None,
156            };
157
158            builder.append_option(result);
159        }
160
161        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
162    }
163}