1use std::sync::Arc;
16
17use common_time::interval::IntervalUnit;
18
19use crate::data_type::DataType;
20use crate::types::{DurationType, TimeType, TimestampType};
21use crate::vectors::constant::ConstantVector;
22use crate::vectors::struct_vector::StructVector;
23use crate::vectors::{
24 BinaryVector, BooleanVector, DateVector, Decimal128Vector, DurationMicrosecondVector,
25 DurationMillisecondVector, DurationNanosecondVector, DurationSecondVector,
26 IntervalDayTimeVector, IntervalMonthDayNanoVector, IntervalYearMonthVector, ListVector,
27 PrimitiveVector, StringVector, TimeMicrosecondVector, TimeMillisecondVector,
28 TimeNanosecondVector, TimeSecondVector, TimestampMicrosecondVector, TimestampMillisecondVector,
29 TimestampNanosecondVector, TimestampSecondVector, Vector,
30};
31use crate::with_match_primitive_type_id;
32
33impl Eq for dyn Vector + '_ {}
34
35impl PartialEq for dyn Vector + '_ {
36 fn eq(&self, other: &dyn Vector) -> bool {
37 equal(self, other)
38 }
39}
40
41impl PartialEq<dyn Vector> for Arc<dyn Vector + '_> {
42 fn eq(&self, other: &dyn Vector) -> bool {
43 equal(&**self, other)
44 }
45}
46
47macro_rules! is_vector_eq {
48 ($VectorType: ident, $lhs: ident, $rhs: ident) => {{
49 let lhs = $lhs.as_any().downcast_ref::<$VectorType>().unwrap();
50 let rhs = $rhs.as_any().downcast_ref::<$VectorType>().unwrap();
51
52 lhs == rhs
53 }};
54}
55
56fn equal(lhs: &dyn Vector, rhs: &dyn Vector) -> bool {
57 if lhs.data_type() != rhs.data_type() || lhs.len() != rhs.len() {
58 return false;
59 }
60
61 if lhs.is_const() || rhs.is_const() {
62 return equal(
65 &**lhs
66 .as_any()
67 .downcast_ref::<ConstantVector>()
68 .unwrap()
69 .inner(),
70 &**lhs
71 .as_any()
72 .downcast_ref::<ConstantVector>()
73 .unwrap()
74 .inner(),
75 );
76 }
77
78 use crate::data_type::ConcreteDataType::*;
79
80 let lhs_type = lhs.data_type();
81 match lhs.data_type() {
82 Null(_) => true,
83 Boolean(_) => is_vector_eq!(BooleanVector, lhs, rhs),
84 Binary(_) | Json(_) | Vector(_) => is_vector_eq!(BinaryVector, lhs, rhs),
85 String(_) => is_vector_eq!(StringVector, lhs, rhs),
86 Date(_) => is_vector_eq!(DateVector, lhs, rhs),
87 Timestamp(t) => match t {
88 TimestampType::Second(_) => {
89 is_vector_eq!(TimestampSecondVector, lhs, rhs)
90 }
91 TimestampType::Millisecond(_) => {
92 is_vector_eq!(TimestampMillisecondVector, lhs, rhs)
93 }
94 TimestampType::Microsecond(_) => {
95 is_vector_eq!(TimestampMicrosecondVector, lhs, rhs)
96 }
97 TimestampType::Nanosecond(_) => {
98 is_vector_eq!(TimestampNanosecondVector, lhs, rhs)
99 }
100 },
101 Interval(v) => match v.unit() {
102 IntervalUnit::YearMonth => {
103 is_vector_eq!(IntervalYearMonthVector, lhs, rhs)
104 }
105 IntervalUnit::DayTime => {
106 is_vector_eq!(IntervalDayTimeVector, lhs, rhs)
107 }
108 IntervalUnit::MonthDayNano => {
109 is_vector_eq!(IntervalMonthDayNanoVector, lhs, rhs)
110 }
111 },
112 List(_) => is_vector_eq!(ListVector, lhs, rhs),
113 Struct(_) => is_vector_eq!(StructVector, lhs, rhs),
114 UInt8(_) | UInt16(_) | UInt32(_) | UInt64(_) | Int8(_) | Int16(_) | Int32(_) | Int64(_)
115 | Float32(_) | Float64(_) | Dictionary(_) => {
116 with_match_primitive_type_id!(lhs_type.logical_type_id(), |$T| {
117 let lhs = lhs.as_any().downcast_ref::<PrimitiveVector<$T>>().unwrap();
118 let rhs = rhs.as_any().downcast_ref::<PrimitiveVector<$T>>().unwrap();
119
120 lhs == rhs
121 },
122 {
123 unreachable!("should not compare {} with {}", lhs.vector_type_name(), rhs.vector_type_name())
124 })
125 }
126
127 Time(t) => match t {
128 TimeType::Second(_) => {
129 is_vector_eq!(TimeSecondVector, lhs, rhs)
130 }
131 TimeType::Millisecond(_) => {
132 is_vector_eq!(TimeMillisecondVector, lhs, rhs)
133 }
134 TimeType::Microsecond(_) => {
135 is_vector_eq!(TimeMicrosecondVector, lhs, rhs)
136 }
137 TimeType::Nanosecond(_) => {
138 is_vector_eq!(TimeNanosecondVector, lhs, rhs)
139 }
140 },
141 Duration(d) => match d {
142 DurationType::Second(_) => {
143 is_vector_eq!(DurationSecondVector, lhs, rhs)
144 }
145 DurationType::Millisecond(_) => {
146 is_vector_eq!(DurationMillisecondVector, lhs, rhs)
147 }
148 DurationType::Microsecond(_) => {
149 is_vector_eq!(DurationMicrosecondVector, lhs, rhs)
150 }
151 DurationType::Nanosecond(_) => {
152 is_vector_eq!(DurationNanosecondVector, lhs, rhs)
153 }
154 },
155 Decimal128(_) => {
156 is_vector_eq!(Decimal128Vector, lhs, rhs)
157 }
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano};
164
165 use super::*;
166 use crate::vectors::{
167 list, DurationMicrosecondVector, DurationMillisecondVector, DurationNanosecondVector,
168 DurationSecondVector, Float32Vector, Float64Vector, Int16Vector, Int32Vector, Int64Vector,
169 Int8Vector, NullVector, UInt16Vector, UInt32Vector, UInt64Vector, UInt8Vector, VectorRef,
170 };
171
172 fn assert_vector_ref_eq(vector: VectorRef) {
173 let rhs = vector.clone();
174 assert_eq!(vector, rhs);
175 assert_dyn_vector_eq(&*vector, &*rhs);
176 }
177
178 fn assert_dyn_vector_eq(lhs: &dyn Vector, rhs: &dyn Vector) {
179 assert_eq!(lhs, rhs);
180 }
181
182 fn assert_vector_ref_ne(lhs: VectorRef, rhs: VectorRef) {
183 assert_ne!(lhs, rhs);
184 }
185
186 #[test]
187 fn test_vector_eq() {
188 assert_vector_ref_eq(Arc::new(BinaryVector::from(vec![
189 Some(b"hello".to_vec()),
190 Some(b"world".to_vec()),
191 ])));
192 assert_vector_ref_eq(Arc::new(BooleanVector::from(vec![true, false])));
193 assert_vector_ref_eq(Arc::new(ConstantVector::new(
194 Arc::new(BooleanVector::from(vec![true])),
195 5,
196 )));
197 assert_vector_ref_eq(Arc::new(BooleanVector::from(vec![true, false])));
198 assert_vector_ref_eq(Arc::new(DateVector::from(vec![Some(100), Some(120)])));
199 assert_vector_ref_eq(Arc::new(TimestampSecondVector::from_values([100, 120])));
200 assert_vector_ref_eq(Arc::new(TimestampMillisecondVector::from_values([
201 100, 120,
202 ])));
203 assert_vector_ref_eq(Arc::new(TimestampMicrosecondVector::from_values([
204 100, 120,
205 ])));
206 assert_vector_ref_eq(Arc::new(TimestampNanosecondVector::from_values([100, 120])));
207
208 let list_vector = list::tests::new_list_vector(&[
209 Some(vec![Some(1), Some(2)]),
210 None,
211 Some(vec![Some(3), Some(4)]),
212 ]);
213 assert_vector_ref_eq(Arc::new(list_vector));
214
215 assert_vector_ref_eq(Arc::new(NullVector::new(4)));
216 assert_vector_ref_eq(Arc::new(StringVector::from(vec![
217 Some("hello"),
218 Some("world"),
219 ])));
220
221 assert_vector_ref_eq(Arc::new(Int8Vector::from_slice([1, 2, 3, 4])));
222 assert_vector_ref_eq(Arc::new(UInt8Vector::from_slice([1, 2, 3, 4])));
223 assert_vector_ref_eq(Arc::new(Int16Vector::from_slice([1, 2, 3, 4])));
224 assert_vector_ref_eq(Arc::new(UInt16Vector::from_slice([1, 2, 3, 4])));
225 assert_vector_ref_eq(Arc::new(Int32Vector::from_slice([1, 2, 3, 4])));
226 assert_vector_ref_eq(Arc::new(UInt32Vector::from_slice([1, 2, 3, 4])));
227 assert_vector_ref_eq(Arc::new(Int64Vector::from_slice([1, 2, 3, 4])));
228 assert_vector_ref_eq(Arc::new(UInt64Vector::from_slice([1, 2, 3, 4])));
229 assert_vector_ref_eq(Arc::new(Float32Vector::from_slice([1.0, 2.0, 3.0, 4.0])));
230 assert_vector_ref_eq(Arc::new(Float64Vector::from_slice([1.0, 2.0, 3.0, 4.0])));
231
232 assert_vector_ref_eq(Arc::new(TimeSecondVector::from_values([100, 120])));
233 assert_vector_ref_eq(Arc::new(TimeMillisecondVector::from_values([100, 120])));
234 assert_vector_ref_eq(Arc::new(TimeMicrosecondVector::from_values([100, 120])));
235 assert_vector_ref_eq(Arc::new(TimeNanosecondVector::from_values([100, 120])));
236
237 assert_vector_ref_eq(Arc::new(IntervalYearMonthVector::from_values([
238 1000, 2000, 3000, 4000,
239 ])));
240 assert_vector_ref_eq(Arc::new(IntervalDayTimeVector::from_values([
241 IntervalDayTime::new(1, 1000),
242 IntervalDayTime::new(1, 2000),
243 IntervalDayTime::new(1, 3000),
244 IntervalDayTime::new(1, 4000),
245 ])));
246 assert_vector_ref_eq(Arc::new(IntervalMonthDayNanoVector::from_values([
247 IntervalMonthDayNano::new(1, 1, 1000),
248 IntervalMonthDayNano::new(1, 1, 2000),
249 IntervalMonthDayNano::new(1, 1, 3000),
250 IntervalMonthDayNano::new(1, 1, 4000),
251 ])));
252 assert_vector_ref_eq(Arc::new(DurationSecondVector::from_values([300, 310])));
253 assert_vector_ref_eq(Arc::new(DurationMillisecondVector::from_values([300, 310])));
254 assert_vector_ref_eq(Arc::new(DurationMicrosecondVector::from_values([300, 310])));
255 assert_vector_ref_eq(Arc::new(DurationNanosecondVector::from_values([300, 310])));
256 assert_vector_ref_eq(Arc::new(Decimal128Vector::from_values(vec![
257 1i128, 2i128, 3i128,
258 ])));
259 }
260
261 #[test]
262 fn test_vector_ne() {
263 assert_vector_ref_ne(
264 Arc::new(Int32Vector::from_slice([1, 2, 3, 4])),
265 Arc::new(Int32Vector::from_slice([1, 2])),
266 );
267 assert_vector_ref_ne(
268 Arc::new(Int32Vector::from_slice([1, 2, 3, 4])),
269 Arc::new(Int8Vector::from_slice([1, 2, 3, 4])),
270 );
271 assert_vector_ref_ne(
272 Arc::new(Int32Vector::from_slice([1, 2, 3, 4])),
273 Arc::new(BooleanVector::from(vec![true, true])),
274 );
275 assert_vector_ref_ne(
276 Arc::new(ConstantVector::new(
277 Arc::new(BooleanVector::from(vec![true])),
278 5,
279 )),
280 Arc::new(ConstantVector::new(
281 Arc::new(BooleanVector::from(vec![true])),
282 4,
283 )),
284 );
285 assert_vector_ref_ne(
286 Arc::new(ConstantVector::new(
287 Arc::new(BooleanVector::from(vec![true])),
288 5,
289 )),
290 Arc::new(ConstantVector::new(
291 Arc::new(BooleanVector::from(vec![false])),
292 4,
293 )),
294 );
295 assert_vector_ref_ne(
296 Arc::new(ConstantVector::new(
297 Arc::new(BooleanVector::from(vec![true])),
298 5,
299 )),
300 Arc::new(ConstantVector::new(
301 Arc::new(Int32Vector::from_slice(vec![1])),
302 4,
303 )),
304 );
305 assert_vector_ref_ne(Arc::new(NullVector::new(5)), Arc::new(NullVector::new(8)));
306
307 assert_vector_ref_ne(
308 Arc::new(TimeMicrosecondVector::from_values([100, 120])),
309 Arc::new(TimeMicrosecondVector::from_values([200, 220])),
310 );
311
312 assert_vector_ref_ne(
313 Arc::new(IntervalDayTimeVector::from_values([
314 IntervalDayTime::new(1, 1000),
315 IntervalDayTime::new(1, 2000),
316 ])),
317 Arc::new(IntervalDayTimeVector::from_values([
318 IntervalDayTime::new(1, 2100),
319 IntervalDayTime::new(1, 1200),
320 ])),
321 );
322 assert_vector_ref_ne(
323 Arc::new(IntervalMonthDayNanoVector::from_values([
324 IntervalMonthDayNano::new(1, 1, 1000),
325 IntervalMonthDayNano::new(1, 1, 2000),
326 ])),
327 Arc::new(IntervalMonthDayNanoVector::from_values([
328 IntervalMonthDayNano::new(1, 1, 2100),
329 IntervalMonthDayNano::new(1, 1, 1200),
330 ])),
331 );
332 assert_vector_ref_ne(
333 Arc::new(IntervalYearMonthVector::from_values([1000, 2000])),
334 Arc::new(IntervalYearMonthVector::from_values([2100, 1200])),
335 );
336
337 assert_vector_ref_ne(
338 Arc::new(DurationSecondVector::from_values([300, 310])),
339 Arc::new(DurationSecondVector::from_values([300, 320])),
340 );
341
342 assert_vector_ref_ne(
343 Arc::new(Decimal128Vector::from_values([300i128, 310i128])),
344 Arc::new(Decimal128Vector::from_values([300i128, 320i128])),
345 );
346 }
347}