substrait/
extension_serializer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use datafusion::error::Result;
18use datafusion::execution::registry::SerializerRegistry;
19use datafusion_common::DataFusionError;
20use datafusion_expr::UserDefinedLogicalNode;
21use promql::extension_plan::{
22    Absent, EmptyMetric, InstantManipulate, RangeManipulate, ScalarCalculate, SeriesDivide,
23    SeriesNormalize,
24};
25
26#[derive(Debug)]
27pub struct ExtensionSerializer;
28
29impl SerializerRegistry for ExtensionSerializer {
30    /// Serialize this node to a byte array. This serialization should not include
31    /// input plans.
32    fn serialize_logical_plan(&self, node: &dyn UserDefinedLogicalNode) -> Result<Vec<u8>> {
33        match node.name() {
34            name if name == InstantManipulate::name() => {
35                let instant_manipulate = node
36                    .as_any()
37                    .downcast_ref::<InstantManipulate>()
38                    .expect("Failed to downcast to InstantManipulate");
39                Ok(instant_manipulate.serialize())
40            }
41            name if name == SeriesNormalize::name() => {
42                let series_normalize = node
43                    .as_any()
44                    .downcast_ref::<SeriesNormalize>()
45                    .expect("Failed to downcast to SeriesNormalize");
46                Ok(series_normalize.serialize())
47            }
48            name if name == RangeManipulate::name() => {
49                let range_manipulate = node
50                    .as_any()
51                    .downcast_ref::<RangeManipulate>()
52                    .expect("Failed to downcast to RangeManipulate");
53                Ok(range_manipulate.serialize())
54            }
55            name if name == ScalarCalculate::name() => {
56                let scalar_calculate = node
57                    .as_any()
58                    .downcast_ref::<ScalarCalculate>()
59                    .expect("Failed to downcast to ScalarCalculate");
60                Ok(scalar_calculate.serialize())
61            }
62            name if name == SeriesDivide::name() => {
63                let series_divide = node
64                    .as_any()
65                    .downcast_ref::<SeriesDivide>()
66                    .expect("Failed to downcast to SeriesDivide");
67                Ok(series_divide.serialize())
68            }
69            name if name == Absent::name() => {
70                let absent = node
71                    .as_any()
72                    .downcast_ref::<Absent>()
73                    .expect("Failed to downcast to Absent");
74                Ok(absent.serialize())
75            }
76            name if name == EmptyMetric::name() => Err(DataFusionError::Substrait(
77                "EmptyMetric should not be serialized".to_string(),
78            )),
79            other => Err(DataFusionError::NotImplemented(format!(
80                "Serizlize logical plan for {}",
81                other
82            ))),
83        }
84    }
85
86    /// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from
87    /// bytes.
88    fn deserialize_logical_plan(
89        &self,
90        name: &str,
91        bytes: &[u8],
92    ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
93        match name {
94            name if name == InstantManipulate::name() => {
95                let instant_manipulate = InstantManipulate::deserialize(bytes)?;
96                Ok(Arc::new(instant_manipulate))
97            }
98            name if name == SeriesNormalize::name() => {
99                let series_normalize = SeriesNormalize::deserialize(bytes)?;
100                Ok(Arc::new(series_normalize))
101            }
102            name if name == RangeManipulate::name() => {
103                let range_manipulate = RangeManipulate::deserialize(bytes)?;
104                Ok(Arc::new(range_manipulate))
105            }
106            name if name == SeriesDivide::name() => {
107                let series_divide = SeriesDivide::deserialize(bytes)?;
108                Ok(Arc::new(series_divide))
109            }
110            name if name == ScalarCalculate::name() => {
111                let scalar_calculate = ScalarCalculate::deserialize(bytes)?;
112                Ok(Arc::new(scalar_calculate))
113            }
114            name if name == Absent::name() => {
115                let absent = Absent::deserialize(bytes)?;
116                Ok(Arc::new(absent))
117            }
118            name if name == EmptyMetric::name() => Err(DataFusionError::Substrait(
119                "EmptyMetric should not be deserialized".to_string(),
120            )),
121            other => Err(DataFusionError::NotImplemented(format!(
122                "Deserialize logical plan for {}",
123                other
124            ))),
125        }
126    }
127}