datatypes/vectors/
eq.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_time::interval::IntervalUnit;
18
19use crate::data_type::DataType;
20use crate::types::{DurationType, TimeType, TimestampType};
21use crate::vectors::constant::ConstantVector;
22use crate::vectors::struct_vector::StructVector;
23use crate::vectors::{
24    BinaryVector, BooleanVector, DateVector, Decimal128Vector, DurationMicrosecondVector,
25    DurationMillisecondVector, DurationNanosecondVector, DurationSecondVector,
26    IntervalDayTimeVector, IntervalMonthDayNanoVector, IntervalYearMonthVector, ListVector,
27    PrimitiveVector, StringVector, TimeMicrosecondVector, TimeMillisecondVector,
28    TimeNanosecondVector, TimeSecondVector, TimestampMicrosecondVector, TimestampMillisecondVector,
29    TimestampNanosecondVector, TimestampSecondVector, Vector,
30};
31use crate::with_match_primitive_type_id;
32
33impl Eq for dyn Vector + '_ {}
34
35impl PartialEq for dyn Vector + '_ {
36    fn eq(&self, other: &dyn Vector) -> bool {
37        equal(self, other)
38    }
39}
40
41impl PartialEq<dyn Vector> for Arc<dyn Vector + '_> {
42    fn eq(&self, other: &dyn Vector) -> bool {
43        equal(&**self, other)
44    }
45}
46
47macro_rules! is_vector_eq {
48    ($VectorType: ident, $lhs: ident, $rhs: ident) => {{
49        let lhs = $lhs.as_any().downcast_ref::<$VectorType>().unwrap();
50        let rhs = $rhs.as_any().downcast_ref::<$VectorType>().unwrap();
51
52        lhs == rhs
53    }};
54}
55
56fn equal(lhs: &dyn Vector, rhs: &dyn Vector) -> bool {
57    if lhs.data_type() != rhs.data_type() || lhs.len() != rhs.len() {
58        return false;
59    }
60
61    if lhs.is_const() || rhs.is_const() {
62        // Length has been checked before, so we only need to compare inner
63        // vector here.
64        return equal(
65            &**lhs
66                .as_any()
67                .downcast_ref::<ConstantVector>()
68                .unwrap()
69                .inner(),
70            &**lhs
71                .as_any()
72                .downcast_ref::<ConstantVector>()
73                .unwrap()
74                .inner(),
75        );
76    }
77
78    use crate::data_type::ConcreteDataType::*;
79
80    let lhs_type = lhs.data_type();
81    match lhs.data_type() {
82        Null(_) => true,
83        Boolean(_) => is_vector_eq!(BooleanVector, lhs, rhs),
84        Binary(_) | Json(_) | Vector(_) => is_vector_eq!(BinaryVector, lhs, rhs),
85        String(_) => is_vector_eq!(StringVector, lhs, rhs),
86        Date(_) => is_vector_eq!(DateVector, lhs, rhs),
87        Timestamp(t) => match t {
88            TimestampType::Second(_) => {
89                is_vector_eq!(TimestampSecondVector, lhs, rhs)
90            }
91            TimestampType::Millisecond(_) => {
92                is_vector_eq!(TimestampMillisecondVector, lhs, rhs)
93            }
94            TimestampType::Microsecond(_) => {
95                is_vector_eq!(TimestampMicrosecondVector, lhs, rhs)
96            }
97            TimestampType::Nanosecond(_) => {
98                is_vector_eq!(TimestampNanosecondVector, lhs, rhs)
99            }
100        },
101        Interval(v) => match v.unit() {
102            IntervalUnit::YearMonth => {
103                is_vector_eq!(IntervalYearMonthVector, lhs, rhs)
104            }
105            IntervalUnit::DayTime => {
106                is_vector_eq!(IntervalDayTimeVector, lhs, rhs)
107            }
108            IntervalUnit::MonthDayNano => {
109                is_vector_eq!(IntervalMonthDayNanoVector, lhs, rhs)
110            }
111        },
112        List(_) => is_vector_eq!(ListVector, lhs, rhs),
113        Struct(_) => is_vector_eq!(StructVector, lhs, rhs),
114        UInt8(_) | UInt16(_) | UInt32(_) | UInt64(_) | Int8(_) | Int16(_) | Int32(_) | Int64(_)
115        | Float32(_) | Float64(_) | Dictionary(_) => {
116            with_match_primitive_type_id!(lhs_type.logical_type_id(), |$T| {
117                let lhs = lhs.as_any().downcast_ref::<PrimitiveVector<$T>>().unwrap();
118                let rhs = rhs.as_any().downcast_ref::<PrimitiveVector<$T>>().unwrap();
119
120                lhs == rhs
121            },
122            {
123                unreachable!("should not compare {} with {}", lhs.vector_type_name(), rhs.vector_type_name())
124            })
125        }
126
127        Time(t) => match t {
128            TimeType::Second(_) => {
129                is_vector_eq!(TimeSecondVector, lhs, rhs)
130            }
131            TimeType::Millisecond(_) => {
132                is_vector_eq!(TimeMillisecondVector, lhs, rhs)
133            }
134            TimeType::Microsecond(_) => {
135                is_vector_eq!(TimeMicrosecondVector, lhs, rhs)
136            }
137            TimeType::Nanosecond(_) => {
138                is_vector_eq!(TimeNanosecondVector, lhs, rhs)
139            }
140        },
141        Duration(d) => match d {
142            DurationType::Second(_) => {
143                is_vector_eq!(DurationSecondVector, lhs, rhs)
144            }
145            DurationType::Millisecond(_) => {
146                is_vector_eq!(DurationMillisecondVector, lhs, rhs)
147            }
148            DurationType::Microsecond(_) => {
149                is_vector_eq!(DurationMicrosecondVector, lhs, rhs)
150            }
151            DurationType::Nanosecond(_) => {
152                is_vector_eq!(DurationNanosecondVector, lhs, rhs)
153            }
154        },
155        Decimal128(_) => {
156            is_vector_eq!(Decimal128Vector, lhs, rhs)
157        }
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano};
164
165    use super::*;
166    use crate::vectors::{
167        list, DurationMicrosecondVector, DurationMillisecondVector, DurationNanosecondVector,
168        DurationSecondVector, Float32Vector, Float64Vector, Int16Vector, Int32Vector, Int64Vector,
169        Int8Vector, NullVector, UInt16Vector, UInt32Vector, UInt64Vector, UInt8Vector, VectorRef,
170    };
171
172    fn assert_vector_ref_eq(vector: VectorRef) {
173        let rhs = vector.clone();
174        assert_eq!(vector, rhs);
175        assert_dyn_vector_eq(&*vector, &*rhs);
176    }
177
178    fn assert_dyn_vector_eq(lhs: &dyn Vector, rhs: &dyn Vector) {
179        assert_eq!(lhs, rhs);
180    }
181
182    fn assert_vector_ref_ne(lhs: VectorRef, rhs: VectorRef) {
183        assert_ne!(lhs, rhs);
184    }
185
186    #[test]
187    fn test_vector_eq() {
188        assert_vector_ref_eq(Arc::new(BinaryVector::from(vec![
189            Some(b"hello".to_vec()),
190            Some(b"world".to_vec()),
191        ])));
192        assert_vector_ref_eq(Arc::new(BooleanVector::from(vec![true, false])));
193        assert_vector_ref_eq(Arc::new(ConstantVector::new(
194            Arc::new(BooleanVector::from(vec![true])),
195            5,
196        )));
197        assert_vector_ref_eq(Arc::new(BooleanVector::from(vec![true, false])));
198        assert_vector_ref_eq(Arc::new(DateVector::from(vec![Some(100), Some(120)])));
199        assert_vector_ref_eq(Arc::new(TimestampSecondVector::from_values([100, 120])));
200        assert_vector_ref_eq(Arc::new(TimestampMillisecondVector::from_values([
201            100, 120,
202        ])));
203        assert_vector_ref_eq(Arc::new(TimestampMicrosecondVector::from_values([
204            100, 120,
205        ])));
206        assert_vector_ref_eq(Arc::new(TimestampNanosecondVector::from_values([100, 120])));
207
208        let list_vector = list::tests::new_list_vector(&[
209            Some(vec![Some(1), Some(2)]),
210            None,
211            Some(vec![Some(3), Some(4)]),
212        ]);
213        assert_vector_ref_eq(Arc::new(list_vector));
214
215        assert_vector_ref_eq(Arc::new(NullVector::new(4)));
216        assert_vector_ref_eq(Arc::new(StringVector::from(vec![
217            Some("hello"),
218            Some("world"),
219        ])));
220
221        assert_vector_ref_eq(Arc::new(Int8Vector::from_slice([1, 2, 3, 4])));
222        assert_vector_ref_eq(Arc::new(UInt8Vector::from_slice([1, 2, 3, 4])));
223        assert_vector_ref_eq(Arc::new(Int16Vector::from_slice([1, 2, 3, 4])));
224        assert_vector_ref_eq(Arc::new(UInt16Vector::from_slice([1, 2, 3, 4])));
225        assert_vector_ref_eq(Arc::new(Int32Vector::from_slice([1, 2, 3, 4])));
226        assert_vector_ref_eq(Arc::new(UInt32Vector::from_slice([1, 2, 3, 4])));
227        assert_vector_ref_eq(Arc::new(Int64Vector::from_slice([1, 2, 3, 4])));
228        assert_vector_ref_eq(Arc::new(UInt64Vector::from_slice([1, 2, 3, 4])));
229        assert_vector_ref_eq(Arc::new(Float32Vector::from_slice([1.0, 2.0, 3.0, 4.0])));
230        assert_vector_ref_eq(Arc::new(Float64Vector::from_slice([1.0, 2.0, 3.0, 4.0])));
231
232        assert_vector_ref_eq(Arc::new(TimeSecondVector::from_values([100, 120])));
233        assert_vector_ref_eq(Arc::new(TimeMillisecondVector::from_values([100, 120])));
234        assert_vector_ref_eq(Arc::new(TimeMicrosecondVector::from_values([100, 120])));
235        assert_vector_ref_eq(Arc::new(TimeNanosecondVector::from_values([100, 120])));
236
237        assert_vector_ref_eq(Arc::new(IntervalYearMonthVector::from_values([
238            1000, 2000, 3000, 4000,
239        ])));
240        assert_vector_ref_eq(Arc::new(IntervalDayTimeVector::from_values([
241            IntervalDayTime::new(1, 1000),
242            IntervalDayTime::new(1, 2000),
243            IntervalDayTime::new(1, 3000),
244            IntervalDayTime::new(1, 4000),
245        ])));
246        assert_vector_ref_eq(Arc::new(IntervalMonthDayNanoVector::from_values([
247            IntervalMonthDayNano::new(1, 1, 1000),
248            IntervalMonthDayNano::new(1, 1, 2000),
249            IntervalMonthDayNano::new(1, 1, 3000),
250            IntervalMonthDayNano::new(1, 1, 4000),
251        ])));
252        assert_vector_ref_eq(Arc::new(DurationSecondVector::from_values([300, 310])));
253        assert_vector_ref_eq(Arc::new(DurationMillisecondVector::from_values([300, 310])));
254        assert_vector_ref_eq(Arc::new(DurationMicrosecondVector::from_values([300, 310])));
255        assert_vector_ref_eq(Arc::new(DurationNanosecondVector::from_values([300, 310])));
256        assert_vector_ref_eq(Arc::new(Decimal128Vector::from_values(vec![
257            1i128, 2i128, 3i128,
258        ])));
259    }
260
261    #[test]
262    fn test_vector_ne() {
263        assert_vector_ref_ne(
264            Arc::new(Int32Vector::from_slice([1, 2, 3, 4])),
265            Arc::new(Int32Vector::from_slice([1, 2])),
266        );
267        assert_vector_ref_ne(
268            Arc::new(Int32Vector::from_slice([1, 2, 3, 4])),
269            Arc::new(Int8Vector::from_slice([1, 2, 3, 4])),
270        );
271        assert_vector_ref_ne(
272            Arc::new(Int32Vector::from_slice([1, 2, 3, 4])),
273            Arc::new(BooleanVector::from(vec![true, true])),
274        );
275        assert_vector_ref_ne(
276            Arc::new(ConstantVector::new(
277                Arc::new(BooleanVector::from(vec![true])),
278                5,
279            )),
280            Arc::new(ConstantVector::new(
281                Arc::new(BooleanVector::from(vec![true])),
282                4,
283            )),
284        );
285        assert_vector_ref_ne(
286            Arc::new(ConstantVector::new(
287                Arc::new(BooleanVector::from(vec![true])),
288                5,
289            )),
290            Arc::new(ConstantVector::new(
291                Arc::new(BooleanVector::from(vec![false])),
292                4,
293            )),
294        );
295        assert_vector_ref_ne(
296            Arc::new(ConstantVector::new(
297                Arc::new(BooleanVector::from(vec![true])),
298                5,
299            )),
300            Arc::new(ConstantVector::new(
301                Arc::new(Int32Vector::from_slice(vec![1])),
302                4,
303            )),
304        );
305        assert_vector_ref_ne(Arc::new(NullVector::new(5)), Arc::new(NullVector::new(8)));
306
307        assert_vector_ref_ne(
308            Arc::new(TimeMicrosecondVector::from_values([100, 120])),
309            Arc::new(TimeMicrosecondVector::from_values([200, 220])),
310        );
311
312        assert_vector_ref_ne(
313            Arc::new(IntervalDayTimeVector::from_values([
314                IntervalDayTime::new(1, 1000),
315                IntervalDayTime::new(1, 2000),
316            ])),
317            Arc::new(IntervalDayTimeVector::from_values([
318                IntervalDayTime::new(1, 2100),
319                IntervalDayTime::new(1, 1200),
320            ])),
321        );
322        assert_vector_ref_ne(
323            Arc::new(IntervalMonthDayNanoVector::from_values([
324                IntervalMonthDayNano::new(1, 1, 1000),
325                IntervalMonthDayNano::new(1, 1, 2000),
326            ])),
327            Arc::new(IntervalMonthDayNanoVector::from_values([
328                IntervalMonthDayNano::new(1, 1, 2100),
329                IntervalMonthDayNano::new(1, 1, 1200),
330            ])),
331        );
332        assert_vector_ref_ne(
333            Arc::new(IntervalYearMonthVector::from_values([1000, 2000])),
334            Arc::new(IntervalYearMonthVector::from_values([2100, 1200])),
335        );
336
337        assert_vector_ref_ne(
338            Arc::new(DurationSecondVector::from_values([300, 310])),
339            Arc::new(DurationSecondVector::from_values([300, 320])),
340        );
341
342        assert_vector_ref_ne(
343            Arc::new(Decimal128Vector::from_values([300i128, 310i128])),
344            Arc::new(Decimal128Vector::from_values([300i128, 320i128])),
345        );
346    }
347}