datatypes/types/
boolean_type.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use arrow::datatypes::DataType as ArrowDataType;
18use num_traits::Num;
19use serde::{Deserialize, Serialize};
20
21use crate::data_type::{DataType, DataTypeRef};
22use crate::scalars::ScalarVectorBuilder;
23use crate::type_id::LogicalTypeId;
24use crate::value::Value;
25use crate::vectors::{BooleanVectorBuilder, MutableVector};
26
27#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
28pub struct BooleanType;
29
30impl BooleanType {
31    pub fn arc() -> DataTypeRef {
32        Arc::new(Self)
33    }
34}
35
36impl DataType for BooleanType {
37    fn name(&self) -> String {
38        "Boolean".to_string()
39    }
40
41    fn logical_type_id(&self) -> LogicalTypeId {
42        LogicalTypeId::Boolean
43    }
44
45    fn default_value(&self) -> Value {
46        bool::default().into()
47    }
48
49    fn as_arrow_type(&self) -> ArrowDataType {
50        ArrowDataType::Boolean
51    }
52
53    fn create_mutable_vector(&self, capacity: usize) -> Box<dyn MutableVector> {
54        Box::new(BooleanVectorBuilder::with_capacity(capacity))
55    }
56
57    fn try_cast(&self, from: Value) -> Option<Value> {
58        match from {
59            Value::Boolean(v) => Some(Value::Boolean(v)),
60            Value::UInt8(v) => numeric_to_bool(v),
61            Value::UInt16(v) => numeric_to_bool(v),
62            Value::UInt32(v) => numeric_to_bool(v),
63            Value::UInt64(v) => numeric_to_bool(v),
64            Value::Int8(v) => numeric_to_bool(v),
65            Value::Int16(v) => numeric_to_bool(v),
66            Value::Int32(v) => numeric_to_bool(v),
67            Value::Int64(v) => numeric_to_bool(v),
68            Value::Float32(v) => numeric_to_bool(v),
69            Value::Float64(v) => numeric_to_bool(v),
70            Value::String(v) => v.as_utf8().parse::<bool>().ok().map(Value::Boolean),
71            _ => None,
72        }
73    }
74}
75
76pub fn numeric_to_bool<T>(num: T) -> Option<Value>
77where
78    T: Num + Default,
79{
80    if num != T::default() {
81        Some(Value::Boolean(true))
82    } else {
83        Some(Value::Boolean(false))
84    }
85}
86
87pub fn bool_to_numeric<T>(value: bool) -> Option<T>
88where
89    T: Num,
90{
91    if value {
92        Some(T::one())
93    } else {
94        Some(T::zero())
95    }
96}
97
98#[cfg(test)]
99mod tests {
100
101    use ordered_float::OrderedFloat;
102
103    use super::*;
104    use crate::data_type::ConcreteDataType;
105
106    macro_rules! test_cast_to_bool {
107        ($value: expr, $expected: expr) => {
108            let val = $value;
109            let b = ConcreteDataType::boolean_datatype().try_cast(val).unwrap();
110            assert_eq!(b, Value::Boolean($expected));
111        };
112    }
113
114    macro_rules! test_cast_from_bool {
115        ($value: expr, $datatype: expr, $expected: expr) => {
116            let val = Value::Boolean($value);
117            let b = $datatype.try_cast(val).unwrap();
118            assert_eq!(b, $expected);
119        };
120    }
121
122    #[test]
123    fn test_other_type_cast_to_bool() {
124        // false cases
125        test_cast_to_bool!(Value::UInt8(0), false);
126        test_cast_to_bool!(Value::UInt16(0), false);
127        test_cast_to_bool!(Value::UInt32(0), false);
128        test_cast_to_bool!(Value::UInt64(0), false);
129        test_cast_to_bool!(Value::Int8(0), false);
130        test_cast_to_bool!(Value::Int16(0), false);
131        test_cast_to_bool!(Value::Int32(0), false);
132        test_cast_to_bool!(Value::Int64(0), false);
133        test_cast_to_bool!(Value::Float32(OrderedFloat(0.0)), false);
134        test_cast_to_bool!(Value::Float64(OrderedFloat(0.0)), false);
135        // true cases
136        test_cast_to_bool!(Value::UInt8(1), true);
137        test_cast_to_bool!(Value::UInt16(2), true);
138        test_cast_to_bool!(Value::UInt32(3), true);
139        test_cast_to_bool!(Value::UInt64(4), true);
140        test_cast_to_bool!(Value::Int8(5), true);
141        test_cast_to_bool!(Value::Int16(6), true);
142        test_cast_to_bool!(Value::Int32(7), true);
143        test_cast_to_bool!(Value::Int64(8), true);
144        test_cast_to_bool!(Value::Float32(OrderedFloat(1.0)), true);
145        test_cast_to_bool!(Value::Float64(OrderedFloat(2.0)), true);
146    }
147
148    #[test]
149    fn test_bool_cast_to_other_type() {
150        // false cases
151        test_cast_from_bool!(false, ConcreteDataType::uint8_datatype(), Value::UInt8(0));
152        test_cast_from_bool!(false, ConcreteDataType::uint16_datatype(), Value::UInt16(0));
153        test_cast_from_bool!(false, ConcreteDataType::uint32_datatype(), Value::UInt32(0));
154        test_cast_from_bool!(false, ConcreteDataType::uint64_datatype(), Value::UInt64(0));
155        test_cast_from_bool!(false, ConcreteDataType::int8_datatype(), Value::Int8(0));
156        test_cast_from_bool!(false, ConcreteDataType::int16_datatype(), Value::Int16(0));
157        test_cast_from_bool!(false, ConcreteDataType::int32_datatype(), Value::Int32(0));
158        test_cast_from_bool!(false, ConcreteDataType::int64_datatype(), Value::Int64(0));
159        test_cast_from_bool!(
160            false,
161            ConcreteDataType::float32_datatype(),
162            Value::Float32(OrderedFloat(0.0))
163        );
164        test_cast_from_bool!(
165            false,
166            ConcreteDataType::float64_datatype(),
167            Value::Float64(OrderedFloat(0.0))
168        );
169        // true cases
170        test_cast_from_bool!(true, ConcreteDataType::uint8_datatype(), Value::UInt8(1));
171        test_cast_from_bool!(true, ConcreteDataType::uint16_datatype(), Value::UInt16(1));
172        test_cast_from_bool!(true, ConcreteDataType::uint32_datatype(), Value::UInt32(1));
173        test_cast_from_bool!(true, ConcreteDataType::uint64_datatype(), Value::UInt64(1));
174        test_cast_from_bool!(true, ConcreteDataType::int8_datatype(), Value::Int8(1));
175        test_cast_from_bool!(true, ConcreteDataType::int16_datatype(), Value::Int16(1));
176        test_cast_from_bool!(true, ConcreteDataType::int32_datatype(), Value::Int32(1));
177        test_cast_from_bool!(true, ConcreteDataType::int64_datatype(), Value::Int64(1));
178        test_cast_from_bool!(
179            true,
180            ConcreteDataType::float32_datatype(),
181            Value::Float32(OrderedFloat(1.0))
182        );
183        test_cast_from_bool!(
184            true,
185            ConcreteDataType::float64_datatype(),
186            Value::Float64(OrderedFloat(1.0))
187        );
188    }
189}