common_function/scalars/geo/
relation.rs1use std::sync::Arc;
16
17use datafusion_common::arrow::array::{Array, AsArray, BooleanBuilder};
18use datafusion_common::arrow::compute;
19use datafusion_common::arrow::datatypes::DataType;
20use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
21use derive_more::Display;
22use geo::algorithm::contains::Contains;
23use geo::algorithm::intersects::Intersects;
24use geo::algorithm::within::Within;
25use geo_types::Geometry;
26
27use crate::function::{Function, extract_args};
28use crate::scalars::geo::wkt::parse_wkt;
29
30#[derive(Clone, Debug, Display)]
32#[display("{}", self.name())]
33pub(crate) struct STContains {
34 signature: Signature,
35}
36
37impl Default for STContains {
38 fn default() -> Self {
39 Self {
40 signature: Signature::string(2, Volatility::Stable),
41 }
42 }
43}
44
45impl StFunction for STContains {
46 const NAME: &'static str = "st_contains";
47
48 fn signature(&self) -> &Signature {
49 &self.signature
50 }
51
52 fn invoke(g1: Geometry, g2: Geometry) -> bool {
53 g1.contains(&g2)
54 }
55}
56
57#[derive(Clone, Debug, Display)]
59#[display("{}", self.name())]
60pub(crate) struct STWithin {
61 signature: Signature,
62}
63
64impl Default for STWithin {
65 fn default() -> Self {
66 Self {
67 signature: Signature::string(2, Volatility::Stable),
68 }
69 }
70}
71
72impl StFunction for STWithin {
73 const NAME: &'static str = "st_within";
74
75 fn signature(&self) -> &Signature {
76 &self.signature
77 }
78
79 fn invoke(g1: Geometry, g2: Geometry) -> bool {
80 g1.is_within(&g2)
81 }
82}
83
84#[derive(Clone, Debug, Display)]
86#[display("{}", self.name())]
87pub(crate) struct STIntersects {
88 signature: Signature,
89}
90
91impl Default for STIntersects {
92 fn default() -> Self {
93 Self {
94 signature: Signature::string(2, Volatility::Stable),
95 }
96 }
97}
98
99impl StFunction for STIntersects {
100 const NAME: &'static str = "st_intersects";
101
102 fn signature(&self) -> &Signature {
103 &self.signature
104 }
105
106 fn invoke(g1: Geometry, g2: Geometry) -> bool {
107 g1.intersects(&g2)
108 }
109}
110
111trait StFunction {
112 const NAME: &'static str;
113
114 fn signature(&self) -> &Signature;
115
116 fn invoke(g1: Geometry, g2: Geometry) -> bool;
117}
118
119impl<T: StFunction + Display + Send + Sync> Function for T {
120 fn name(&self) -> &str {
121 T::NAME
122 }
123
124 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
125 Ok(DataType::Boolean)
126 }
127
128 fn signature(&self) -> &Signature {
129 self.signature()
130 }
131
132 fn invoke_with_args(
133 &self,
134 args: ScalarFunctionArgs,
135 ) -> datafusion_common::Result<ColumnarValue> {
136 let [arg0, arg1] = extract_args(self.name(), &args)?;
137
138 let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
139 let wkt_this_vec = arg0.as_string_view();
140 let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
141 let wkt_that_vec = arg1.as_string_view();
142
143 let size = wkt_this_vec.len();
144 let mut builder = BooleanBuilder::with_capacity(size);
145
146 for i in 0..size {
147 let wkt_this = wkt_this_vec.is_valid(i).then(|| wkt_this_vec.value(i));
148 let wkt_that = wkt_that_vec.is_valid(i).then(|| wkt_that_vec.value(i));
149
150 let result = match (wkt_this, wkt_that) {
151 (Some(wkt_this), Some(wkt_that)) => {
152 Some(T::invoke(parse_wkt(wkt_this)?, parse_wkt(wkt_that)?))
153 }
154 _ => None,
155 };
156
157 builder.append_option(result);
158 }
159
160 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
161 }
162}