common_function/scalars/geo/
relation.rs1use std::fmt::Display;
16use std::sync::Arc;
17
18use datafusion_common::arrow::array::{Array, AsArray, BooleanBuilder};
19use datafusion_common::arrow::compute;
20use datafusion_common::arrow::datatypes::DataType;
21use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
22use derive_more::Display;
23use geo::algorithm::contains::Contains;
24use geo::algorithm::intersects::Intersects;
25use geo::algorithm::within::Within;
26use geo_types::Geometry;
27
28use crate::function::{Function, extract_args};
29use crate::scalars::geo::wkt::parse_wkt;
30
31#[derive(Clone, Debug, Display)]
33#[display("{}", self.name())]
34pub(crate) struct STContains {
35 signature: Signature,
36}
37
38impl Default for STContains {
39 fn default() -> Self {
40 Self {
41 signature: Signature::string(2, Volatility::Stable),
42 }
43 }
44}
45
46impl StFunction for STContains {
47 const NAME: &'static str = "st_contains";
48
49 fn signature(&self) -> &Signature {
50 &self.signature
51 }
52
53 fn invoke(g1: Geometry, g2: Geometry) -> bool {
54 g1.contains(&g2)
55 }
56}
57
58#[derive(Clone, Debug, Display)]
60#[display("{}", self.name())]
61pub(crate) struct STWithin {
62 signature: Signature,
63}
64
65impl Default for STWithin {
66 fn default() -> Self {
67 Self {
68 signature: Signature::string(2, Volatility::Stable),
69 }
70 }
71}
72
73impl StFunction for STWithin {
74 const NAME: &'static str = "st_within";
75
76 fn signature(&self) -> &Signature {
77 &self.signature
78 }
79
80 fn invoke(g1: Geometry, g2: Geometry) -> bool {
81 g1.is_within(&g2)
82 }
83}
84
85#[derive(Clone, Debug, Display)]
87#[display("{}", self.name())]
88pub(crate) struct STIntersects {
89 signature: Signature,
90}
91
92impl Default for STIntersects {
93 fn default() -> Self {
94 Self {
95 signature: Signature::string(2, Volatility::Stable),
96 }
97 }
98}
99
100impl StFunction for STIntersects {
101 const NAME: &'static str = "st_intersects";
102
103 fn signature(&self) -> &Signature {
104 &self.signature
105 }
106
107 fn invoke(g1: Geometry, g2: Geometry) -> bool {
108 g1.intersects(&g2)
109 }
110}
111
112trait StFunction {
113 const NAME: &'static str;
114
115 fn signature(&self) -> &Signature;
116
117 fn invoke(g1: Geometry, g2: Geometry) -> bool;
118}
119
120impl<T: StFunction + Display + Send + Sync> Function for T {
121 fn name(&self) -> &str {
122 T::NAME
123 }
124
125 fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
126 Ok(DataType::Boolean)
127 }
128
129 fn signature(&self) -> &Signature {
130 self.signature()
131 }
132
133 fn invoke_with_args(
134 &self,
135 args: ScalarFunctionArgs,
136 ) -> datafusion_common::Result<ColumnarValue> {
137 let [arg0, arg1] = extract_args(self.name(), &args)?;
138
139 let arg0 = compute::cast(&arg0, &DataType::Utf8View)?;
140 let wkt_this_vec = arg0.as_string_view();
141 let arg1 = compute::cast(&arg1, &DataType::Utf8View)?;
142 let wkt_that_vec = arg1.as_string_view();
143
144 let size = wkt_this_vec.len();
145 let mut builder = BooleanBuilder::with_capacity(size);
146
147 for i in 0..size {
148 let wkt_this = wkt_this_vec.is_valid(i).then(|| wkt_this_vec.value(i));
149 let wkt_that = wkt_that_vec.is_valid(i).then(|| wkt_that_vec.value(i));
150
151 let result = match (wkt_this, wkt_that) {
152 (Some(wkt_this), Some(wkt_that)) => {
153 Some(T::invoke(parse_wkt(wkt_this)?, parse_wkt(wkt_that)?))
154 }
155 _ => None,
156 };
157
158 builder.append_option(result);
159 }
160
161 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
162 }
163}