use std::sync::Arc;
use common_error::ext::BoxedError;
use common_function::function_registry::FUNCTION_REGISTRY;
use common_function::scalars::udf::create_udf;
use common_query::logical_plan::SubstraitPlanDecoder;
use datafusion::catalog::CatalogProviderList;
use datafusion::common::DataFusionError;
use datafusion::error::Result;
use datafusion::execution::context::SessionState;
use datafusion::execution::registry::SerializerRegistry;
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::LogicalPlan;
use datafusion_expr::UserDefinedLogicalNode;
use greptime_proto::substrait_extension::MergeScan as PbMergeScan;
use prost::Message;
use session::context::QueryContextRef;
use snafu::ResultExt;
use substrait::extension_serializer::ExtensionSerializer;
use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
use crate::dist_plan::MergeScanLogicalPlan;
use crate::error::DataFusionSnafu;
pub struct DefaultSerializer;
impl SerializerRegistry for DefaultSerializer {
fn serialize_logical_plan(&self, node: &dyn UserDefinedLogicalNode) -> Result<Vec<u8>> {
if node.name() == MergeScanLogicalPlan::name() {
let merge_scan = node
.as_any()
.downcast_ref::<MergeScanLogicalPlan>()
.expect("Failed to downcast to MergeScanLogicalPlan");
let input = merge_scan.input();
let is_placeholder = merge_scan.is_placeholder();
let input = DFLogicalSubstraitConvertor
.encode(input, DefaultSerializer)
.map_err(|e| DataFusionError::External(Box::new(e)))?
.to_vec();
Ok(PbMergeScan {
is_placeholder,
input,
}
.encode_to_vec())
} else {
ExtensionSerializer.serialize_logical_plan(node)
}
}
fn deserialize_logical_plan(
&self,
name: &str,
bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
if name == MergeScanLogicalPlan::name() {
Err(DataFusionError::Substrait(format!(
"Unsupported plan node: {name}"
)))
} else {
ExtensionSerializer.deserialize_logical_plan(name, bytes)
}
}
}
pub struct DefaultPlanDecoder {
session_state: SessionState,
}
impl DefaultPlanDecoder {
pub fn new(
mut session_state: SessionState,
query_ctx: &QueryContextRef,
) -> crate::error::Result<Self> {
for func in FUNCTION_REGISTRY.functions() {
let udf = Arc::new(create_udf(func, query_ctx.clone(), Default::default()).into());
session_state.register_udf(udf).context(DataFusionSnafu)?;
}
Ok(Self { session_state })
}
}
#[async_trait::async_trait]
impl SubstraitPlanDecoder for DefaultPlanDecoder {
async fn decode(
&self,
message: bytes::Bytes,
catalog_list: Arc<dyn CatalogProviderList>,
optimize: bool,
) -> common_query::error::Result<LogicalPlan> {
let logical_plan = DFLogicalSubstraitConvertor
.decode(message, catalog_list.clone(), self.session_state.clone())
.await
.map_err(BoxedError::new)
.context(common_query::error::DecodePlanSnafu)?;
if optimize {
self.session_state
.optimize(&logical_plan)
.context(common_query::error::GeneralDataFusionSnafu)
} else {
Ok(logical_plan)
}
}
}
#[cfg(test)]
mod tests {
use session::context::QueryContext;
use super::*;
use crate::dummy_catalog::DummyCatalogList;
use crate::optimizer::test_util::mock_table_provider;
use crate::plan::tests::mock_plan;
use crate::QueryEngineFactory;
#[tokio::test]
async fn test_serializer_decode_plan() {
let catalog_list = catalog::memory::new_memory_catalog_manager().unwrap();
let factory = QueryEngineFactory::new(catalog_list, None, None, None, None, false);
let engine = factory.query_engine();
let plan = mock_plan();
let bytes = DFLogicalSubstraitConvertor
.encode(&plan, DefaultSerializer)
.unwrap();
let plan_decoder = engine
.engine_context(QueryContext::arc())
.new_plan_decoder()
.unwrap();
let table_provider = Arc::new(mock_table_provider(1.into()));
let catalog_list = Arc::new(DummyCatalogList::with_table_provider(table_provider));
let decode_plan = plan_decoder
.decode(bytes, catalog_list, false)
.await
.unwrap();
assert_eq!(
"Filter: devices.k0 > Int32(500)
TableScan: devices projection=[k0, ts, v0]",
format!("{:?}", decode_plan),
);
}
}