substrait/
extension_serializer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use datafusion::error::Result;
18use datafusion::execution::registry::SerializerRegistry;
19use datafusion_common::DataFusionError;
20use datafusion_expr::UserDefinedLogicalNode;
21use promql::extension_plan::{
22    Absent, EmptyMetric, HistogramFold, InstantManipulate, RangeManipulate, ScalarCalculate,
23    SeriesDivide, SeriesNormalize, UnionDistinctOn,
24};
25
26#[derive(Debug)]
27pub struct ExtensionSerializer;
28
29impl SerializerRegistry for ExtensionSerializer {
30    /// Serialize this node to a byte array. This serialization should not include
31    /// input plans.
32    fn serialize_logical_plan(&self, node: &dyn UserDefinedLogicalNode) -> Result<Vec<u8>> {
33        match node.name() {
34            name if name == InstantManipulate::name() => {
35                let instant_manipulate = node
36                    .as_any()
37                    .downcast_ref::<InstantManipulate>()
38                    .expect("Failed to downcast to InstantManipulate");
39                Ok(instant_manipulate.serialize())
40            }
41            name if name == SeriesNormalize::name() => {
42                let series_normalize = node
43                    .as_any()
44                    .downcast_ref::<SeriesNormalize>()
45                    .expect("Failed to downcast to SeriesNormalize");
46                Ok(series_normalize.serialize())
47            }
48            name if name == RangeManipulate::name() => {
49                let range_manipulate = node
50                    .as_any()
51                    .downcast_ref::<RangeManipulate>()
52                    .expect("Failed to downcast to RangeManipulate");
53                Ok(range_manipulate.serialize())
54            }
55            name if name == ScalarCalculate::name() => {
56                let scalar_calculate = node
57                    .as_any()
58                    .downcast_ref::<ScalarCalculate>()
59                    .expect("Failed to downcast to ScalarCalculate");
60                Ok(scalar_calculate.serialize())
61            }
62            name if name == SeriesDivide::name() => {
63                let series_divide = node
64                    .as_any()
65                    .downcast_ref::<SeriesDivide>()
66                    .expect("Failed to downcast to SeriesDivide");
67                Ok(series_divide.serialize())
68            }
69            name if name == Absent::name() => {
70                let absent = node
71                    .as_any()
72                    .downcast_ref::<Absent>()
73                    .expect("Failed to downcast to Absent");
74                Ok(absent.serialize())
75            }
76            name if name == HistogramFold::name() => {
77                let histogram_fold = node
78                    .as_any()
79                    .downcast_ref::<HistogramFold>()
80                    .expect("Failed to downcast to HistogramFold");
81                Ok(histogram_fold.serialize())
82            }
83            name if name == UnionDistinctOn::name() => {
84                let union_distinct_on = node
85                    .as_any()
86                    .downcast_ref::<UnionDistinctOn>()
87                    .expect("Failed to downcast to UnionDistinctOn");
88                Ok(union_distinct_on.serialize())
89            }
90            name if name == EmptyMetric::name() => Err(DataFusionError::Substrait(
91                "EmptyMetric should not be serialized".to_string(),
92            )),
93            other => Err(DataFusionError::NotImplemented(format!(
94                "Serizlize logical plan for {}",
95                other
96            ))),
97        }
98    }
99
100    /// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from
101    /// bytes.
102    fn deserialize_logical_plan(
103        &self,
104        name: &str,
105        bytes: &[u8],
106    ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
107        match name {
108            name if name == InstantManipulate::name() => {
109                let instant_manipulate = InstantManipulate::deserialize(bytes)?;
110                Ok(Arc::new(instant_manipulate))
111            }
112            name if name == SeriesNormalize::name() => {
113                let series_normalize = SeriesNormalize::deserialize(bytes)?;
114                Ok(Arc::new(series_normalize))
115            }
116            name if name == RangeManipulate::name() => {
117                let range_manipulate = RangeManipulate::deserialize(bytes)?;
118                Ok(Arc::new(range_manipulate))
119            }
120            name if name == SeriesDivide::name() => {
121                let series_divide = SeriesDivide::deserialize(bytes)?;
122                Ok(Arc::new(series_divide))
123            }
124            name if name == ScalarCalculate::name() => {
125                let scalar_calculate = ScalarCalculate::deserialize(bytes)?;
126                Ok(Arc::new(scalar_calculate))
127            }
128            name if name == Absent::name() => {
129                let absent = Absent::deserialize(bytes)?;
130                Ok(Arc::new(absent))
131            }
132            name if name == HistogramFold::name() => {
133                let histogram_fold = HistogramFold::deserialize(bytes)?;
134                Ok(Arc::new(histogram_fold))
135            }
136            name if name == UnionDistinctOn::name() => {
137                let union_distinct_on = UnionDistinctOn::deserialize(bytes)?;
138                Ok(Arc::new(union_distinct_on))
139            }
140            name if name == EmptyMetric::name() => Err(DataFusionError::Substrait(
141                "EmptyMetric should not be deserialized".to_string(),
142            )),
143            other => Err(DataFusionError::NotImplemented(format!(
144                "Deserialize logical plan for {}",
145                other
146            ))),
147        }
148    }
149}