datatypes/vectors/
eq.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_time::interval::IntervalUnit;
18
19use crate::data_type::DataType;
20use crate::types::{DurationType, TimeType, TimestampType};
21use crate::vectors::constant::ConstantVector;
22use crate::vectors::{
23    BinaryVector, BooleanVector, DateVector, Decimal128Vector, DurationMicrosecondVector,
24    DurationMillisecondVector, DurationNanosecondVector, DurationSecondVector,
25    IntervalDayTimeVector, IntervalMonthDayNanoVector, IntervalYearMonthVector, ListVector,
26    PrimitiveVector, StringVector, TimeMicrosecondVector, TimeMillisecondVector,
27    TimeNanosecondVector, TimeSecondVector, TimestampMicrosecondVector, TimestampMillisecondVector,
28    TimestampNanosecondVector, TimestampSecondVector, Vector,
29};
30use crate::with_match_primitive_type_id;
31
32impl Eq for dyn Vector + '_ {}
33
34impl PartialEq for dyn Vector + '_ {
35    fn eq(&self, other: &dyn Vector) -> bool {
36        equal(self, other)
37    }
38}
39
40impl PartialEq<dyn Vector> for Arc<dyn Vector + '_> {
41    fn eq(&self, other: &dyn Vector) -> bool {
42        equal(&**self, other)
43    }
44}
45
46macro_rules! is_vector_eq {
47    ($VectorType: ident, $lhs: ident, $rhs: ident) => {{
48        let lhs = $lhs.as_any().downcast_ref::<$VectorType>().unwrap();
49        let rhs = $rhs.as_any().downcast_ref::<$VectorType>().unwrap();
50
51        lhs == rhs
52    }};
53}
54
55fn equal(lhs: &dyn Vector, rhs: &dyn Vector) -> bool {
56    if lhs.data_type() != rhs.data_type() || lhs.len() != rhs.len() {
57        return false;
58    }
59
60    if lhs.is_const() || rhs.is_const() {
61        // Length has been checked before, so we only need to compare inner
62        // vector here.
63        return equal(
64            &**lhs
65                .as_any()
66                .downcast_ref::<ConstantVector>()
67                .unwrap()
68                .inner(),
69            &**lhs
70                .as_any()
71                .downcast_ref::<ConstantVector>()
72                .unwrap()
73                .inner(),
74        );
75    }
76
77    use crate::data_type::ConcreteDataType::*;
78
79    let lhs_type = lhs.data_type();
80    match lhs.data_type() {
81        Null(_) => true,
82        Boolean(_) => is_vector_eq!(BooleanVector, lhs, rhs),
83        Binary(_) | Json(_) | Vector(_) => is_vector_eq!(BinaryVector, lhs, rhs),
84        String(_) => is_vector_eq!(StringVector, lhs, rhs),
85        Date(_) => is_vector_eq!(DateVector, lhs, rhs),
86        Timestamp(t) => match t {
87            TimestampType::Second(_) => {
88                is_vector_eq!(TimestampSecondVector, lhs, rhs)
89            }
90            TimestampType::Millisecond(_) => {
91                is_vector_eq!(TimestampMillisecondVector, lhs, rhs)
92            }
93            TimestampType::Microsecond(_) => {
94                is_vector_eq!(TimestampMicrosecondVector, lhs, rhs)
95            }
96            TimestampType::Nanosecond(_) => {
97                is_vector_eq!(TimestampNanosecondVector, lhs, rhs)
98            }
99        },
100        Interval(v) => match v.unit() {
101            IntervalUnit::YearMonth => {
102                is_vector_eq!(IntervalYearMonthVector, lhs, rhs)
103            }
104            IntervalUnit::DayTime => {
105                is_vector_eq!(IntervalDayTimeVector, lhs, rhs)
106            }
107            IntervalUnit::MonthDayNano => {
108                is_vector_eq!(IntervalMonthDayNanoVector, lhs, rhs)
109            }
110        },
111        List(_) => is_vector_eq!(ListVector, lhs, rhs),
112        UInt8(_) | UInt16(_) | UInt32(_) | UInt64(_) | Int8(_) | Int16(_) | Int32(_) | Int64(_)
113        | Float32(_) | Float64(_) | Dictionary(_) => {
114            with_match_primitive_type_id!(lhs_type.logical_type_id(), |$T| {
115                let lhs = lhs.as_any().downcast_ref::<PrimitiveVector<$T>>().unwrap();
116                let rhs = rhs.as_any().downcast_ref::<PrimitiveVector<$T>>().unwrap();
117
118                lhs == rhs
119            },
120            {
121                unreachable!("should not compare {} with {}", lhs.vector_type_name(), rhs.vector_type_name())
122            })
123        }
124
125        Time(t) => match t {
126            TimeType::Second(_) => {
127                is_vector_eq!(TimeSecondVector, lhs, rhs)
128            }
129            TimeType::Millisecond(_) => {
130                is_vector_eq!(TimeMillisecondVector, lhs, rhs)
131            }
132            TimeType::Microsecond(_) => {
133                is_vector_eq!(TimeMicrosecondVector, lhs, rhs)
134            }
135            TimeType::Nanosecond(_) => {
136                is_vector_eq!(TimeNanosecondVector, lhs, rhs)
137            }
138        },
139        Duration(d) => match d {
140            DurationType::Second(_) => {
141                is_vector_eq!(DurationSecondVector, lhs, rhs)
142            }
143            DurationType::Millisecond(_) => {
144                is_vector_eq!(DurationMillisecondVector, lhs, rhs)
145            }
146            DurationType::Microsecond(_) => {
147                is_vector_eq!(DurationMicrosecondVector, lhs, rhs)
148            }
149            DurationType::Nanosecond(_) => {
150                is_vector_eq!(DurationNanosecondVector, lhs, rhs)
151            }
152        },
153        Decimal128(_) => {
154            is_vector_eq!(Decimal128Vector, lhs, rhs)
155        }
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano};
162
163    use super::*;
164    use crate::vectors::{
165        list, DurationMicrosecondVector, DurationMillisecondVector, DurationNanosecondVector,
166        DurationSecondVector, Float32Vector, Float64Vector, Int16Vector, Int32Vector, Int64Vector,
167        Int8Vector, NullVector, UInt16Vector, UInt32Vector, UInt64Vector, UInt8Vector, VectorRef,
168    };
169
170    fn assert_vector_ref_eq(vector: VectorRef) {
171        let rhs = vector.clone();
172        assert_eq!(vector, rhs);
173        assert_dyn_vector_eq(&*vector, &*rhs);
174    }
175
176    fn assert_dyn_vector_eq(lhs: &dyn Vector, rhs: &dyn Vector) {
177        assert_eq!(lhs, rhs);
178    }
179
180    fn assert_vector_ref_ne(lhs: VectorRef, rhs: VectorRef) {
181        assert_ne!(lhs, rhs);
182    }
183
184    #[test]
185    fn test_vector_eq() {
186        assert_vector_ref_eq(Arc::new(BinaryVector::from(vec![
187            Some(b"hello".to_vec()),
188            Some(b"world".to_vec()),
189        ])));
190        assert_vector_ref_eq(Arc::new(BooleanVector::from(vec![true, false])));
191        assert_vector_ref_eq(Arc::new(ConstantVector::new(
192            Arc::new(BooleanVector::from(vec![true])),
193            5,
194        )));
195        assert_vector_ref_eq(Arc::new(BooleanVector::from(vec![true, false])));
196        assert_vector_ref_eq(Arc::new(DateVector::from(vec![Some(100), Some(120)])));
197        assert_vector_ref_eq(Arc::new(TimestampSecondVector::from_values([100, 120])));
198        assert_vector_ref_eq(Arc::new(TimestampMillisecondVector::from_values([
199            100, 120,
200        ])));
201        assert_vector_ref_eq(Arc::new(TimestampMicrosecondVector::from_values([
202            100, 120,
203        ])));
204        assert_vector_ref_eq(Arc::new(TimestampNanosecondVector::from_values([100, 120])));
205
206        let list_vector = list::tests::new_list_vector(&[
207            Some(vec![Some(1), Some(2)]),
208            None,
209            Some(vec![Some(3), Some(4)]),
210        ]);
211        assert_vector_ref_eq(Arc::new(list_vector));
212
213        assert_vector_ref_eq(Arc::new(NullVector::new(4)));
214        assert_vector_ref_eq(Arc::new(StringVector::from(vec![
215            Some("hello"),
216            Some("world"),
217        ])));
218
219        assert_vector_ref_eq(Arc::new(Int8Vector::from_slice([1, 2, 3, 4])));
220        assert_vector_ref_eq(Arc::new(UInt8Vector::from_slice([1, 2, 3, 4])));
221        assert_vector_ref_eq(Arc::new(Int16Vector::from_slice([1, 2, 3, 4])));
222        assert_vector_ref_eq(Arc::new(UInt16Vector::from_slice([1, 2, 3, 4])));
223        assert_vector_ref_eq(Arc::new(Int32Vector::from_slice([1, 2, 3, 4])));
224        assert_vector_ref_eq(Arc::new(UInt32Vector::from_slice([1, 2, 3, 4])));
225        assert_vector_ref_eq(Arc::new(Int64Vector::from_slice([1, 2, 3, 4])));
226        assert_vector_ref_eq(Arc::new(UInt64Vector::from_slice([1, 2, 3, 4])));
227        assert_vector_ref_eq(Arc::new(Float32Vector::from_slice([1.0, 2.0, 3.0, 4.0])));
228        assert_vector_ref_eq(Arc::new(Float64Vector::from_slice([1.0, 2.0, 3.0, 4.0])));
229
230        assert_vector_ref_eq(Arc::new(TimeSecondVector::from_values([100, 120])));
231        assert_vector_ref_eq(Arc::new(TimeMillisecondVector::from_values([100, 120])));
232        assert_vector_ref_eq(Arc::new(TimeMicrosecondVector::from_values([100, 120])));
233        assert_vector_ref_eq(Arc::new(TimeNanosecondVector::from_values([100, 120])));
234
235        assert_vector_ref_eq(Arc::new(IntervalYearMonthVector::from_values([
236            1000, 2000, 3000, 4000,
237        ])));
238        assert_vector_ref_eq(Arc::new(IntervalDayTimeVector::from_values([
239            IntervalDayTime::new(1, 1000),
240            IntervalDayTime::new(1, 2000),
241            IntervalDayTime::new(1, 3000),
242            IntervalDayTime::new(1, 4000),
243        ])));
244        assert_vector_ref_eq(Arc::new(IntervalMonthDayNanoVector::from_values([
245            IntervalMonthDayNano::new(1, 1, 1000),
246            IntervalMonthDayNano::new(1, 1, 2000),
247            IntervalMonthDayNano::new(1, 1, 3000),
248            IntervalMonthDayNano::new(1, 1, 4000),
249        ])));
250        assert_vector_ref_eq(Arc::new(DurationSecondVector::from_values([300, 310])));
251        assert_vector_ref_eq(Arc::new(DurationMillisecondVector::from_values([300, 310])));
252        assert_vector_ref_eq(Arc::new(DurationMicrosecondVector::from_values([300, 310])));
253        assert_vector_ref_eq(Arc::new(DurationNanosecondVector::from_values([300, 310])));
254        assert_vector_ref_eq(Arc::new(Decimal128Vector::from_values(vec![
255            1i128, 2i128, 3i128,
256        ])));
257    }
258
259    #[test]
260    fn test_vector_ne() {
261        assert_vector_ref_ne(
262            Arc::new(Int32Vector::from_slice([1, 2, 3, 4])),
263            Arc::new(Int32Vector::from_slice([1, 2])),
264        );
265        assert_vector_ref_ne(
266            Arc::new(Int32Vector::from_slice([1, 2, 3, 4])),
267            Arc::new(Int8Vector::from_slice([1, 2, 3, 4])),
268        );
269        assert_vector_ref_ne(
270            Arc::new(Int32Vector::from_slice([1, 2, 3, 4])),
271            Arc::new(BooleanVector::from(vec![true, true])),
272        );
273        assert_vector_ref_ne(
274            Arc::new(ConstantVector::new(
275                Arc::new(BooleanVector::from(vec![true])),
276                5,
277            )),
278            Arc::new(ConstantVector::new(
279                Arc::new(BooleanVector::from(vec![true])),
280                4,
281            )),
282        );
283        assert_vector_ref_ne(
284            Arc::new(ConstantVector::new(
285                Arc::new(BooleanVector::from(vec![true])),
286                5,
287            )),
288            Arc::new(ConstantVector::new(
289                Arc::new(BooleanVector::from(vec![false])),
290                4,
291            )),
292        );
293        assert_vector_ref_ne(
294            Arc::new(ConstantVector::new(
295                Arc::new(BooleanVector::from(vec![true])),
296                5,
297            )),
298            Arc::new(ConstantVector::new(
299                Arc::new(Int32Vector::from_slice(vec![1])),
300                4,
301            )),
302        );
303        assert_vector_ref_ne(Arc::new(NullVector::new(5)), Arc::new(NullVector::new(8)));
304
305        assert_vector_ref_ne(
306            Arc::new(TimeMicrosecondVector::from_values([100, 120])),
307            Arc::new(TimeMicrosecondVector::from_values([200, 220])),
308        );
309
310        assert_vector_ref_ne(
311            Arc::new(IntervalDayTimeVector::from_values([
312                IntervalDayTime::new(1, 1000),
313                IntervalDayTime::new(1, 2000),
314            ])),
315            Arc::new(IntervalDayTimeVector::from_values([
316                IntervalDayTime::new(1, 2100),
317                IntervalDayTime::new(1, 1200),
318            ])),
319        );
320        assert_vector_ref_ne(
321            Arc::new(IntervalMonthDayNanoVector::from_values([
322                IntervalMonthDayNano::new(1, 1, 1000),
323                IntervalMonthDayNano::new(1, 1, 2000),
324            ])),
325            Arc::new(IntervalMonthDayNanoVector::from_values([
326                IntervalMonthDayNano::new(1, 1, 2100),
327                IntervalMonthDayNano::new(1, 1, 1200),
328            ])),
329        );
330        assert_vector_ref_ne(
331            Arc::new(IntervalYearMonthVector::from_values([1000, 2000])),
332            Arc::new(IntervalYearMonthVector::from_values([2100, 1200])),
333        );
334
335        assert_vector_ref_ne(
336            Arc::new(DurationSecondVector::from_values([300, 310])),
337            Arc::new(DurationSecondVector::from_values([300, 320])),
338        );
339
340        assert_vector_ref_ne(
341            Arc::new(Decimal128Vector::from_values([300i128, 310i128])),
342            Arc::new(Decimal128Vector::from_values([300i128, 320i128])),
343        );
344    }
345}