1use std::sync::Arc;
16
17use common_time::interval::IntervalUnit;
18
19use crate::data_type::DataType;
20use crate::types::{DurationType, TimeType, TimestampType};
21use crate::vectors::constant::ConstantVector;
22use crate::vectors::struct_vector::StructVector;
23use crate::vectors::{
24 BinaryVector, BooleanVector, DateVector, Decimal128Vector, DurationMicrosecondVector,
25 DurationMillisecondVector, DurationNanosecondVector, DurationSecondVector,
26 IntervalDayTimeVector, IntervalMonthDayNanoVector, IntervalYearMonthVector, ListVector,
27 PrimitiveVector, StringVector, TimeMicrosecondVector, TimeMillisecondVector,
28 TimeNanosecondVector, TimeSecondVector, TimestampMicrosecondVector, TimestampMillisecondVector,
29 TimestampNanosecondVector, TimestampSecondVector, Vector,
30};
31use crate::with_match_primitive_type_id;
32
33impl Eq for dyn Vector + '_ {}
34
35impl PartialEq for dyn Vector + '_ {
36 fn eq(&self, other: &dyn Vector) -> bool {
37 equal(self, other)
38 }
39}
40
41impl PartialEq<dyn Vector> for Arc<dyn Vector + '_> {
42 fn eq(&self, other: &dyn Vector) -> bool {
43 equal(&**self, other)
44 }
45}
46
47macro_rules! is_vector_eq {
48 ($VectorType: ident, $lhs: ident, $rhs: ident) => {{
49 let lhs = $lhs.as_any().downcast_ref::<$VectorType>().unwrap();
50 let rhs = $rhs.as_any().downcast_ref::<$VectorType>().unwrap();
51
52 lhs == rhs
53 }};
54}
55
56fn equal(lhs: &dyn Vector, rhs: &dyn Vector) -> bool {
57 if lhs.data_type() != rhs.data_type() || lhs.len() != rhs.len() {
58 return false;
59 }
60
61 if lhs.is_const() || rhs.is_const() {
62 return equal(
65 &**lhs
66 .as_any()
67 .downcast_ref::<ConstantVector>()
68 .unwrap()
69 .inner(),
70 &**lhs
71 .as_any()
72 .downcast_ref::<ConstantVector>()
73 .unwrap()
74 .inner(),
75 );
76 }
77
78 use crate::data_type::ConcreteDataType::*;
79
80 let lhs_type = lhs.data_type();
81 match lhs.data_type() {
82 Null(_) => true,
83 Boolean(_) => is_vector_eq!(BooleanVector, lhs, rhs),
84 Binary(_) | Json(_) | Vector(_) => is_vector_eq!(BinaryVector, lhs, rhs),
85 String(_) => is_vector_eq!(StringVector, lhs, rhs),
86 Date(_) => is_vector_eq!(DateVector, lhs, rhs),
87 Timestamp(t) => match t {
88 TimestampType::Second(_) => {
89 is_vector_eq!(TimestampSecondVector, lhs, rhs)
90 }
91 TimestampType::Millisecond(_) => {
92 is_vector_eq!(TimestampMillisecondVector, lhs, rhs)
93 }
94 TimestampType::Microsecond(_) => {
95 is_vector_eq!(TimestampMicrosecondVector, lhs, rhs)
96 }
97 TimestampType::Nanosecond(_) => {
98 is_vector_eq!(TimestampNanosecondVector, lhs, rhs)
99 }
100 },
101 Interval(v) => match v.unit() {
102 IntervalUnit::YearMonth => {
103 is_vector_eq!(IntervalYearMonthVector, lhs, rhs)
104 }
105 IntervalUnit::DayTime => {
106 is_vector_eq!(IntervalDayTimeVector, lhs, rhs)
107 }
108 IntervalUnit::MonthDayNano => {
109 is_vector_eq!(IntervalMonthDayNanoVector, lhs, rhs)
110 }
111 },
112 List(_) => is_vector_eq!(ListVector, lhs, rhs),
113 Struct(_) => is_vector_eq!(StructVector, lhs, rhs),
114 UInt8(_) | UInt16(_) | UInt32(_) | UInt64(_) | Int8(_) | Int16(_) | Int32(_) | Int64(_)
115 | Float32(_) | Float64(_) | Dictionary(_) => {
116 with_match_primitive_type_id!(lhs_type.logical_type_id(), |$T| {
117 let lhs = lhs.as_any().downcast_ref::<PrimitiveVector<$T>>().unwrap();
118 let rhs = rhs.as_any().downcast_ref::<PrimitiveVector<$T>>().unwrap();
119
120 lhs == rhs
121 },
122 {
123 unreachable!("should not compare {} with {}", lhs.vector_type_name(), rhs.vector_type_name())
124 })
125 }
126
127 Time(t) => match t {
128 TimeType::Second(_) => {
129 is_vector_eq!(TimeSecondVector, lhs, rhs)
130 }
131 TimeType::Millisecond(_) => {
132 is_vector_eq!(TimeMillisecondVector, lhs, rhs)
133 }
134 TimeType::Microsecond(_) => {
135 is_vector_eq!(TimeMicrosecondVector, lhs, rhs)
136 }
137 TimeType::Nanosecond(_) => {
138 is_vector_eq!(TimeNanosecondVector, lhs, rhs)
139 }
140 },
141 Duration(d) => match d {
142 DurationType::Second(_) => {
143 is_vector_eq!(DurationSecondVector, lhs, rhs)
144 }
145 DurationType::Millisecond(_) => {
146 is_vector_eq!(DurationMillisecondVector, lhs, rhs)
147 }
148 DurationType::Microsecond(_) => {
149 is_vector_eq!(DurationMicrosecondVector, lhs, rhs)
150 }
151 DurationType::Nanosecond(_) => {
152 is_vector_eq!(DurationNanosecondVector, lhs, rhs)
153 }
154 },
155 Decimal128(_) => {
156 is_vector_eq!(Decimal128Vector, lhs, rhs)
157 }
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano};
164
165 use super::*;
166 use crate::vectors::{
167 DurationMicrosecondVector, DurationMillisecondVector, DurationNanosecondVector,
168 DurationSecondVector, Float32Vector, Float64Vector, Int8Vector, Int16Vector, Int32Vector,
169 Int64Vector, NullVector, UInt8Vector, UInt16Vector, UInt32Vector, UInt64Vector, VectorRef,
170 list,
171 };
172
173 fn assert_vector_ref_eq(vector: VectorRef) {
174 let rhs = vector.clone();
175 assert_eq!(vector, rhs);
176 assert_dyn_vector_eq(&*vector, &*rhs);
177 }
178
179 fn assert_dyn_vector_eq(lhs: &dyn Vector, rhs: &dyn Vector) {
180 assert_eq!(lhs, rhs);
181 }
182
183 fn assert_vector_ref_ne(lhs: VectorRef, rhs: VectorRef) {
184 assert_ne!(lhs, rhs);
185 }
186
187 #[test]
188 fn test_vector_eq() {
189 assert_vector_ref_eq(Arc::new(BinaryVector::from(vec![
190 Some(b"hello".to_vec()),
191 Some(b"world".to_vec()),
192 ])));
193 assert_vector_ref_eq(Arc::new(BooleanVector::from(vec![true, false])));
194 assert_vector_ref_eq(Arc::new(ConstantVector::new(
195 Arc::new(BooleanVector::from(vec![true])),
196 5,
197 )));
198 assert_vector_ref_eq(Arc::new(BooleanVector::from(vec![true, false])));
199 assert_vector_ref_eq(Arc::new(DateVector::from(vec![Some(100), Some(120)])));
200 assert_vector_ref_eq(Arc::new(TimestampSecondVector::from_values([100, 120])));
201 assert_vector_ref_eq(Arc::new(TimestampMillisecondVector::from_values([
202 100, 120,
203 ])));
204 assert_vector_ref_eq(Arc::new(TimestampMicrosecondVector::from_values([
205 100, 120,
206 ])));
207 assert_vector_ref_eq(Arc::new(TimestampNanosecondVector::from_values([100, 120])));
208
209 let list_vector = list::tests::new_list_vector(&[
210 Some(vec![Some(1), Some(2)]),
211 None,
212 Some(vec![Some(3), Some(4)]),
213 ]);
214 assert_vector_ref_eq(Arc::new(list_vector));
215
216 assert_vector_ref_eq(Arc::new(NullVector::new(4)));
217 assert_vector_ref_eq(Arc::new(StringVector::from(vec![
218 Some("hello"),
219 Some("world"),
220 ])));
221
222 assert_vector_ref_eq(Arc::new(Int8Vector::from_slice([1, 2, 3, 4])));
223 assert_vector_ref_eq(Arc::new(UInt8Vector::from_slice([1, 2, 3, 4])));
224 assert_vector_ref_eq(Arc::new(Int16Vector::from_slice([1, 2, 3, 4])));
225 assert_vector_ref_eq(Arc::new(UInt16Vector::from_slice([1, 2, 3, 4])));
226 assert_vector_ref_eq(Arc::new(Int32Vector::from_slice([1, 2, 3, 4])));
227 assert_vector_ref_eq(Arc::new(UInt32Vector::from_slice([1, 2, 3, 4])));
228 assert_vector_ref_eq(Arc::new(Int64Vector::from_slice([1, 2, 3, 4])));
229 assert_vector_ref_eq(Arc::new(UInt64Vector::from_slice([1, 2, 3, 4])));
230 assert_vector_ref_eq(Arc::new(Float32Vector::from_slice([1.0, 2.0, 3.0, 4.0])));
231 assert_vector_ref_eq(Arc::new(Float64Vector::from_slice([1.0, 2.0, 3.0, 4.0])));
232
233 assert_vector_ref_eq(Arc::new(TimeSecondVector::from_values([100, 120])));
234 assert_vector_ref_eq(Arc::new(TimeMillisecondVector::from_values([100, 120])));
235 assert_vector_ref_eq(Arc::new(TimeMicrosecondVector::from_values([100, 120])));
236 assert_vector_ref_eq(Arc::new(TimeNanosecondVector::from_values([100, 120])));
237
238 assert_vector_ref_eq(Arc::new(IntervalYearMonthVector::from_values([
239 1000, 2000, 3000, 4000,
240 ])));
241 assert_vector_ref_eq(Arc::new(IntervalDayTimeVector::from_values([
242 IntervalDayTime::new(1, 1000),
243 IntervalDayTime::new(1, 2000),
244 IntervalDayTime::new(1, 3000),
245 IntervalDayTime::new(1, 4000),
246 ])));
247 assert_vector_ref_eq(Arc::new(IntervalMonthDayNanoVector::from_values([
248 IntervalMonthDayNano::new(1, 1, 1000),
249 IntervalMonthDayNano::new(1, 1, 2000),
250 IntervalMonthDayNano::new(1, 1, 3000),
251 IntervalMonthDayNano::new(1, 1, 4000),
252 ])));
253 assert_vector_ref_eq(Arc::new(DurationSecondVector::from_values([300, 310])));
254 assert_vector_ref_eq(Arc::new(DurationMillisecondVector::from_values([300, 310])));
255 assert_vector_ref_eq(Arc::new(DurationMicrosecondVector::from_values([300, 310])));
256 assert_vector_ref_eq(Arc::new(DurationNanosecondVector::from_values([300, 310])));
257 assert_vector_ref_eq(Arc::new(Decimal128Vector::from_values(vec![
258 1i128, 2i128, 3i128,
259 ])));
260 }
261
262 #[test]
263 fn test_vector_ne() {
264 assert_vector_ref_ne(
265 Arc::new(Int32Vector::from_slice([1, 2, 3, 4])),
266 Arc::new(Int32Vector::from_slice([1, 2])),
267 );
268 assert_vector_ref_ne(
269 Arc::new(Int32Vector::from_slice([1, 2, 3, 4])),
270 Arc::new(Int8Vector::from_slice([1, 2, 3, 4])),
271 );
272 assert_vector_ref_ne(
273 Arc::new(Int32Vector::from_slice([1, 2, 3, 4])),
274 Arc::new(BooleanVector::from(vec![true, true])),
275 );
276 assert_vector_ref_ne(
277 Arc::new(ConstantVector::new(
278 Arc::new(BooleanVector::from(vec![true])),
279 5,
280 )),
281 Arc::new(ConstantVector::new(
282 Arc::new(BooleanVector::from(vec![true])),
283 4,
284 )),
285 );
286 assert_vector_ref_ne(
287 Arc::new(ConstantVector::new(
288 Arc::new(BooleanVector::from(vec![true])),
289 5,
290 )),
291 Arc::new(ConstantVector::new(
292 Arc::new(BooleanVector::from(vec![false])),
293 4,
294 )),
295 );
296 assert_vector_ref_ne(
297 Arc::new(ConstantVector::new(
298 Arc::new(BooleanVector::from(vec![true])),
299 5,
300 )),
301 Arc::new(ConstantVector::new(
302 Arc::new(Int32Vector::from_slice(vec![1])),
303 4,
304 )),
305 );
306 assert_vector_ref_ne(Arc::new(NullVector::new(5)), Arc::new(NullVector::new(8)));
307
308 assert_vector_ref_ne(
309 Arc::new(TimeMicrosecondVector::from_values([100, 120])),
310 Arc::new(TimeMicrosecondVector::from_values([200, 220])),
311 );
312
313 assert_vector_ref_ne(
314 Arc::new(IntervalDayTimeVector::from_values([
315 IntervalDayTime::new(1, 1000),
316 IntervalDayTime::new(1, 2000),
317 ])),
318 Arc::new(IntervalDayTimeVector::from_values([
319 IntervalDayTime::new(1, 2100),
320 IntervalDayTime::new(1, 1200),
321 ])),
322 );
323 assert_vector_ref_ne(
324 Arc::new(IntervalMonthDayNanoVector::from_values([
325 IntervalMonthDayNano::new(1, 1, 1000),
326 IntervalMonthDayNano::new(1, 1, 2000),
327 ])),
328 Arc::new(IntervalMonthDayNanoVector::from_values([
329 IntervalMonthDayNano::new(1, 1, 2100),
330 IntervalMonthDayNano::new(1, 1, 1200),
331 ])),
332 );
333 assert_vector_ref_ne(
334 Arc::new(IntervalYearMonthVector::from_values([1000, 2000])),
335 Arc::new(IntervalYearMonthVector::from_values([2100, 1200])),
336 );
337
338 assert_vector_ref_ne(
339 Arc::new(DurationSecondVector::from_values([300, 310])),
340 Arc::new(DurationSecondVector::from_values([300, 320])),
341 );
342
343 assert_vector_ref_ne(
344 Arc::new(Decimal128Vector::from_values([300i128, 310i128])),
345 Arc::new(Decimal128Vector::from_values([300i128, 320i128])),
346 );
347 }
348}