flow/transform/
literal.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::array::TryFromSliceError;
16
17use bytes::Bytes;
18use common_decimal::Decimal128;
19use common_time::timestamp::TimeUnit;
20use common_time::{Date, IntervalMonthDayNano, Timestamp};
21use datafusion_common::ScalarValue;
22use datatypes::data_type::ConcreteDataType as CDT;
23use datatypes::value::Value;
24use num_traits::FromBytes;
25use snafu::OptionExt;
26use substrait::variation_const::{
27    DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF,
28    UNSIGNED_INTEGER_TYPE_VARIATION_REF,
29};
30use substrait_proto::proto;
31use substrait_proto::proto::expression::literal::{LiteralType, PrecisionTimestamp};
32use substrait_proto::proto::expression::Literal;
33use substrait_proto::proto::r#type::Kind;
34
35use crate::error::{Error, NotImplementedSnafu, PlanSnafu, UnexpectedSnafu};
36use crate::transform::substrait_proto;
37
38#[derive(Debug)]
39enum TimestampPrecision {
40    Second = 0,
41    Millisecond = 3,
42    Microsecond = 6,
43    Nanosecond = 9,
44}
45
46impl TryFrom<i32> for TimestampPrecision {
47    type Error = Error;
48
49    fn try_from(prec: i32) -> Result<Self, Self::Error> {
50        match prec {
51            0 => Ok(Self::Second),
52            3 => Ok(Self::Millisecond),
53            6 => Ok(Self::Microsecond),
54            9 => Ok(Self::Nanosecond),
55            _ => not_impl_err!("Unsupported precision: {prec}"),
56        }
57    }
58}
59
60impl TimestampPrecision {
61    fn to_time_unit(&self) -> TimeUnit {
62        match self {
63            Self::Second => TimeUnit::Second,
64            Self::Millisecond => TimeUnit::Millisecond,
65            Self::Microsecond => TimeUnit::Microsecond,
66            Self::Nanosecond => TimeUnit::Nanosecond,
67        }
68    }
69
70    fn to_cdt(&self) -> CDT {
71        match self {
72            Self::Second => CDT::timestamp_second_datatype(),
73            Self::Millisecond => CDT::timestamp_millisecond_datatype(),
74            Self::Microsecond => CDT::timestamp_microsecond_datatype(),
75            Self::Nanosecond => CDT::timestamp_nanosecond_datatype(),
76        }
77    }
78}
79
80/// TODO(discord9): this is copy from datafusion-substrait since the original function is not public, will be replace once is exported
81pub(crate) fn to_substrait_literal(value: &ScalarValue) -> Result<Literal, Error> {
82    if value.is_null() {
83        return not_impl_err!("Unsupported literal: {value:?}");
84    }
85    let (literal_type, type_variation_reference) = match value {
86        ScalarValue::Boolean(Some(b)) => (LiteralType::Boolean(*b), DEFAULT_TYPE_VARIATION_REF),
87        ScalarValue::Int8(Some(n)) => (LiteralType::I8(*n as i32), DEFAULT_TYPE_VARIATION_REF),
88        ScalarValue::UInt8(Some(n)) => (
89            LiteralType::I8(*n as i32),
90            UNSIGNED_INTEGER_TYPE_VARIATION_REF,
91        ),
92        ScalarValue::Int16(Some(n)) => (LiteralType::I16(*n as i32), DEFAULT_TYPE_VARIATION_REF),
93        ScalarValue::UInt16(Some(n)) => (
94            LiteralType::I16(*n as i32),
95            UNSIGNED_INTEGER_TYPE_VARIATION_REF,
96        ),
97        ScalarValue::Int32(Some(n)) => (LiteralType::I32(*n), DEFAULT_TYPE_VARIATION_REF),
98        ScalarValue::UInt32(Some(n)) => (
99            LiteralType::I32(*n as i32),
100            UNSIGNED_INTEGER_TYPE_VARIATION_REF,
101        ),
102        ScalarValue::Int64(Some(n)) => (LiteralType::I64(*n), DEFAULT_TYPE_VARIATION_REF),
103        ScalarValue::UInt64(Some(n)) => (
104            LiteralType::I64(*n as i64),
105            UNSIGNED_INTEGER_TYPE_VARIATION_REF,
106        ),
107        ScalarValue::Float32(Some(f)) => (LiteralType::Fp32(*f), DEFAULT_TYPE_VARIATION_REF),
108        ScalarValue::Float64(Some(f)) => (LiteralType::Fp64(*f), DEFAULT_TYPE_VARIATION_REF),
109        // TODO(discord9): deal with timezone
110        ScalarValue::TimestampSecond(Some(t), _) => (
111            LiteralType::PrecisionTimestamp(PrecisionTimestamp {
112                value: *t,
113                precision: TimestampPrecision::Second as i32,
114            }),
115            DEFAULT_TYPE_VARIATION_REF,
116        ),
117        ScalarValue::TimestampMillisecond(Some(t), _) => (
118            LiteralType::PrecisionTimestamp(PrecisionTimestamp {
119                value: *t,
120                precision: TimestampPrecision::Millisecond as i32,
121            }),
122            DEFAULT_TYPE_VARIATION_REF,
123        ),
124        ScalarValue::TimestampMicrosecond(Some(t), _) => (
125            LiteralType::PrecisionTimestamp(PrecisionTimestamp {
126                value: *t,
127                precision: TimestampPrecision::Microsecond as i32,
128            }),
129            DEFAULT_TYPE_VARIATION_REF,
130        ),
131        ScalarValue::TimestampNanosecond(Some(t), _) => (
132            LiteralType::PrecisionTimestamp(PrecisionTimestamp {
133                value: *t,
134                precision: TimestampPrecision::Nanosecond as i32,
135            }),
136            DEFAULT_TYPE_VARIATION_REF,
137        ),
138        ScalarValue::Date32(Some(d)) => (LiteralType::Date(*d), DATE_32_TYPE_VARIATION_REF),
139        _ => (
140            not_impl_err!("Unsupported literal: {value:?}")?,
141            DEFAULT_TYPE_VARIATION_REF,
142        ),
143    };
144
145    Ok(Literal {
146        nullable: false,
147        type_variation_reference,
148        literal_type: Some(literal_type),
149    })
150}
151
152/// Convert a Substrait literal into a Value and its ConcreteDataType (So that we can know type even if the value is null)
153pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<(Value, CDT), Error> {
154    let scalar_value = match &lit.literal_type {
155        Some(LiteralType::Boolean(b)) => (Value::from(*b), CDT::boolean_datatype()),
156        Some(LiteralType::I8(n)) => match lit.type_variation_reference {
157            DEFAULT_TYPE_VARIATION_REF => (Value::from(*n as i8), CDT::int8_datatype()),
158            UNSIGNED_INTEGER_TYPE_VARIATION_REF => (Value::from(*n as u8), CDT::uint8_datatype()),
159            others => not_impl_err!("Unknown type variation reference {others}",)?,
160        },
161        Some(LiteralType::I16(n)) => match lit.type_variation_reference {
162            DEFAULT_TYPE_VARIATION_REF => (Value::from(*n as i16), CDT::int16_datatype()),
163            UNSIGNED_INTEGER_TYPE_VARIATION_REF => (Value::from(*n as u16), CDT::uint16_datatype()),
164            others => not_impl_err!("Unknown type variation reference {others}",)?,
165        },
166        Some(LiteralType::I32(n)) => match lit.type_variation_reference {
167            DEFAULT_TYPE_VARIATION_REF => (Value::from(*n), CDT::int32_datatype()),
168            UNSIGNED_INTEGER_TYPE_VARIATION_REF => (Value::from(*n as u32), CDT::uint32_datatype()),
169            others => not_impl_err!("Unknown type variation reference {others}",)?,
170        },
171        Some(LiteralType::I64(n)) => match lit.type_variation_reference {
172            DEFAULT_TYPE_VARIATION_REF => (Value::from(*n), CDT::int64_datatype()),
173            UNSIGNED_INTEGER_TYPE_VARIATION_REF => (Value::from(*n as u64), CDT::uint64_datatype()),
174            others => not_impl_err!("Unknown type variation reference {others}",)?,
175        },
176        Some(LiteralType::Fp32(f)) => (Value::from(*f), CDT::float32_datatype()),
177        Some(LiteralType::Fp64(f)) => (Value::from(*f), CDT::float64_datatype()),
178        Some(LiteralType::Timestamp(t)) => (
179            Value::from(Timestamp::new_microsecond(*t)),
180            CDT::timestamp_microsecond_datatype(),
181        ),
182        Some(LiteralType::PrecisionTimestamp(prec_ts)) => {
183            let (prec, val) = (prec_ts.precision, prec_ts.value);
184            let prec = TimestampPrecision::try_from(prec)?;
185            let unit = prec.to_time_unit();
186            let typ = prec.to_cdt();
187            (Value::from(Timestamp::new(val, unit)), typ)
188        }
189        Some(LiteralType::Date(d)) => (Value::from(Date::new(*d)), CDT::date_datatype()),
190        Some(LiteralType::String(s)) => (Value::from(s.clone()), CDT::string_datatype()),
191        Some(LiteralType::Binary(b)) | Some(LiteralType::FixedBinary(b)) => {
192            (Value::from(b.clone()), CDT::binary_datatype())
193        }
194        Some(LiteralType::Decimal(d)) => {
195            let value: [u8; 16] = d.value.clone().try_into().map_err(|e| {
196                PlanSnafu {
197                    reason: format!("Failed to parse decimal value from {e:?}"),
198                }
199                .build()
200            })?;
201            let p: u8 = d.precision.try_into().map_err(|e| {
202                PlanSnafu {
203                    reason: format!("Failed to parse decimal precision: {e}"),
204                }
205                .build()
206            })?;
207            let s: i8 = d.scale.try_into().map_err(|e| {
208                PlanSnafu {
209                    reason: format!("Failed to parse decimal scale: {e}"),
210                }
211                .build()
212            })?;
213            let value = i128::from_le_bytes(value);
214            (
215                Value::from(Decimal128::new(value, p, s)),
216                CDT::decimal128_datatype(p, s),
217            )
218        }
219        Some(LiteralType::Null(ntype)) => (Value::Null, from_substrait_type(ntype)?),
220        Some(LiteralType::IntervalDayToSecond(interval)) => from_interval_day_sec(interval)?,
221        Some(LiteralType::IntervalYearToMonth(interval)) => from_interval_year_month(interval)?,
222        Some(LiteralType::IntervalCompound(interval_compound)) => {
223            let interval_day_time = &interval_compound
224                .interval_day_to_second
225                .map(|i| from_interval_day_sec(&i))
226                .transpose()?;
227            let interval_year_month = &interval_compound
228                .interval_year_to_month
229                .map(|i| from_interval_year_month(&i))
230                .transpose()?;
231            let mut compound = IntervalMonthDayNano::new(0, 0, 0);
232            if let Some(day_sec) = interval_day_time {
233                let Value::IntervalDayTime(day_time) = day_sec.0 else {
234                    UnexpectedSnafu {
235                        reason: format!("Expect IntervalDayTime, found {:?}", day_sec),
236                    }
237                    .fail()?
238                };
239                //  1 day in milliseconds = 24 * 60 * 60 * 1000 = 8.64e7 ms = 8.64e13 ns << 2^63
240                // so overflow is unexpected
241                compound.nanoseconds = compound
242                    .nanoseconds
243                    .checked_add(day_time.milliseconds as i64 * 1_000_000)
244                    .with_context(|| UnexpectedSnafu {
245                        reason: format!(
246                            "Overflow when converting interval: {:?}",
247                            interval_compound
248                        ),
249                    })?;
250                compound.days += day_time.days;
251            }
252
253            if let Some(year_month) = interval_year_month {
254                let Value::IntervalYearMonth(year_month) = year_month.0 else {
255                    UnexpectedSnafu {
256                        reason: format!("Expect IntervalYearMonth, found {:?}", year_month),
257                    }
258                    .fail()?
259                };
260                compound.months += year_month.months;
261            }
262
263            (
264                Value::IntervalMonthDayNano(compound),
265                CDT::interval_month_day_nano_datatype(),
266            )
267        }
268        _ => not_impl_err!("unsupported literal_type: {:?}", &lit.literal_type)?,
269    };
270    Ok(scalar_value)
271}
272
273fn from_interval_day_sec(
274    interval: &proto::expression::literal::IntervalDayToSecond,
275) -> Result<(Value, CDT), Error> {
276    let (days, seconds, subseconds) = (interval.days, interval.seconds, interval.subseconds);
277    let millis = if let Some(prec) = interval.precision_mode {
278        use substrait_proto::proto::expression::literal::interval_day_to_second::PrecisionMode;
279        match prec {
280            PrecisionMode::Precision(e) => {
281                if e >= 3 {
282                    subseconds
283                        / 10_i64
284                            .checked_pow((e - 3) as _)
285                            .with_context(|| UnexpectedSnafu {
286                                reason: format!(
287                                    "Overflow when converting interval: {:?}",
288                                    interval
289                                ),
290                            })?
291                } else {
292                    subseconds
293                        * 10_i64
294                            .checked_pow((3 - e) as _)
295                            .with_context(|| UnexpectedSnafu {
296                                reason: format!(
297                                    "Overflow when converting interval: {:?}",
298                                    interval
299                                ),
300                            })?
301                }
302            }
303            PrecisionMode::Microseconds(_) => subseconds / 1000,
304        }
305    } else if subseconds == 0 {
306        0
307    } else {
308        not_impl_err!("unsupported subseconds without precision_mode: {subseconds}")?
309    };
310
311    let value_interval = common_time::IntervalDayTime::new(days, seconds * 1000 + millis as i32);
312
313    Ok((
314        Value::IntervalDayTime(value_interval),
315        CDT::interval_day_time_datatype(),
316    ))
317}
318
319fn from_interval_year_month(
320    interval: &proto::expression::literal::IntervalYearToMonth,
321) -> Result<(Value, CDT), Error> {
322    let value_interval = common_time::IntervalYearMonth::new(interval.years * 12 + interval.months);
323
324    Ok((
325        Value::IntervalYearMonth(value_interval),
326        CDT::interval_year_month_datatype(),
327    ))
328}
329
330fn from_bytes<T: FromBytes>(i: &Bytes) -> Result<T, Error>
331where
332    for<'a> &'a <T as num_traits::FromBytes>::Bytes:
333        std::convert::TryFrom<&'a [u8], Error = TryFromSliceError>,
334{
335    let (int_bytes, _rest) = i.split_at(std::mem::size_of::<T>());
336    let i = T::from_le_bytes(int_bytes.try_into().map_err(|e| {
337        UnexpectedSnafu {
338            reason: format!(
339                "Expect slice to be {} bytes, found {} bytes, error={:?}",
340                std::mem::size_of::<T>(),
341                int_bytes.len(),
342                e
343            ),
344        }
345        .build()
346    })?);
347    Ok(i)
348}
349
350/// convert a Substrait type into a ConcreteDataType
351pub fn from_substrait_type(null_type: &substrait_proto::proto::Type) -> Result<CDT, Error> {
352    if let Some(kind) = &null_type.kind {
353        match kind {
354            Kind::Bool(_) => Ok(CDT::boolean_datatype()),
355            Kind::I8(integer) => match integer.type_variation_reference {
356                DEFAULT_TYPE_VARIATION_REF => Ok(CDT::int8_datatype()),
357                UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(CDT::uint8_datatype()),
358                v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
359            },
360            Kind::I16(integer) => match integer.type_variation_reference {
361                DEFAULT_TYPE_VARIATION_REF => Ok(CDT::int16_datatype()),
362                UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(CDT::uint16_datatype()),
363                v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
364            },
365            Kind::I32(integer) => match integer.type_variation_reference {
366                DEFAULT_TYPE_VARIATION_REF => Ok(CDT::int32_datatype()),
367                UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(CDT::uint32_datatype()),
368                v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
369            },
370            Kind::I64(integer) => match integer.type_variation_reference {
371                DEFAULT_TYPE_VARIATION_REF => Ok(CDT::int64_datatype()),
372                UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(CDT::uint64_datatype()),
373                v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
374            },
375            Kind::Fp32(_) => Ok(CDT::float32_datatype()),
376            Kind::Fp64(_) => Ok(CDT::float64_datatype()),
377            Kind::PrecisionTimestamp(ts) => {
378                Ok(TimestampPrecision::try_from(ts.precision)?.to_cdt())
379            }
380            Kind::Date(date) => match date.type_variation_reference {
381                DATE_32_TYPE_VARIATION_REF | DATE_64_TYPE_VARIATION_REF => Ok(CDT::date_datatype()),
382                v => not_impl_err!("Unsupported Substrait type variation {v} of type {kind:?}"),
383            },
384            Kind::Binary(_) => Ok(CDT::binary_datatype()),
385            Kind::String(_) => Ok(CDT::string_datatype()),
386            Kind::Decimal(d) => Ok(CDT::decimal128_datatype(d.precision as u8, d.scale as i8)),
387            _ => not_impl_err!("Unsupported Substrait type: {kind:?}"),
388        }
389    } else {
390        not_impl_err!("Null type without kind is not supported")
391    }
392}
393
394#[cfg(test)]
395mod test {
396    use pretty_assertions::assert_eq;
397
398    use super::*;
399    use crate::plan::{Plan, TypedPlan};
400    use crate::repr::{self, ColumnType, RelationType};
401    use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
402    /// test if literal in substrait plan can be correctly converted to flow plan
403    #[tokio::test]
404    async fn test_literal() {
405        let engine = create_test_query_engine();
406        let sql = "SELECT 1 FROM numbers";
407        let plan = sql_to_substrait(engine.clone(), sql).await;
408
409        let mut ctx = create_test_ctx();
410        let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
411
412        let expected = TypedPlan {
413            schema: RelationType::new(vec![ColumnType::new(CDT::int64_datatype(), true)])
414                .into_named(vec![Some("Int64(1)".to_string())]),
415            plan: Plan::Constant {
416                rows: vec![(
417                    repr::Row::new(vec![Value::Int64(1)]),
418                    repr::Timestamp::MIN,
419                    1,
420                )],
421            },
422        };
423
424        assert_eq!(flow_plan.unwrap(), expected);
425    }
426}