1use std::sync::Arc;
16
17use common_time::interval::IntervalUnit;
18
19use crate::data_type::DataType;
20use crate::types::{DurationType, TimeType, TimestampType};
21use crate::vectors::constant::ConstantVector;
22use crate::vectors::{
23 BinaryVector, BooleanVector, DateVector, Decimal128Vector, DurationMicrosecondVector,
24 DurationMillisecondVector, DurationNanosecondVector, DurationSecondVector,
25 IntervalDayTimeVector, IntervalMonthDayNanoVector, IntervalYearMonthVector, ListVector,
26 PrimitiveVector, StringVector, TimeMicrosecondVector, TimeMillisecondVector,
27 TimeNanosecondVector, TimeSecondVector, TimestampMicrosecondVector, TimestampMillisecondVector,
28 TimestampNanosecondVector, TimestampSecondVector, Vector,
29};
30use crate::with_match_primitive_type_id;
31
32impl Eq for dyn Vector + '_ {}
33
34impl PartialEq for dyn Vector + '_ {
35 fn eq(&self, other: &dyn Vector) -> bool {
36 equal(self, other)
37 }
38}
39
40impl PartialEq<dyn Vector> for Arc<dyn Vector + '_> {
41 fn eq(&self, other: &dyn Vector) -> bool {
42 equal(&**self, other)
43 }
44}
45
46macro_rules! is_vector_eq {
47 ($VectorType: ident, $lhs: ident, $rhs: ident) => {{
48 let lhs = $lhs.as_any().downcast_ref::<$VectorType>().unwrap();
49 let rhs = $rhs.as_any().downcast_ref::<$VectorType>().unwrap();
50
51 lhs == rhs
52 }};
53}
54
55fn equal(lhs: &dyn Vector, rhs: &dyn Vector) -> bool {
56 if lhs.data_type() != rhs.data_type() || lhs.len() != rhs.len() {
57 return false;
58 }
59
60 if lhs.is_const() || rhs.is_const() {
61 return equal(
64 &**lhs
65 .as_any()
66 .downcast_ref::<ConstantVector>()
67 .unwrap()
68 .inner(),
69 &**lhs
70 .as_any()
71 .downcast_ref::<ConstantVector>()
72 .unwrap()
73 .inner(),
74 );
75 }
76
77 use crate::data_type::ConcreteDataType::*;
78
79 let lhs_type = lhs.data_type();
80 match lhs.data_type() {
81 Null(_) => true,
82 Boolean(_) => is_vector_eq!(BooleanVector, lhs, rhs),
83 Binary(_) | Json(_) | Vector(_) => is_vector_eq!(BinaryVector, lhs, rhs),
84 String(_) => is_vector_eq!(StringVector, lhs, rhs),
85 Date(_) => is_vector_eq!(DateVector, lhs, rhs),
86 Timestamp(t) => match t {
87 TimestampType::Second(_) => {
88 is_vector_eq!(TimestampSecondVector, lhs, rhs)
89 }
90 TimestampType::Millisecond(_) => {
91 is_vector_eq!(TimestampMillisecondVector, lhs, rhs)
92 }
93 TimestampType::Microsecond(_) => {
94 is_vector_eq!(TimestampMicrosecondVector, lhs, rhs)
95 }
96 TimestampType::Nanosecond(_) => {
97 is_vector_eq!(TimestampNanosecondVector, lhs, rhs)
98 }
99 },
100 Interval(v) => match v.unit() {
101 IntervalUnit::YearMonth => {
102 is_vector_eq!(IntervalYearMonthVector, lhs, rhs)
103 }
104 IntervalUnit::DayTime => {
105 is_vector_eq!(IntervalDayTimeVector, lhs, rhs)
106 }
107 IntervalUnit::MonthDayNano => {
108 is_vector_eq!(IntervalMonthDayNanoVector, lhs, rhs)
109 }
110 },
111 List(_) => is_vector_eq!(ListVector, lhs, rhs),
112 UInt8(_) | UInt16(_) | UInt32(_) | UInt64(_) | Int8(_) | Int16(_) | Int32(_) | Int64(_)
113 | Float32(_) | Float64(_) | Dictionary(_) => {
114 with_match_primitive_type_id!(lhs_type.logical_type_id(), |$T| {
115 let lhs = lhs.as_any().downcast_ref::<PrimitiveVector<$T>>().unwrap();
116 let rhs = rhs.as_any().downcast_ref::<PrimitiveVector<$T>>().unwrap();
117
118 lhs == rhs
119 },
120 {
121 unreachable!("should not compare {} with {}", lhs.vector_type_name(), rhs.vector_type_name())
122 })
123 }
124
125 Time(t) => match t {
126 TimeType::Second(_) => {
127 is_vector_eq!(TimeSecondVector, lhs, rhs)
128 }
129 TimeType::Millisecond(_) => {
130 is_vector_eq!(TimeMillisecondVector, lhs, rhs)
131 }
132 TimeType::Microsecond(_) => {
133 is_vector_eq!(TimeMicrosecondVector, lhs, rhs)
134 }
135 TimeType::Nanosecond(_) => {
136 is_vector_eq!(TimeNanosecondVector, lhs, rhs)
137 }
138 },
139 Duration(d) => match d {
140 DurationType::Second(_) => {
141 is_vector_eq!(DurationSecondVector, lhs, rhs)
142 }
143 DurationType::Millisecond(_) => {
144 is_vector_eq!(DurationMillisecondVector, lhs, rhs)
145 }
146 DurationType::Microsecond(_) => {
147 is_vector_eq!(DurationMicrosecondVector, lhs, rhs)
148 }
149 DurationType::Nanosecond(_) => {
150 is_vector_eq!(DurationNanosecondVector, lhs, rhs)
151 }
152 },
153 Decimal128(_) => {
154 is_vector_eq!(Decimal128Vector, lhs, rhs)
155 }
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano};
162
163 use super::*;
164 use crate::vectors::{
165 list, DurationMicrosecondVector, DurationMillisecondVector, DurationNanosecondVector,
166 DurationSecondVector, Float32Vector, Float64Vector, Int16Vector, Int32Vector, Int64Vector,
167 Int8Vector, NullVector, UInt16Vector, UInt32Vector, UInt64Vector, UInt8Vector, VectorRef,
168 };
169
170 fn assert_vector_ref_eq(vector: VectorRef) {
171 let rhs = vector.clone();
172 assert_eq!(vector, rhs);
173 assert_dyn_vector_eq(&*vector, &*rhs);
174 }
175
176 fn assert_dyn_vector_eq(lhs: &dyn Vector, rhs: &dyn Vector) {
177 assert_eq!(lhs, rhs);
178 }
179
180 fn assert_vector_ref_ne(lhs: VectorRef, rhs: VectorRef) {
181 assert_ne!(lhs, rhs);
182 }
183
184 #[test]
185 fn test_vector_eq() {
186 assert_vector_ref_eq(Arc::new(BinaryVector::from(vec![
187 Some(b"hello".to_vec()),
188 Some(b"world".to_vec()),
189 ])));
190 assert_vector_ref_eq(Arc::new(BooleanVector::from(vec![true, false])));
191 assert_vector_ref_eq(Arc::new(ConstantVector::new(
192 Arc::new(BooleanVector::from(vec![true])),
193 5,
194 )));
195 assert_vector_ref_eq(Arc::new(BooleanVector::from(vec![true, false])));
196 assert_vector_ref_eq(Arc::new(DateVector::from(vec![Some(100), Some(120)])));
197 assert_vector_ref_eq(Arc::new(TimestampSecondVector::from_values([100, 120])));
198 assert_vector_ref_eq(Arc::new(TimestampMillisecondVector::from_values([
199 100, 120,
200 ])));
201 assert_vector_ref_eq(Arc::new(TimestampMicrosecondVector::from_values([
202 100, 120,
203 ])));
204 assert_vector_ref_eq(Arc::new(TimestampNanosecondVector::from_values([100, 120])));
205
206 let list_vector = list::tests::new_list_vector(&[
207 Some(vec![Some(1), Some(2)]),
208 None,
209 Some(vec![Some(3), Some(4)]),
210 ]);
211 assert_vector_ref_eq(Arc::new(list_vector));
212
213 assert_vector_ref_eq(Arc::new(NullVector::new(4)));
214 assert_vector_ref_eq(Arc::new(StringVector::from(vec![
215 Some("hello"),
216 Some("world"),
217 ])));
218
219 assert_vector_ref_eq(Arc::new(Int8Vector::from_slice([1, 2, 3, 4])));
220 assert_vector_ref_eq(Arc::new(UInt8Vector::from_slice([1, 2, 3, 4])));
221 assert_vector_ref_eq(Arc::new(Int16Vector::from_slice([1, 2, 3, 4])));
222 assert_vector_ref_eq(Arc::new(UInt16Vector::from_slice([1, 2, 3, 4])));
223 assert_vector_ref_eq(Arc::new(Int32Vector::from_slice([1, 2, 3, 4])));
224 assert_vector_ref_eq(Arc::new(UInt32Vector::from_slice([1, 2, 3, 4])));
225 assert_vector_ref_eq(Arc::new(Int64Vector::from_slice([1, 2, 3, 4])));
226 assert_vector_ref_eq(Arc::new(UInt64Vector::from_slice([1, 2, 3, 4])));
227 assert_vector_ref_eq(Arc::new(Float32Vector::from_slice([1.0, 2.0, 3.0, 4.0])));
228 assert_vector_ref_eq(Arc::new(Float64Vector::from_slice([1.0, 2.0, 3.0, 4.0])));
229
230 assert_vector_ref_eq(Arc::new(TimeSecondVector::from_values([100, 120])));
231 assert_vector_ref_eq(Arc::new(TimeMillisecondVector::from_values([100, 120])));
232 assert_vector_ref_eq(Arc::new(TimeMicrosecondVector::from_values([100, 120])));
233 assert_vector_ref_eq(Arc::new(TimeNanosecondVector::from_values([100, 120])));
234
235 assert_vector_ref_eq(Arc::new(IntervalYearMonthVector::from_values([
236 1000, 2000, 3000, 4000,
237 ])));
238 assert_vector_ref_eq(Arc::new(IntervalDayTimeVector::from_values([
239 IntervalDayTime::new(1, 1000),
240 IntervalDayTime::new(1, 2000),
241 IntervalDayTime::new(1, 3000),
242 IntervalDayTime::new(1, 4000),
243 ])));
244 assert_vector_ref_eq(Arc::new(IntervalMonthDayNanoVector::from_values([
245 IntervalMonthDayNano::new(1, 1, 1000),
246 IntervalMonthDayNano::new(1, 1, 2000),
247 IntervalMonthDayNano::new(1, 1, 3000),
248 IntervalMonthDayNano::new(1, 1, 4000),
249 ])));
250 assert_vector_ref_eq(Arc::new(DurationSecondVector::from_values([300, 310])));
251 assert_vector_ref_eq(Arc::new(DurationMillisecondVector::from_values([300, 310])));
252 assert_vector_ref_eq(Arc::new(DurationMicrosecondVector::from_values([300, 310])));
253 assert_vector_ref_eq(Arc::new(DurationNanosecondVector::from_values([300, 310])));
254 assert_vector_ref_eq(Arc::new(Decimal128Vector::from_values(vec![
255 1i128, 2i128, 3i128,
256 ])));
257 }
258
259 #[test]
260 fn test_vector_ne() {
261 assert_vector_ref_ne(
262 Arc::new(Int32Vector::from_slice([1, 2, 3, 4])),
263 Arc::new(Int32Vector::from_slice([1, 2])),
264 );
265 assert_vector_ref_ne(
266 Arc::new(Int32Vector::from_slice([1, 2, 3, 4])),
267 Arc::new(Int8Vector::from_slice([1, 2, 3, 4])),
268 );
269 assert_vector_ref_ne(
270 Arc::new(Int32Vector::from_slice([1, 2, 3, 4])),
271 Arc::new(BooleanVector::from(vec![true, true])),
272 );
273 assert_vector_ref_ne(
274 Arc::new(ConstantVector::new(
275 Arc::new(BooleanVector::from(vec![true])),
276 5,
277 )),
278 Arc::new(ConstantVector::new(
279 Arc::new(BooleanVector::from(vec![true])),
280 4,
281 )),
282 );
283 assert_vector_ref_ne(
284 Arc::new(ConstantVector::new(
285 Arc::new(BooleanVector::from(vec![true])),
286 5,
287 )),
288 Arc::new(ConstantVector::new(
289 Arc::new(BooleanVector::from(vec![false])),
290 4,
291 )),
292 );
293 assert_vector_ref_ne(
294 Arc::new(ConstantVector::new(
295 Arc::new(BooleanVector::from(vec![true])),
296 5,
297 )),
298 Arc::new(ConstantVector::new(
299 Arc::new(Int32Vector::from_slice(vec![1])),
300 4,
301 )),
302 );
303 assert_vector_ref_ne(Arc::new(NullVector::new(5)), Arc::new(NullVector::new(8)));
304
305 assert_vector_ref_ne(
306 Arc::new(TimeMicrosecondVector::from_values([100, 120])),
307 Arc::new(TimeMicrosecondVector::from_values([200, 220])),
308 );
309
310 assert_vector_ref_ne(
311 Arc::new(IntervalDayTimeVector::from_values([
312 IntervalDayTime::new(1, 1000),
313 IntervalDayTime::new(1, 2000),
314 ])),
315 Arc::new(IntervalDayTimeVector::from_values([
316 IntervalDayTime::new(1, 2100),
317 IntervalDayTime::new(1, 1200),
318 ])),
319 );
320 assert_vector_ref_ne(
321 Arc::new(IntervalMonthDayNanoVector::from_values([
322 IntervalMonthDayNano::new(1, 1, 1000),
323 IntervalMonthDayNano::new(1, 1, 2000),
324 ])),
325 Arc::new(IntervalMonthDayNanoVector::from_values([
326 IntervalMonthDayNano::new(1, 1, 2100),
327 IntervalMonthDayNano::new(1, 1, 1200),
328 ])),
329 );
330 assert_vector_ref_ne(
331 Arc::new(IntervalYearMonthVector::from_values([1000, 2000])),
332 Arc::new(IntervalYearMonthVector::from_values([2100, 1200])),
333 );
334
335 assert_vector_ref_ne(
336 Arc::new(DurationSecondVector::from_values([300, 310])),
337 Arc::new(DurationSecondVector::from_values([300, 320])),
338 );
339
340 assert_vector_ref_ne(
341 Arc::new(Decimal128Vector::from_values([300i128, 310i128])),
342 Arc::new(Decimal128Vector::from_values([300i128, 320i128])),
343 );
344 }
345}