1use std::pin::Pin;
16use std::sync::{Arc, RwLock};
17use std::task::{Context, Poll};
18
19use common_recordbatch::adapter::RecordBatchMetrics;
20use common_recordbatch::error::Result as RecordBatchResult;
21use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream};
22use common_runtime::JoinHandle;
23use datatypes::schema::SchemaRef;
24use futures_util::Stream;
25use tokio::sync::mpsc;
26
27pub type QueryRuntimeSender = mpsc::Sender<RecordBatchResult<RecordBatch>>;
28pub type QueryRuntimeMetrics = Arc<RwLock<Option<RecordBatchMetrics>>>;
29
30pub struct QueryRuntimeStream {
32 schema: SchemaRef,
33 receiver: mpsc::Receiver<RecordBatchResult<RecordBatch>>,
34 output_ordering: Option<Vec<OrderOption>>,
35 metrics: QueryRuntimeMetrics,
36 producer_handle: Option<JoinHandle<()>>,
37}
38
39impl QueryRuntimeStream {
40 pub fn new(
41 schema: SchemaRef,
42 receiver: mpsc::Receiver<RecordBatchResult<RecordBatch>>,
43 ) -> Self {
44 Self {
45 schema,
46 receiver,
47 output_ordering: None,
48 metrics: Default::default(),
49 producer_handle: None,
50 }
51 }
52
53 pub fn with_output_ordering(mut self, output_ordering: Option<Vec<OrderOption>>) -> Self {
54 self.output_ordering = output_ordering;
55 self
56 }
57
58 pub fn with_metrics(self, metrics: Option<RecordBatchMetrics>) -> Self {
59 *self.metrics.write().unwrap() = metrics;
60 self
61 }
62
63 pub fn with_metrics_store(mut self, metrics: QueryRuntimeMetrics) -> Self {
64 self.metrics = metrics;
65 self
66 }
67
68 pub fn with_producer_handle(mut self, handle: JoinHandle<()>) -> Self {
69 self.producer_handle = Some(handle);
70 self
71 }
72
73 pub fn metrics_store() -> QueryRuntimeMetrics {
74 Default::default()
75 }
76}
77
78impl Stream for QueryRuntimeStream {
79 type Item = RecordBatchResult<RecordBatch>;
80
81 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
82 self.receiver.poll_recv(cx)
83 }
84}
85
86impl Drop for QueryRuntimeStream {
87 fn drop(&mut self) {
88 if let Some(handle) = self.producer_handle.take() {
89 handle.abort();
90 }
91 }
92}
93
94impl RecordBatchStream for QueryRuntimeStream {
95 fn name(&self) -> &str {
96 "QueryRuntimeStream"
97 }
98
99 fn schema(&self) -> SchemaRef {
100 self.schema.clone()
101 }
102
103 fn output_ordering(&self) -> Option<&[OrderOption]> {
104 self.output_ordering.as_deref()
105 }
106
107 fn metrics(&self) -> Option<RecordBatchMetrics> {
108 self.metrics.read().unwrap().clone()
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use std::sync::Arc;
115
116 use common_recordbatch::error::CreateRecordBatchesSnafu;
117 use datatypes::prelude::{ConcreteDataType, VectorRef};
118 use datatypes::schema::{ColumnSchema, Schema};
119 use datatypes::vectors::Int32Vector;
120 use futures_util::StreamExt;
121
122 use super::*;
123
124 fn test_batch() -> RecordBatch {
125 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
126 "v",
127 ConcreteDataType::int32_datatype(),
128 false,
129 )]));
130 let values: VectorRef = Arc::new(Int32Vector::from_slice([1]));
131 RecordBatch::new(schema, vec![values]).unwrap()
132 }
133
134 #[tokio::test]
135 async fn test_query_runtime_stream_receives_batches() {
136 let batch = test_batch();
137 let schema = batch.schema.clone();
138 let (tx, rx) = mpsc::channel(1);
139 tx.send(Ok(batch)).await.unwrap();
140 drop(tx);
141
142 let mut stream = QueryRuntimeStream::new(schema, rx);
143 let batch = stream.next().await.unwrap().unwrap();
144 assert_eq!(1, batch.num_rows());
145 assert!(stream.next().await.is_none());
146 }
147
148 #[tokio::test]
149 async fn test_query_runtime_stream_forwards_errors() {
150 let schema = test_batch().schema.clone();
151 let (tx, rx) = mpsc::channel(1);
152 tx.send(Err(CreateRecordBatchesSnafu {
153 reason: "test error",
154 }
155 .build()))
156 .await
157 .unwrap();
158 drop(tx);
159
160 let mut stream = QueryRuntimeStream::new(schema, rx);
161 assert!(stream.next().await.unwrap().is_err());
162 }
163
164 #[tokio::test]
165 async fn test_query_runtime_stream_reads_shared_metrics() {
166 let schema = test_batch().schema.clone();
167 let (tx, rx) = mpsc::channel(1);
168 drop(tx);
169 let metrics = QueryRuntimeStream::metrics_store();
170 let stream = QueryRuntimeStream::new(schema, rx).with_metrics_store(metrics.clone());
171
172 assert!(stream.metrics().is_none());
173 *metrics.write().unwrap() = Some(RecordBatchMetrics {
174 elapsed_compute: 42,
175 ..Default::default()
176 });
177
178 assert_eq!(42, stream.metrics().unwrap().elapsed_compute);
179 }
180
181 #[tokio::test]
182 async fn test_query_runtime_stream_drop_aborts_producer() {
183 struct AbortGuard(Option<tokio::sync::oneshot::Sender<()>>);
184
185 impl Drop for AbortGuard {
186 fn drop(&mut self) {
187 let _ = self.0.take().unwrap().send(());
188 }
189 }
190
191 let schema = test_batch().schema.clone();
192 let (_tx, rx) = mpsc::channel(1);
193 let (abort_tx, abort_rx) = tokio::sync::oneshot::channel();
194 let (started_tx, started_rx) = tokio::sync::oneshot::channel();
195 let handle = tokio::spawn(async move {
196 let _guard = AbortGuard(Some(abort_tx));
197 let _ = started_tx.send(());
198 std::future::pending::<()>().await;
199 });
200 started_rx.await.unwrap();
201
202 let stream = QueryRuntimeStream::new(schema, rx).with_producer_handle(handle);
203 drop(stream);
204
205 tokio::time::timeout(std::time::Duration::from_secs(1), abort_rx)
206 .await
207 .unwrap()
208 .unwrap();
209 }
210
211 #[tokio::test]
212 async fn test_query_runtime_stream_close_stops_sender() {
213 let schema = test_batch().schema.clone();
214 let (tx, rx) = mpsc::channel(1);
215 let stream = QueryRuntimeStream::new(schema, rx);
216 drop(stream);
217
218 assert!(tx.send(Ok(test_batch())).await.is_err());
219 }
220}