1use std::sync::Arc;
16
17use arrow::datatypes::DataType as ArrowDataType;
18use num_traits::Num;
19use serde::{Deserialize, Serialize};
20
21use crate::data_type::{DataType, DataTypeRef};
22use crate::scalars::ScalarVectorBuilder;
23use crate::type_id::LogicalTypeId;
24use crate::value::Value;
25use crate::vectors::{BooleanVectorBuilder, MutableVector};
26
27#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
28pub struct BooleanType;
29
30impl BooleanType {
31 pub fn arc() -> DataTypeRef {
32 Arc::new(Self)
33 }
34}
35
36impl DataType for BooleanType {
37 fn name(&self) -> String {
38 "Boolean".to_string()
39 }
40
41 fn logical_type_id(&self) -> LogicalTypeId {
42 LogicalTypeId::Boolean
43 }
44
45 fn default_value(&self) -> Value {
46 bool::default().into()
47 }
48
49 fn as_arrow_type(&self) -> ArrowDataType {
50 ArrowDataType::Boolean
51 }
52
53 fn create_mutable_vector(&self, capacity: usize) -> Box<dyn MutableVector> {
54 Box::new(BooleanVectorBuilder::with_capacity(capacity))
55 }
56
57 fn try_cast(&self, from: Value) -> Option<Value> {
58 match from {
59 Value::Boolean(v) => Some(Value::Boolean(v)),
60 Value::UInt8(v) => numeric_to_bool(v),
61 Value::UInt16(v) => numeric_to_bool(v),
62 Value::UInt32(v) => numeric_to_bool(v),
63 Value::UInt64(v) => numeric_to_bool(v),
64 Value::Int8(v) => numeric_to_bool(v),
65 Value::Int16(v) => numeric_to_bool(v),
66 Value::Int32(v) => numeric_to_bool(v),
67 Value::Int64(v) => numeric_to_bool(v),
68 Value::Float32(v) => numeric_to_bool(v),
69 Value::Float64(v) => numeric_to_bool(v),
70 Value::String(v) => v.as_utf8().parse::<bool>().ok().map(Value::Boolean),
71 _ => None,
72 }
73 }
74}
75
76pub fn numeric_to_bool<T>(num: T) -> Option<Value>
77where
78 T: Num + Default,
79{
80 if num != T::default() {
81 Some(Value::Boolean(true))
82 } else {
83 Some(Value::Boolean(false))
84 }
85}
86
87pub fn bool_to_numeric<T>(value: bool) -> Option<T>
88where
89 T: Num,
90{
91 if value {
92 Some(T::one())
93 } else {
94 Some(T::zero())
95 }
96}
97
98#[cfg(test)]
99mod tests {
100
101 use ordered_float::OrderedFloat;
102
103 use super::*;
104 use crate::data_type::ConcreteDataType;
105
106 macro_rules! test_cast_to_bool {
107 ($value: expr, $expected: expr) => {
108 let val = $value;
109 let b = ConcreteDataType::boolean_datatype().try_cast(val).unwrap();
110 assert_eq!(b, Value::Boolean($expected));
111 };
112 }
113
114 macro_rules! test_cast_from_bool {
115 ($value: expr, $datatype: expr, $expected: expr) => {
116 let val = Value::Boolean($value);
117 let b = $datatype.try_cast(val).unwrap();
118 assert_eq!(b, $expected);
119 };
120 }
121
122 #[test]
123 fn test_other_type_cast_to_bool() {
124 test_cast_to_bool!(Value::UInt8(0), false);
126 test_cast_to_bool!(Value::UInt16(0), false);
127 test_cast_to_bool!(Value::UInt32(0), false);
128 test_cast_to_bool!(Value::UInt64(0), false);
129 test_cast_to_bool!(Value::Int8(0), false);
130 test_cast_to_bool!(Value::Int16(0), false);
131 test_cast_to_bool!(Value::Int32(0), false);
132 test_cast_to_bool!(Value::Int64(0), false);
133 test_cast_to_bool!(Value::Float32(OrderedFloat(0.0)), false);
134 test_cast_to_bool!(Value::Float64(OrderedFloat(0.0)), false);
135 test_cast_to_bool!(Value::UInt8(1), true);
137 test_cast_to_bool!(Value::UInt16(2), true);
138 test_cast_to_bool!(Value::UInt32(3), true);
139 test_cast_to_bool!(Value::UInt64(4), true);
140 test_cast_to_bool!(Value::Int8(5), true);
141 test_cast_to_bool!(Value::Int16(6), true);
142 test_cast_to_bool!(Value::Int32(7), true);
143 test_cast_to_bool!(Value::Int64(8), true);
144 test_cast_to_bool!(Value::Float32(OrderedFloat(1.0)), true);
145 test_cast_to_bool!(Value::Float64(OrderedFloat(2.0)), true);
146 }
147
148 #[test]
149 fn test_bool_cast_to_other_type() {
150 test_cast_from_bool!(false, ConcreteDataType::uint8_datatype(), Value::UInt8(0));
152 test_cast_from_bool!(false, ConcreteDataType::uint16_datatype(), Value::UInt16(0));
153 test_cast_from_bool!(false, ConcreteDataType::uint32_datatype(), Value::UInt32(0));
154 test_cast_from_bool!(false, ConcreteDataType::uint64_datatype(), Value::UInt64(0));
155 test_cast_from_bool!(false, ConcreteDataType::int8_datatype(), Value::Int8(0));
156 test_cast_from_bool!(false, ConcreteDataType::int16_datatype(), Value::Int16(0));
157 test_cast_from_bool!(false, ConcreteDataType::int32_datatype(), Value::Int32(0));
158 test_cast_from_bool!(false, ConcreteDataType::int64_datatype(), Value::Int64(0));
159 test_cast_from_bool!(
160 false,
161 ConcreteDataType::float32_datatype(),
162 Value::Float32(OrderedFloat(0.0))
163 );
164 test_cast_from_bool!(
165 false,
166 ConcreteDataType::float64_datatype(),
167 Value::Float64(OrderedFloat(0.0))
168 );
169 test_cast_from_bool!(true, ConcreteDataType::uint8_datatype(), Value::UInt8(1));
171 test_cast_from_bool!(true, ConcreteDataType::uint16_datatype(), Value::UInt16(1));
172 test_cast_from_bool!(true, ConcreteDataType::uint32_datatype(), Value::UInt32(1));
173 test_cast_from_bool!(true, ConcreteDataType::uint64_datatype(), Value::UInt64(1));
174 test_cast_from_bool!(true, ConcreteDataType::int8_datatype(), Value::Int8(1));
175 test_cast_from_bool!(true, ConcreteDataType::int16_datatype(), Value::Int16(1));
176 test_cast_from_bool!(true, ConcreteDataType::int32_datatype(), Value::Int32(1));
177 test_cast_from_bool!(true, ConcreteDataType::int64_datatype(), Value::Int64(1));
178 test_cast_from_bool!(
179 true,
180 ConcreteDataType::float32_datatype(),
181 Value::Float32(OrderedFloat(1.0))
182 );
183 test_cast_from_bool!(
184 true,
185 ConcreteDataType::float64_datatype(),
186 Value::Float64(OrderedFloat(1.0))
187 );
188 }
189}