1use std::array::TryFromSliceError;
16
17use bytes::Bytes;
18use common_decimal::Decimal128;
19use common_time::timestamp::TimeUnit;
20use common_time::{Date, IntervalMonthDayNano, Timestamp};
21use datafusion_common::ScalarValue;
22use datatypes::data_type::ConcreteDataType as CDT;
23use datatypes::value::Value;
24use num_traits::FromBytes;
25use snafu::OptionExt;
26use substrait::variation_const::{
27 DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF,
28 UNSIGNED_INTEGER_TYPE_VARIATION_REF,
29};
30use substrait_proto::proto;
31use substrait_proto::proto::expression::literal::{LiteralType, PrecisionTimestamp};
32use substrait_proto::proto::expression::Literal;
33use substrait_proto::proto::r#type::Kind;
34
35use crate::error::{Error, NotImplementedSnafu, PlanSnafu, UnexpectedSnafu};
36use crate::transform::substrait_proto;
37
38#[derive(Debug)]
39enum TimestampPrecision {
40 Second = 0,
41 Millisecond = 3,
42 Microsecond = 6,
43 Nanosecond = 9,
44}
45
46impl TryFrom<i32> for TimestampPrecision {
47 type Error = Error;
48
49 fn try_from(prec: i32) -> Result<Self, Self::Error> {
50 match prec {
51 0 => Ok(Self::Second),
52 3 => Ok(Self::Millisecond),
53 6 => Ok(Self::Microsecond),
54 9 => Ok(Self::Nanosecond),
55 _ => not_impl_err!("Unsupported precision: {prec}"),
56 }
57 }
58}
59
60impl TimestampPrecision {
61 fn to_time_unit(&self) -> TimeUnit {
62 match self {
63 Self::Second => TimeUnit::Second,
64 Self::Millisecond => TimeUnit::Millisecond,
65 Self::Microsecond => TimeUnit::Microsecond,
66 Self::Nanosecond => TimeUnit::Nanosecond,
67 }
68 }
69
70 fn to_cdt(&self) -> CDT {
71 match self {
72 Self::Second => CDT::timestamp_second_datatype(),
73 Self::Millisecond => CDT::timestamp_millisecond_datatype(),
74 Self::Microsecond => CDT::timestamp_microsecond_datatype(),
75 Self::Nanosecond => CDT::timestamp_nanosecond_datatype(),
76 }
77 }
78}
79
80pub(crate) fn to_substrait_literal(value: &ScalarValue) -> Result<Literal, Error> {
82 if value.is_null() {
83 return not_impl_err!("Unsupported literal: {value:?}");
84 }
85 let (literal_type, type_variation_reference) = match value {
86 ScalarValue::Boolean(Some(b)) => (LiteralType::Boolean(*b), DEFAULT_TYPE_VARIATION_REF),
87 ScalarValue::Int8(Some(n)) => (LiteralType::I8(*n as i32), DEFAULT_TYPE_VARIATION_REF),
88 ScalarValue::UInt8(Some(n)) => (
89 LiteralType::I8(*n as i32),
90 UNSIGNED_INTEGER_TYPE_VARIATION_REF,
91 ),
92 ScalarValue::Int16(Some(n)) => (LiteralType::I16(*n as i32), DEFAULT_TYPE_VARIATION_REF),
93 ScalarValue::UInt16(Some(n)) => (
94 LiteralType::I16(*n as i32),
95 UNSIGNED_INTEGER_TYPE_VARIATION_REF,
96 ),
97 ScalarValue::Int32(Some(n)) => (LiteralType::I32(*n), DEFAULT_TYPE_VARIATION_REF),
98 ScalarValue::UInt32(Some(n)) => (
99 LiteralType::I32(*n as i32),
100 UNSIGNED_INTEGER_TYPE_VARIATION_REF,
101 ),
102 ScalarValue::Int64(Some(n)) => (LiteralType::I64(*n), DEFAULT_TYPE_VARIATION_REF),
103 ScalarValue::UInt64(Some(n)) => (
104 LiteralType::I64(*n as i64),
105 UNSIGNED_INTEGER_TYPE_VARIATION_REF,
106 ),
107 ScalarValue::Float32(Some(f)) => (LiteralType::Fp32(*f), DEFAULT_TYPE_VARIATION_REF),
108 ScalarValue::Float64(Some(f)) => (LiteralType::Fp64(*f), DEFAULT_TYPE_VARIATION_REF),
109 ScalarValue::TimestampSecond(Some(t), _) => (
111 LiteralType::PrecisionTimestamp(PrecisionTimestamp {
112 value: *t,
113 precision: TimestampPrecision::Second as i32,
114 }),
115 DEFAULT_TYPE_VARIATION_REF,
116 ),
117 ScalarValue::TimestampMillisecond(Some(t), _) => (
118 LiteralType::PrecisionTimestamp(PrecisionTimestamp {
119 value: *t,
120 precision: TimestampPrecision::Millisecond as i32,
121 }),
122 DEFAULT_TYPE_VARIATION_REF,
123 ),
124 ScalarValue::TimestampMicrosecond(Some(t), _) => (
125 LiteralType::PrecisionTimestamp(PrecisionTimestamp {
126 value: *t,
127 precision: TimestampPrecision::Microsecond as i32,
128 }),
129 DEFAULT_TYPE_VARIATION_REF,
130 ),
131 ScalarValue::TimestampNanosecond(Some(t), _) => (
132 LiteralType::PrecisionTimestamp(PrecisionTimestamp {
133 value: *t,
134 precision: TimestampPrecision::Nanosecond as i32,
135 }),
136 DEFAULT_TYPE_VARIATION_REF,
137 ),
138 ScalarValue::Date32(Some(d)) => (LiteralType::Date(*d), DATE_32_TYPE_VARIATION_REF),
139 _ => (
140 not_impl_err!("Unsupported literal: {value:?}")?,
141 DEFAULT_TYPE_VARIATION_REF,
142 ),
143 };
144
145 Ok(Literal {
146 nullable: false,
147 type_variation_reference,
148 literal_type: Some(literal_type),
149 })
150}
151
152pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<(Value, CDT), Error> {
154 let scalar_value = match &lit.literal_type {
155 Some(LiteralType::Boolean(b)) => (Value::from(*b), CDT::boolean_datatype()),
156 Some(LiteralType::I8(n)) => match lit.type_variation_reference {
157 DEFAULT_TYPE_VARIATION_REF => (Value::from(*n as i8), CDT::int8_datatype()),
158 UNSIGNED_INTEGER_TYPE_VARIATION_REF => (Value::from(*n as u8), CDT::uint8_datatype()),
159 others => not_impl_err!("Unknown type variation reference {others}",)?,
160 },
161 Some(LiteralType::I16(n)) => match lit.type_variation_reference {
162 DEFAULT_TYPE_VARIATION_REF => (Value::from(*n as i16), CDT::int16_datatype()),
163 UNSIGNED_INTEGER_TYPE_VARIATION_REF => (Value::from(*n as u16), CDT::uint16_datatype()),
164 others => not_impl_err!("Unknown type variation reference {others}",)?,
165 },
166 Some(LiteralType::I32(n)) => match lit.type_variation_reference {
167 DEFAULT_TYPE_VARIATION_REF => (Value::from(*n), CDT::int32_datatype()),
168 UNSIGNED_INTEGER_TYPE_VARIATION_REF => (Value::from(*n as u32), CDT::uint32_datatype()),
169 others => not_impl_err!("Unknown type variation reference {others}",)?,
170 },
171 Some(LiteralType::I64(n)) => match lit.type_variation_reference {
172 DEFAULT_TYPE_VARIATION_REF => (Value::from(*n), CDT::int64_datatype()),
173 UNSIGNED_INTEGER_TYPE_VARIATION_REF => (Value::from(*n as u64), CDT::uint64_datatype()),
174 others => not_impl_err!("Unknown type variation reference {others}",)?,
175 },
176 Some(LiteralType::Fp32(f)) => (Value::from(*f), CDT::float32_datatype()),
177 Some(LiteralType::Fp64(f)) => (Value::from(*f), CDT::float64_datatype()),
178 Some(LiteralType::Timestamp(t)) => (
179 Value::from(Timestamp::new_microsecond(*t)),
180 CDT::timestamp_microsecond_datatype(),
181 ),
182 Some(LiteralType::PrecisionTimestamp(prec_ts)) => {
183 let (prec, val) = (prec_ts.precision, prec_ts.value);
184 let prec = TimestampPrecision::try_from(prec)?;
185 let unit = prec.to_time_unit();
186 let typ = prec.to_cdt();
187 (Value::from(Timestamp::new(val, unit)), typ)
188 }
189 Some(LiteralType::Date(d)) => (Value::from(Date::new(*d)), CDT::date_datatype()),
190 Some(LiteralType::String(s)) => (Value::from(s.clone()), CDT::string_datatype()),
191 Some(LiteralType::Binary(b)) | Some(LiteralType::FixedBinary(b)) => {
192 (Value::from(b.clone()), CDT::binary_datatype())
193 }
194 Some(LiteralType::Decimal(d)) => {
195 let value: [u8; 16] = d.value.clone().try_into().map_err(|e| {
196 PlanSnafu {
197 reason: format!("Failed to parse decimal value from {e:?}"),
198 }
199 .build()
200 })?;
201 let p: u8 = d.precision.try_into().map_err(|e| {
202 PlanSnafu {
203 reason: format!("Failed to parse decimal precision: {e}"),
204 }
205 .build()
206 })?;
207 let s: i8 = d.scale.try_into().map_err(|e| {
208 PlanSnafu {
209 reason: format!("Failed to parse decimal scale: {e}"),
210 }
211 .build()
212 })?;
213 let value = i128::from_le_bytes(value);
214 (
215 Value::from(Decimal128::new(value, p, s)),
216 CDT::decimal128_datatype(p, s),
217 )
218 }
219 Some(LiteralType::Null(ntype)) => (Value::Null, from_substrait_type(ntype)?),
220 Some(LiteralType::IntervalDayToSecond(interval)) => from_interval_day_sec(interval)?,
221 Some(LiteralType::IntervalYearToMonth(interval)) => from_interval_year_month(interval)?,
222 Some(LiteralType::IntervalCompound(interval_compound)) => {
223 let interval_day_time = &interval_compound
224 .interval_day_to_second
225 .map(|i| from_interval_day_sec(&i))
226 .transpose()?;
227 let interval_year_month = &interval_compound
228 .interval_year_to_month
229 .map(|i| from_interval_year_month(&i))
230 .transpose()?;
231 let mut compound = IntervalMonthDayNano::new(0, 0, 0);
232 if let Some(day_sec) = interval_day_time {
233 let Value::IntervalDayTime(day_time) = day_sec.0 else {
234 UnexpectedSnafu {
235 reason: format!("Expect IntervalDayTime, found {:?}", day_sec),
236 }
237 .fail()?
238 };
239 compound.nanoseconds = compound
242 .nanoseconds
243 .checked_add(day_time.milliseconds as i64 * 1_000_000)
244 .with_context(|| UnexpectedSnafu {
245 reason: format!(
246 "Overflow when converting interval: {:?}",
247 interval_compound
248 ),
249 })?;
250 compound.days += day_time.days;
251 }
252
253 if let Some(year_month) = interval_year_month {
254 let Value::IntervalYearMonth(year_month) = year_month.0 else {
255 UnexpectedSnafu {
256 reason: format!("Expect IntervalYearMonth, found {:?}", year_month),
257 }
258 .fail()?
259 };
260 compound.months += year_month.months;
261 }
262
263 (
264 Value::IntervalMonthDayNano(compound),
265 CDT::interval_month_day_nano_datatype(),
266 )
267 }
268 _ => not_impl_err!("unsupported literal_type: {:?}", &lit.literal_type)?,
269 };
270 Ok(scalar_value)
271}
272
273fn from_interval_day_sec(
274 interval: &proto::expression::literal::IntervalDayToSecond,
275) -> Result<(Value, CDT), Error> {
276 let (days, seconds, subseconds) = (interval.days, interval.seconds, interval.subseconds);
277 let millis = if let Some(prec) = interval.precision_mode {
278 use substrait_proto::proto::expression::literal::interval_day_to_second::PrecisionMode;
279 match prec {
280 PrecisionMode::Precision(e) => {
281 if e >= 3 {
282 subseconds
283 / 10_i64
284 .checked_pow((e - 3) as _)
285 .with_context(|| UnexpectedSnafu {
286 reason: format!(
287 "Overflow when converting interval: {:?}",
288 interval
289 ),
290 })?
291 } else {
292 subseconds
293 * 10_i64
294 .checked_pow((3 - e) as _)
295 .with_context(|| UnexpectedSnafu {
296 reason: format!(
297 "Overflow when converting interval: {:?}",
298 interval
299 ),
300 })?
301 }
302 }
303 PrecisionMode::Microseconds(_) => subseconds / 1000,
304 }
305 } else if subseconds == 0 {
306 0
307 } else {
308 not_impl_err!("unsupported subseconds without precision_mode: {subseconds}")?
309 };
310
311 let value_interval = common_time::IntervalDayTime::new(days, seconds * 1000 + millis as i32);
312
313 Ok((
314 Value::IntervalDayTime(value_interval),
315 CDT::interval_day_time_datatype(),
316 ))
317}
318
319fn from_interval_year_month(
320 interval: &proto::expression::literal::IntervalYearToMonth,
321) -> Result<(Value, CDT), Error> {
322 let value_interval = common_time::IntervalYearMonth::new(interval.years * 12 + interval.months);
323
324 Ok((
325 Value::IntervalYearMonth(value_interval),
326 CDT::interval_year_month_datatype(),
327 ))
328}
329
330fn from_bytes<T: FromBytes>(i: &Bytes) -> Result<T, Error>
331where
332 for<'a> &'a <T as num_traits::FromBytes>::Bytes:
333 std::convert::TryFrom<&'a [u8], Error = TryFromSliceError>,
334{
335 let (int_bytes, _rest) = i.split_at(std::mem::size_of::<T>());
336 let i = T::from_le_bytes(int_bytes.try_into().map_err(|e| {
337 UnexpectedSnafu {
338 reason: format!(
339 "Expect slice to be {} bytes, found {} bytes, error={:?}",
340 std::mem::size_of::<T>(),
341 int_bytes.len(),
342 e
343 ),
344 }
345 .build()
346 })?);
347 Ok(i)
348}
349
350pub fn from_substrait_type(null_type: &substrait_proto::proto::Type) -> Result<CDT, Error> {
352 if let Some(kind) = &null_type.kind {
353 match kind {
354 Kind::Bool(_) => Ok(CDT::boolean_datatype()),
355 Kind::I8(integer) => match integer.type_variation_reference {
356 DEFAULT_TYPE_VARIATION_REF => Ok(CDT::int8_datatype()),
357 UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(CDT::uint8_datatype()),
358 v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
359 },
360 Kind::I16(integer) => match integer.type_variation_reference {
361 DEFAULT_TYPE_VARIATION_REF => Ok(CDT::int16_datatype()),
362 UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(CDT::uint16_datatype()),
363 v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
364 },
365 Kind::I32(integer) => match integer.type_variation_reference {
366 DEFAULT_TYPE_VARIATION_REF => Ok(CDT::int32_datatype()),
367 UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(CDT::uint32_datatype()),
368 v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
369 },
370 Kind::I64(integer) => match integer.type_variation_reference {
371 DEFAULT_TYPE_VARIATION_REF => Ok(CDT::int64_datatype()),
372 UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(CDT::uint64_datatype()),
373 v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
374 },
375 Kind::Fp32(_) => Ok(CDT::float32_datatype()),
376 Kind::Fp64(_) => Ok(CDT::float64_datatype()),
377 Kind::PrecisionTimestamp(ts) => {
378 Ok(TimestampPrecision::try_from(ts.precision)?.to_cdt())
379 }
380 Kind::Date(date) => match date.type_variation_reference {
381 DATE_32_TYPE_VARIATION_REF | DATE_64_TYPE_VARIATION_REF => Ok(CDT::date_datatype()),
382 v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
383 },
384 Kind::Binary(_) => Ok(CDT::binary_datatype()),
385 Kind::String(_) => Ok(CDT::string_datatype()),
386 Kind::Decimal(d) => Ok(CDT::decimal128_datatype(d.precision as u8, d.scale as i8)),
387 _ => not_impl_err!("Unsupported Substrait type: {kind:?}"),
388 }
389 } else {
390 not_impl_err!("Null type without kind is not supported")
391 }
392}
393
394#[cfg(test)]
395mod test {
396 use pretty_assertions::assert_eq;
397
398 use super::*;
399 use crate::plan::{Plan, TypedPlan};
400 use crate::repr::{self, ColumnType, RelationType};
401 use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
402 #[tokio::test]
404 async fn test_literal() {
405 let engine = create_test_query_engine();
406 let sql = "SELECT 1 FROM numbers";
407 let plan = sql_to_substrait(engine.clone(), sql).await;
408
409 let mut ctx = create_test_ctx();
410 let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
411
412 let expected = TypedPlan {
413 schema: RelationType::new(vec![ColumnType::new(CDT::int64_datatype(), true)])
414 .into_named(vec![Some("Int64(1)".to_string())]),
415 plan: Plan::Constant {
416 rows: vec![(
417 repr::Row::new(vec![Value::Int64(1)]),
418 repr::Timestamp::MIN,
419 1,
420 )],
421 },
422 };
423
424 assert_eq!(flow_plan.unwrap(), expected);
425 }
426}