datatypes/types/
primitive_type.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Ordering;
16use std::fmt;
17
18use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType as ArrowDataType};
19use common_time::Date;
20use serde::{Deserialize, Serialize};
21use snafu::OptionExt;
22
23use crate::data_type::{ConcreteDataType, DataType};
24use crate::error::{self, Result};
25use crate::scalars::{Scalar, ScalarRef, ScalarVectorBuilder};
26use crate::type_id::LogicalTypeId;
27use crate::types::boolean_type::bool_to_numeric;
28use crate::types::DateType;
29use crate::value::{Value, ValueRef};
30use crate::vectors::{MutableVector, PrimitiveVector, PrimitiveVectorBuilder, Vector};
31
32// TODO(yingwen): Can we remove `Into<serde_json::Value>`?
33/// Represents the wrapper type that wraps a native type using the `newtype pattern`,
34/// such as [Date](`common_time::Date`) is a wrapper type for the underlying native
35/// type `i32`.
36pub trait WrapperType:
37    Copy
38    + Send
39    + Sync
40    + fmt::Debug
41    + for<'a> Scalar<RefType<'a> = Self>
42    + PartialEq
43    + Into<Value>
44    + Into<ValueRef<'static>>
45    + Serialize
46    + Into<serde_json::Value>
47{
48    /// Logical primitive type that this wrapper type belongs to.
49    type LogicalType: LogicalPrimitiveType<Wrapper = Self, Native = Self::Native>;
50    /// The underlying native type.
51    type Native: ArrowNativeType;
52
53    /// Convert native type into this wrapper type.
54    fn from_native(value: Self::Native) -> Self;
55
56    /// Convert this wrapper type into native type.
57    fn into_native(self) -> Self::Native;
58}
59
60/// Trait bridging the logical primitive type with [ArrowPrimitiveType].
61pub trait LogicalPrimitiveType: 'static + Sized {
62    /// Arrow primitive type of this logical type.
63    type ArrowPrimitive: ArrowPrimitiveType<Native = Self::Native>;
64    /// Native (physical) type of this logical type.
65    type Native: ArrowNativeType;
66    /// Wrapper type that the vector returns.
67    type Wrapper: WrapperType<LogicalType = Self, Native = Self::Native>
68        + for<'a> Scalar<VectorType = PrimitiveVector<Self>, RefType<'a> = Self::Wrapper>
69        + for<'a> ScalarRef<'a, ScalarType = Self::Wrapper>;
70    /// Largest type this primitive type can cast to.
71    type LargestType: LogicalPrimitiveType;
72
73    /// Construct the data type struct.
74    fn build_data_type() -> ConcreteDataType;
75
76    /// Return the name of the type.
77    fn type_name() -> &'static str;
78
79    /// Dynamic cast the vector to the concrete vector type.
80    fn cast_vector(vector: &dyn Vector) -> Result<&PrimitiveVector<Self>>;
81
82    /// Cast value ref to the primitive type.
83    fn cast_value_ref(value: ValueRef) -> Result<Option<Self::Wrapper>>;
84}
85
86/// A new type for [WrapperType], complement the `Ord` feature for it.
87///
88/// Wrapping non ordered primitive types like `f32` and `f64` in `OrdPrimitive`
89/// can make them be used in places that require `Ord`. For example, in `Median` UDAFs.
90#[derive(Debug, Clone, Copy, PartialEq)]
91pub struct OrdPrimitive<T: WrapperType>(pub T);
92
93impl<T: WrapperType> OrdPrimitive<T> {
94    pub fn as_primitive(&self) -> T::Native {
95        self.0.into_native()
96    }
97}
98
99impl<T: WrapperType> Eq for OrdPrimitive<T> {}
100
101impl<T: WrapperType> PartialOrd for OrdPrimitive<T> {
102    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
103        Some(self.cmp(other))
104    }
105}
106
107impl<T: WrapperType> Ord for OrdPrimitive<T> {
108    fn cmp(&self, other: &Self) -> Ordering {
109        Into::<Value>::into(self.0).cmp(&Into::<Value>::into(other.0))
110    }
111}
112
113impl<T: WrapperType> From<OrdPrimitive<T>> for Value {
114    fn from(p: OrdPrimitive<T>) -> Self {
115        p.0.into()
116    }
117}
118
119macro_rules! impl_wrapper {
120    ($Type: ident, $LogicalType: ident) => {
121        impl WrapperType for $Type {
122            type LogicalType = $LogicalType;
123            type Native = $Type;
124
125            fn from_native(value: Self::Native) -> Self {
126                value
127            }
128
129            fn into_native(self) -> Self::Native {
130                self
131            }
132        }
133    };
134}
135
136impl_wrapper!(u8, UInt8Type);
137impl_wrapper!(u16, UInt16Type);
138impl_wrapper!(u32, UInt32Type);
139impl_wrapper!(u64, UInt64Type);
140impl_wrapper!(i8, Int8Type);
141impl_wrapper!(i16, Int16Type);
142impl_wrapper!(i32, Int32Type);
143impl_wrapper!(i64, Int64Type);
144impl_wrapper!(f32, Float32Type);
145impl_wrapper!(f64, Float64Type);
146
147impl WrapperType for Date {
148    type LogicalType = DateType;
149    type Native = i32;
150
151    fn from_native(value: i32) -> Self {
152        Date::new(value)
153    }
154
155    fn into_native(self) -> i32 {
156        self.val()
157    }
158}
159
160macro_rules! define_logical_primitive_type {
161    ($Native: ident, $TypeId: ident, $DataType: ident, $Largest: ident) => {
162        // We need to define it as an empty struct `struct DataType {}` instead of a struct-unit
163        // `struct DataType;` to ensure the serialized JSON string is compatible with previous
164        // implementation.
165        #[derive(
166            Debug, Clone, Default, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize,
167        )]
168        pub struct $DataType {}
169
170        impl LogicalPrimitiveType for $DataType {
171            type ArrowPrimitive = arrow::datatypes::$DataType;
172            type Native = $Native;
173            type Wrapper = $Native;
174            type LargestType = $Largest;
175
176            fn build_data_type() -> ConcreteDataType {
177                ConcreteDataType::$TypeId($DataType::default())
178            }
179
180            fn type_name() -> &'static str {
181                stringify!($TypeId)
182            }
183
184            fn cast_vector(vector: &dyn Vector) -> Result<&PrimitiveVector<$DataType>> {
185                vector
186                    .as_any()
187                    .downcast_ref::<PrimitiveVector<$DataType>>()
188                    .with_context(|| error::CastTypeSnafu {
189                        msg: format!(
190                            "Failed to cast {} to vector of primitive type {}",
191                            vector.vector_type_name(),
192                            stringify!($TypeId)
193                        ),
194                    })
195            }
196
197            fn cast_value_ref(value: ValueRef) -> Result<Option<$Native>> {
198                match value {
199                    ValueRef::Null => Ok(None),
200                    ValueRef::$TypeId(v) => Ok(Some(v.into())),
201                    other => error::CastTypeSnafu {
202                        msg: format!(
203                            "Failed to cast value {:?} to primitive type {}",
204                            other,
205                            stringify!($TypeId),
206                        ),
207                    }
208                    .fail(),
209                }
210            }
211        }
212    };
213}
214
215macro_rules! define_non_timestamp_primitive {
216    ( $Native: ident, $TypeId: ident, $DataType: ident, $Largest: ident $(, $TargetType: ident)* ) => {
217        define_logical_primitive_type!($Native, $TypeId, $DataType, $Largest);
218
219        impl DataType for $DataType {
220            fn name(&self) -> String {
221                stringify!($TypeId).to_string()
222            }
223
224            fn logical_type_id(&self) -> LogicalTypeId {
225                LogicalTypeId::$TypeId
226            }
227
228            fn default_value(&self) -> Value {
229                $Native::default().into()
230            }
231
232            fn as_arrow_type(&self) -> ArrowDataType {
233                ArrowDataType::$TypeId
234            }
235
236            fn create_mutable_vector(&self, capacity: usize) -> Box<dyn MutableVector> {
237                Box::new(PrimitiveVectorBuilder::<$DataType>::with_capacity(capacity))
238            }
239
240
241            fn try_cast(&self, from: Value) -> Option<Value> {
242                match from {
243                    Value::Boolean(v) => bool_to_numeric(v).map(Value::$TypeId),
244                    Value::String(v) => v.as_utf8().parse::<$Native>().map(|val| Value::from(val)).ok(),
245                    $(
246                        Value::$TargetType(v) => num::cast::cast(v).map(Value::$TypeId),
247                    )*
248                    _ => None,
249                }
250            }
251        }
252    };
253}
254
255define_non_timestamp_primitive!(
256    u8, UInt8, UInt8Type, UInt64Type, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64,
257    Float32, Float64
258);
259define_non_timestamp_primitive!(
260    u16, UInt16, UInt16Type, UInt64Type, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64,
261    Float32, Float64
262);
263define_non_timestamp_primitive!(
264    u32, UInt32, UInt32Type, UInt64Type, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64,
265    Float32, Float64
266);
267define_non_timestamp_primitive!(
268    u64, UInt64, UInt64Type, UInt64Type, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64,
269    Float32, Float64
270);
271define_non_timestamp_primitive!(
272    i8, Int8, Int8Type, Int64Type, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64,
273    Float32, Float64
274);
275define_non_timestamp_primitive!(
276    i16, Int16, Int16Type, Int64Type, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64,
277    Float32, Float64
278);
279
280define_non_timestamp_primitive!(
281    f32,
282    Float32,
283    Float32Type,
284    Float64Type,
285    Int8,
286    Int16,
287    Int32,
288    Int64,
289    UInt8,
290    UInt16,
291    UInt32,
292    UInt64,
293    Float32,
294    Float64
295);
296define_non_timestamp_primitive!(
297    f64,
298    Float64,
299    Float64Type,
300    Float64Type,
301    Int8,
302    Int16,
303    Int32,
304    Int64,
305    UInt8,
306    UInt16,
307    UInt32,
308    UInt64,
309    Float32,
310    Float64
311);
312
313// Timestamp primitive:
314define_logical_primitive_type!(i64, Int64, Int64Type, Int64Type);
315
316define_logical_primitive_type!(i32, Int32, Int32Type, Int64Type);
317
318impl DataType for Int64Type {
319    fn name(&self) -> String {
320        "Int64".to_string()
321    }
322
323    fn logical_type_id(&self) -> LogicalTypeId {
324        LogicalTypeId::Int64
325    }
326
327    fn default_value(&self) -> Value {
328        Value::Int64(0)
329    }
330
331    fn as_arrow_type(&self) -> ArrowDataType {
332        ArrowDataType::Int64
333    }
334
335    fn create_mutable_vector(&self, capacity: usize) -> Box<dyn MutableVector> {
336        Box::new(PrimitiveVectorBuilder::<Int64Type>::with_capacity(capacity))
337    }
338
339    fn try_cast(&self, from: Value) -> Option<Value> {
340        match from {
341            Value::Boolean(v) => bool_to_numeric(v).map(Value::Int64),
342            Value::Int8(v) => num::cast::cast(v).map(Value::Int64),
343            Value::Int16(v) => num::cast::cast(v).map(Value::Int64),
344            Value::Int32(v) => num::cast::cast(v).map(Value::Int64),
345            Value::Int64(v) => Some(Value::Int64(v)),
346            Value::UInt8(v) => num::cast::cast(v).map(Value::Int64),
347            Value::UInt16(v) => num::cast::cast(v).map(Value::Int64),
348            Value::UInt32(v) => num::cast::cast(v).map(Value::Int64),
349            Value::Float32(v) => num::cast::cast(v).map(Value::Int64),
350            Value::Float64(v) => num::cast::cast(v).map(Value::Int64),
351            Value::String(v) => v.as_utf8().parse::<i64>().map(Value::Int64).ok(),
352            Value::Timestamp(v) => Some(Value::Int64(v.value())),
353            Value::Time(v) => Some(Value::Int64(v.value())),
354            // We don't allow casting interval type to int.
355            _ => None,
356        }
357    }
358}
359
360impl DataType for Int32Type {
361    fn name(&self) -> String {
362        "Int32".to_string()
363    }
364
365    fn logical_type_id(&self) -> LogicalTypeId {
366        LogicalTypeId::Int32
367    }
368
369    fn default_value(&self) -> Value {
370        Value::Int32(0)
371    }
372
373    fn as_arrow_type(&self) -> ArrowDataType {
374        ArrowDataType::Int32
375    }
376
377    fn create_mutable_vector(&self, capacity: usize) -> Box<dyn MutableVector> {
378        Box::new(PrimitiveVectorBuilder::<Int32Type>::with_capacity(capacity))
379    }
380
381    fn try_cast(&self, from: Value) -> Option<Value> {
382        match from {
383            Value::Boolean(v) => bool_to_numeric(v).map(Value::Int32),
384            Value::Int8(v) => num::cast::cast(v).map(Value::Int32),
385            Value::Int16(v) => num::cast::cast(v).map(Value::Int32),
386            Value::Int32(v) => Some(Value::Int32(v)),
387            Value::Int64(v) => num::cast::cast(v).map(Value::Int64),
388            Value::UInt8(v) => num::cast::cast(v).map(Value::Int32),
389            Value::UInt16(v) => num::cast::cast(v).map(Value::Int32),
390            Value::UInt32(v) => num::cast::cast(v).map(Value::UInt32),
391            Value::UInt64(v) => num::cast::cast(v).map(Value::UInt64),
392            Value::Float32(v) => num::cast::cast(v).map(Value::Int32),
393            Value::Float64(v) => num::cast::cast(v).map(Value::Int32),
394            Value::String(v) => v.as_utf8().parse::<i32>().map(Value::Int32).ok(),
395            Value::Date(v) => Some(Value::Int32(v.val())),
396            // We don't allow casting interval type to int.
397            _ => None,
398        }
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use std::collections::BinaryHeap;
405
406    use ordered_float::OrderedFloat;
407
408    use super::*;
409
410    #[test]
411    fn test_ord_primitive() {
412        struct Foo<T>
413        where
414            T: WrapperType,
415        {
416            heap: BinaryHeap<OrdPrimitive<T>>,
417        }
418
419        impl<T> Foo<T>
420        where
421            T: WrapperType,
422        {
423            fn push(&mut self, value: T) {
424                let value = OrdPrimitive::<T>(value);
425                self.heap.push(value);
426            }
427        }
428
429        macro_rules! test {
430            ($Type:ident) => {
431                let mut foo = Foo::<$Type> {
432                    heap: BinaryHeap::new(),
433                };
434                foo.push($Type::default());
435                assert_eq!($Type::default(), foo.heap.pop().unwrap().as_primitive());
436            };
437        }
438
439        test!(u8);
440        test!(u16);
441        test!(u32);
442        test!(u64);
443        test!(i8);
444        test!(i16);
445        test!(i32);
446        test!(i64);
447        test!(f32);
448        test!(f64);
449    }
450
451    macro_rules! assert_primitive_cast {
452        ($value: expr, $datatype:expr, $expected: expr) => {
453            let val = $value;
454            let b = $datatype.try_cast(val).unwrap();
455            assert_eq!(b, $expected);
456        };
457    }
458
459    #[test]
460    fn test_primitive_cast() {
461        // Integer cast
462        assert_primitive_cast!(
463            Value::UInt8(123),
464            ConcreteDataType::uint16_datatype(),
465            Value::UInt16(123)
466        );
467
468        assert_primitive_cast!(
469            Value::UInt8(123),
470            ConcreteDataType::uint32_datatype(),
471            Value::UInt32(123)
472        );
473        assert_primitive_cast!(
474            Value::UInt8(123),
475            ConcreteDataType::uint64_datatype(),
476            Value::UInt64(123)
477        );
478        assert_primitive_cast!(
479            Value::UInt16(1234),
480            ConcreteDataType::uint32_datatype(),
481            Value::UInt32(1234)
482        );
483        assert_primitive_cast!(
484            Value::UInt16(1234),
485            ConcreteDataType::uint64_datatype(),
486            Value::UInt64(1234)
487        );
488        assert_primitive_cast!(
489            Value::UInt32(12345),
490            ConcreteDataType::uint64_datatype(),
491            Value::UInt64(12345)
492        );
493
494        assert_primitive_cast!(
495            Value::UInt8(123),
496            ConcreteDataType::uint16_datatype(),
497            Value::UInt16(123)
498        );
499
500        assert_primitive_cast!(
501            Value::Int8(123),
502            ConcreteDataType::int32_datatype(),
503            Value::Int32(123)
504        );
505        assert_primitive_cast!(
506            Value::Int8(123),
507            ConcreteDataType::int64_datatype(),
508            Value::Int64(123)
509        );
510        assert_primitive_cast!(
511            Value::Int16(1234),
512            ConcreteDataType::int32_datatype(),
513            Value::Int32(1234)
514        );
515        assert_primitive_cast!(
516            Value::Int16(1234),
517            ConcreteDataType::int64_datatype(),
518            Value::Int64(1234)
519        );
520        assert_primitive_cast!(
521            Value::Int32(12345),
522            ConcreteDataType::int64_datatype(),
523            Value::Int64(12345)
524        );
525    }
526
527    #[test]
528    fn test_float_cast() {
529        // cast to Float32
530        assert_primitive_cast!(
531            Value::UInt8(12),
532            ConcreteDataType::float32_datatype(),
533            Value::Float32(OrderedFloat(12.0))
534        );
535        assert_primitive_cast!(
536            Value::UInt16(12),
537            ConcreteDataType::float32_datatype(),
538            Value::Float32(OrderedFloat(12.0))
539        );
540        assert_primitive_cast!(
541            Value::Int8(12),
542            ConcreteDataType::float32_datatype(),
543            Value::Float32(OrderedFloat(12.0))
544        );
545        assert_primitive_cast!(
546            Value::Int16(12),
547            ConcreteDataType::float32_datatype(),
548            Value::Float32(OrderedFloat(12.0))
549        );
550        assert_primitive_cast!(
551            Value::Int32(12),
552            ConcreteDataType::float32_datatype(),
553            Value::Float32(OrderedFloat(12.0))
554        );
555
556        // cast to Float64
557        assert_primitive_cast!(
558            Value::UInt8(12),
559            ConcreteDataType::float64_datatype(),
560            Value::Float64(OrderedFloat(12.0))
561        );
562        assert_primitive_cast!(
563            Value::UInt16(12),
564            ConcreteDataType::float64_datatype(),
565            Value::Float64(OrderedFloat(12.0))
566        );
567        assert_primitive_cast!(
568            Value::UInt32(12),
569            ConcreteDataType::float64_datatype(),
570            Value::Float64(OrderedFloat(12.0))
571        );
572        assert_primitive_cast!(
573            Value::Int8(12),
574            ConcreteDataType::float64_datatype(),
575            Value::Float64(OrderedFloat(12.0))
576        );
577        assert_primitive_cast!(
578            Value::Int16(12),
579            ConcreteDataType::float64_datatype(),
580            Value::Float64(OrderedFloat(12.0))
581        );
582        assert_primitive_cast!(
583            Value::Int32(12),
584            ConcreteDataType::float64_datatype(),
585            Value::Float64(OrderedFloat(12.0))
586        );
587        assert_primitive_cast!(
588            Value::Int64(12),
589            ConcreteDataType::float64_datatype(),
590            Value::Float64(OrderedFloat(12.0))
591        );
592    }
593
594    #[test]
595    fn test_string_cast_to_primitive() {
596        assert_primitive_cast!(
597            Value::String("123".into()),
598            ConcreteDataType::uint8_datatype(),
599            Value::UInt8(123)
600        );
601        assert_primitive_cast!(
602            Value::String("123".into()),
603            ConcreteDataType::uint16_datatype(),
604            Value::UInt16(123)
605        );
606        assert_primitive_cast!(
607            Value::String("123".into()),
608            ConcreteDataType::uint32_datatype(),
609            Value::UInt32(123)
610        );
611        assert_primitive_cast!(
612            Value::String("123".into()),
613            ConcreteDataType::uint64_datatype(),
614            Value::UInt64(123)
615        );
616        assert_primitive_cast!(
617            Value::String("123".into()),
618            ConcreteDataType::int8_datatype(),
619            Value::Int8(123)
620        );
621        assert_primitive_cast!(
622            Value::String("123".into()),
623            ConcreteDataType::int16_datatype(),
624            Value::Int16(123)
625        );
626        assert_primitive_cast!(
627            Value::String("123".into()),
628            ConcreteDataType::int32_datatype(),
629            Value::Int32(123)
630        );
631        assert_primitive_cast!(
632            Value::String("123".into()),
633            ConcreteDataType::int64_datatype(),
634            Value::Int64(123)
635        );
636        assert_primitive_cast!(
637            Value::String("1.23".into()),
638            ConcreteDataType::float32_datatype(),
639            Value::Float32(OrderedFloat(1.23))
640        );
641        assert_primitive_cast!(
642            Value::String("1.23".into()),
643            ConcreteDataType::float64_datatype(),
644            Value::Float64(OrderedFloat(1.23))
645        );
646    }
647}