Skip to main content

datanode/
query_stream.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::sync::{Arc, RwLock};
17use std::task::{Context, Poll};
18
19use common_recordbatch::adapter::RecordBatchMetrics;
20use common_recordbatch::error::Result as RecordBatchResult;
21use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream};
22use common_runtime::JoinHandle;
23use datatypes::schema::SchemaRef;
24use futures_util::Stream;
25use tokio::sync::mpsc;
26
27pub type QueryRuntimeSender = mpsc::Sender<RecordBatchResult<RecordBatch>>;
28pub type QueryRuntimeMetrics = Arc<RwLock<Option<RecordBatchMetrics>>>;
29
30/// A record batch stream backed by batches produced on the datanode query runtime.
31pub struct QueryRuntimeStream {
32    schema: SchemaRef,
33    receiver: mpsc::Receiver<RecordBatchResult<RecordBatch>>,
34    output_ordering: Option<Vec<OrderOption>>,
35    metrics: QueryRuntimeMetrics,
36    producer_handle: Option<JoinHandle<()>>,
37}
38
39impl QueryRuntimeStream {
40    pub fn new(
41        schema: SchemaRef,
42        receiver: mpsc::Receiver<RecordBatchResult<RecordBatch>>,
43    ) -> Self {
44        Self {
45            schema,
46            receiver,
47            output_ordering: None,
48            metrics: Default::default(),
49            producer_handle: None,
50        }
51    }
52
53    pub fn with_output_ordering(mut self, output_ordering: Option<Vec<OrderOption>>) -> Self {
54        self.output_ordering = output_ordering;
55        self
56    }
57
58    pub fn with_metrics(self, metrics: Option<RecordBatchMetrics>) -> Self {
59        *self.metrics.write().unwrap() = metrics;
60        self
61    }
62
63    pub fn with_metrics_store(mut self, metrics: QueryRuntimeMetrics) -> Self {
64        self.metrics = metrics;
65        self
66    }
67
68    pub fn with_producer_handle(mut self, handle: JoinHandle<()>) -> Self {
69        self.producer_handle = Some(handle);
70        self
71    }
72
73    pub fn metrics_store() -> QueryRuntimeMetrics {
74        Default::default()
75    }
76}
77
78impl Stream for QueryRuntimeStream {
79    type Item = RecordBatchResult<RecordBatch>;
80
81    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
82        self.receiver.poll_recv(cx)
83    }
84}
85
86impl Drop for QueryRuntimeStream {
87    fn drop(&mut self) {
88        if let Some(handle) = self.producer_handle.take() {
89            handle.abort();
90        }
91    }
92}
93
94impl RecordBatchStream for QueryRuntimeStream {
95    fn name(&self) -> &str {
96        "QueryRuntimeStream"
97    }
98
99    fn schema(&self) -> SchemaRef {
100        self.schema.clone()
101    }
102
103    fn output_ordering(&self) -> Option<&[OrderOption]> {
104        self.output_ordering.as_deref()
105    }
106
107    fn metrics(&self) -> Option<RecordBatchMetrics> {
108        self.metrics.read().unwrap().clone()
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use std::sync::Arc;
115
116    use common_recordbatch::error::CreateRecordBatchesSnafu;
117    use datatypes::prelude::{ConcreteDataType, VectorRef};
118    use datatypes::schema::{ColumnSchema, Schema};
119    use datatypes::vectors::Int32Vector;
120    use futures_util::StreamExt;
121
122    use super::*;
123
124    fn test_batch() -> RecordBatch {
125        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
126            "v",
127            ConcreteDataType::int32_datatype(),
128            false,
129        )]));
130        let values: VectorRef = Arc::new(Int32Vector::from_slice([1]));
131        RecordBatch::new(schema, vec![values]).unwrap()
132    }
133
134    #[tokio::test]
135    async fn test_query_runtime_stream_receives_batches() {
136        let batch = test_batch();
137        let schema = batch.schema.clone();
138        let (tx, rx) = mpsc::channel(1);
139        tx.send(Ok(batch)).await.unwrap();
140        drop(tx);
141
142        let mut stream = QueryRuntimeStream::new(schema, rx);
143        let batch = stream.next().await.unwrap().unwrap();
144        assert_eq!(1, batch.num_rows());
145        assert!(stream.next().await.is_none());
146    }
147
148    #[tokio::test]
149    async fn test_query_runtime_stream_forwards_errors() {
150        let schema = test_batch().schema.clone();
151        let (tx, rx) = mpsc::channel(1);
152        tx.send(Err(CreateRecordBatchesSnafu {
153            reason: "test error",
154        }
155        .build()))
156            .await
157            .unwrap();
158        drop(tx);
159
160        let mut stream = QueryRuntimeStream::new(schema, rx);
161        assert!(stream.next().await.unwrap().is_err());
162    }
163
164    #[tokio::test]
165    async fn test_query_runtime_stream_reads_shared_metrics() {
166        let schema = test_batch().schema.clone();
167        let (tx, rx) = mpsc::channel(1);
168        drop(tx);
169        let metrics = QueryRuntimeStream::metrics_store();
170        let stream = QueryRuntimeStream::new(schema, rx).with_metrics_store(metrics.clone());
171
172        assert!(stream.metrics().is_none());
173        *metrics.write().unwrap() = Some(RecordBatchMetrics {
174            elapsed_compute: 42,
175            ..Default::default()
176        });
177
178        assert_eq!(42, stream.metrics().unwrap().elapsed_compute);
179    }
180
181    #[tokio::test]
182    async fn test_query_runtime_stream_drop_aborts_producer() {
183        struct AbortGuard(Option<tokio::sync::oneshot::Sender<()>>);
184
185        impl Drop for AbortGuard {
186            fn drop(&mut self) {
187                let _ = self.0.take().unwrap().send(());
188            }
189        }
190
191        let schema = test_batch().schema.clone();
192        let (_tx, rx) = mpsc::channel(1);
193        let (abort_tx, abort_rx) = tokio::sync::oneshot::channel();
194        let (started_tx, started_rx) = tokio::sync::oneshot::channel();
195        let handle = tokio::spawn(async move {
196            let _guard = AbortGuard(Some(abort_tx));
197            let _ = started_tx.send(());
198            std::future::pending::<()>().await;
199        });
200        started_rx.await.unwrap();
201
202        let stream = QueryRuntimeStream::new(schema, rx).with_producer_handle(handle);
203        drop(stream);
204
205        tokio::time::timeout(std::time::Duration::from_secs(1), abort_rx)
206            .await
207            .unwrap()
208            .unwrap();
209    }
210
211    #[tokio::test]
212    async fn test_query_runtime_stream_close_stops_sender() {
213        let schema = test_batch().schema.clone();
214        let (tx, rx) = mpsc::channel(1);
215        let stream = QueryRuntimeStream::new(schema, rx);
216        drop(stream);
217
218        assert!(tx.send(Ok(test_batch())).await.is_err());
219    }
220}