common_function/scalars/geo/
relation.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use datafusion_common::arrow::array::{Array, AsArray, BooleanBuilder};
18use datafusion_common::arrow::compute;
19use datafusion_common::arrow::datatypes::DataType;
20use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
21use derive_more::Display;
22use geo::algorithm::contains::Contains;
23use geo::algorithm::intersects::Intersects;
24use geo::algorithm::within::Within;
25use geo_types::Geometry;
26
27use crate::function::{Function, extract_args};
28use crate::scalars::geo::wkt::parse_wkt;
29
30/// Test if spatial relationship: contains
31#[derive(Clone, Debug, Display)]
32#[display("{}", self.name())]
33pub(crate) struct STContains {
34    signature: Signature,
35}
36
37impl Default for STContains {
38    fn default() -> Self {
39        Self {
40            signature: Signature::string(2, Volatility::Stable),
41        }
42    }
43}
44
45impl StFunction for STContains {
46    const NAME: &'static str = "st_contains";
47
48    fn signature(&self) -> &Signature {
49        &self.signature
50    }
51
52    fn invoke(g1: Geometry, g2: Geometry) -> bool {
53        g1.contains(&g2)
54    }
55}
56
57/// Test if spatial relationship: within
58#[derive(Clone, Debug, Display)]
59#[display("{}", self.name())]
60pub(crate) struct STWithin {
61    signature: Signature,
62}
63
64impl Default for STWithin {
65    fn default() -> Self {
66        Self {
67            signature: Signature::string(2, Volatility::Stable),
68        }
69    }
70}
71
72impl StFunction for STWithin {
73    const NAME: &'static str = "st_within";
74
75    fn signature(&self) -> &Signature {
76        &self.signature
77    }
78
79    fn invoke(g1: Geometry, g2: Geometry) -> bool {
80        g1.is_within(&g2)
81    }
82}
83
84/// Test if spatial relationship: within
85#[derive(Clone, Debug, Display)]
86#[display("{}", self.name())]
87pub(crate) struct STIntersects {
88    signature: Signature,
89}
90
91impl Default for STIntersects {
92    fn default() -> Self {
93        Self {
94            signature: Signature::string(2, Volatility::Stable),
95        }
96    }
97}
98
99impl StFunction for STIntersects {
100    const NAME: &'static str = "st_intersects";
101
102    fn signature(&self) -> &Signature {
103        &self.signature
104    }
105
106    fn invoke(g1: Geometry, g2: Geometry) -> bool {
107        g1.intersects(&g2)
108    }
109}
110
111trait StFunction {
112    const NAME: &'static str;
113
114    fn signature(&self) -> &Signature;
115
116    fn invoke(g1: Geometry, g2: Geometry) -> bool;
117}
118
119impl<T: StFunction + Display + Send + Sync> Function for T {
120    fn name(&self) -> &str {
121        T::NAME
122    }
123
124    fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
125        Ok(DataType::Boolean)
126    }
127
128    fn signature(&self) -> &Signature {
129        self.signature()
130    }
131
132    fn invoke_with_args(
133        &self,
134        args: ScalarFunctionArgs,
135    ) -> datafusion_common::Result<ColumnarValue> {
136        let [arg0, arg1] = extract_args(self.name(), &args)?;
137
138        let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
139        let wkt_this_vec = arg0.as_string_view();
140        let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
141        let wkt_that_vec = arg1.as_string_view();
142
143        let size = wkt_this_vec.len();
144        let mut builder = BooleanBuilder::with_capacity(size);
145
146        for i in 0..size {
147            let wkt_this = wkt_this_vec.is_valid(i).then(|| wkt_this_vec.value(i));
148            let wkt_that = wkt_that_vec.is_valid(i).then(|| wkt_that_vec.value(i));
149
150            let result = match (wkt_this, wkt_that) {
151                (Some(wkt_this), Some(wkt_that)) => {
152                    Some(T::invoke(parse_wkt(wkt_this)?, parse_wkt(wkt_that)?))
153                }
154                _ => None,
155            };
156
157            builder.append_option(result);
158        }
159
160        Ok(ColumnarValue::Array(Arc::new(builder.finish())))
161    }
162}