substrait/
extension_serializer.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use datafusion::error::Result;
18use datafusion::execution::registry::SerializerRegistry;
19use datafusion_common::DataFusionError;
20use datafusion_expr::UserDefinedLogicalNode;
21use promql::extension_plan::{
22    EmptyMetric, InstantManipulate, RangeManipulate, ScalarCalculate, SeriesDivide, SeriesNormalize,
23};
24
25#[derive(Debug)]
26pub struct ExtensionSerializer;
27
28impl SerializerRegistry for ExtensionSerializer {
29    /// Serialize this node to a byte array. This serialization should not include
30    /// input plans.
31    fn serialize_logical_plan(&self, node: &dyn UserDefinedLogicalNode) -> Result<Vec<u8>> {
32        match node.name() {
33            name if name == InstantManipulate::name() => {
34                let instant_manipulate = node
35                    .as_any()
36                    .downcast_ref::<InstantManipulate>()
37                    .expect("Failed to downcast to InstantManipulate");
38                Ok(instant_manipulate.serialize())
39            }
40            name if name == SeriesNormalize::name() => {
41                let series_normalize = node
42                    .as_any()
43                    .downcast_ref::<SeriesNormalize>()
44                    .expect("Failed to downcast to SeriesNormalize");
45                Ok(series_normalize.serialize())
46            }
47            name if name == RangeManipulate::name() => {
48                let range_manipulate = node
49                    .as_any()
50                    .downcast_ref::<RangeManipulate>()
51                    .expect("Failed to downcast to RangeManipulate");
52                Ok(range_manipulate.serialize())
53            }
54            name if name == ScalarCalculate::name() => {
55                let scalar_calculate = node
56                    .as_any()
57                    .downcast_ref::<ScalarCalculate>()
58                    .expect("Failed to downcast to ScalarCalculate");
59                Ok(scalar_calculate.serialize())
60            }
61            name if name == SeriesDivide::name() => {
62                let series_divide = node
63                    .as_any()
64                    .downcast_ref::<SeriesDivide>()
65                    .expect("Failed to downcast to SeriesDivide");
66                Ok(series_divide.serialize())
67            }
68            name if name == EmptyMetric::name() => Err(DataFusionError::Substrait(
69                "EmptyMetric should not be serialized".to_string(),
70            )),
71            other => Err(DataFusionError::NotImplemented(format!(
72                "Serizlize logical plan for {}",
73                other
74            ))),
75        }
76    }
77
78    /// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from
79    /// bytes.
80    fn deserialize_logical_plan(
81        &self,
82        name: &str,
83        bytes: &[u8],
84    ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
85        match name {
86            name if name == InstantManipulate::name() => {
87                let instant_manipulate = InstantManipulate::deserialize(bytes)?;
88                Ok(Arc::new(instant_manipulate))
89            }
90            name if name == SeriesNormalize::name() => {
91                let series_normalize = SeriesNormalize::deserialize(bytes)?;
92                Ok(Arc::new(series_normalize))
93            }
94            name if name == RangeManipulate::name() => {
95                let range_manipulate = RangeManipulate::deserialize(bytes)?;
96                Ok(Arc::new(range_manipulate))
97            }
98            name if name == SeriesDivide::name() => {
99                let series_divide = SeriesDivide::deserialize(bytes)?;
100                Ok(Arc::new(series_divide))
101            }
102            name if name == ScalarCalculate::name() => {
103                let scalar_calculate = ScalarCalculate::deserialize(bytes)?;
104                Ok(Arc::new(scalar_calculate))
105            }
106            name if name == EmptyMetric::name() => Err(DataFusionError::Substrait(
107                "EmptyMetric should not be deserialized".to_string(),
108            )),
109            other => Err(DataFusionError::NotImplemented(format!(
110                "Deserialize logical plan for {}",
111                other
112            ))),
113        }
114    }
115}