frontend/
stream_wrapper.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use catalog::process_manager::Ticket;
19use common_recordbatch::adapter::RecordBatchMetrics;
20use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream};
21use datatypes::schema::SchemaRef;
22use futures::Stream;
23
24pub struct CancellableStreamWrapper {
25    inner: SendableRecordBatchStream,
26    ticket: Ticket,
27}
28
29impl Unpin for CancellableStreamWrapper {}
30
31impl CancellableStreamWrapper {
32    pub fn new(stream: SendableRecordBatchStream, ticket: Ticket) -> Self {
33        Self {
34            inner: stream,
35            ticket,
36        }
37    }
38}
39
40impl Stream for CancellableStreamWrapper {
41    type Item = common_recordbatch::error::Result<RecordBatch>;
42
43    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
44        let this = &mut *self;
45        if this.ticket.cancellation_handle.is_cancelled() {
46            return Poll::Ready(Some(common_recordbatch::error::StreamCancelledSnafu.fail()));
47        }
48
49        if let Poll::Ready(res) = Pin::new(&mut this.inner).poll_next(cx) {
50            return Poll::Ready(res);
51        }
52
53        // on pending, register cancellation waker.
54        this.ticket.cancellation_handle.waker().register(cx.waker());
55        // check if canceled again.
56        if this.ticket.cancellation_handle.is_cancelled() {
57            return Poll::Ready(Some(common_recordbatch::error::StreamCancelledSnafu.fail()));
58        }
59        Poll::Pending
60    }
61}
62
63impl RecordBatchStream for CancellableStreamWrapper {
64    fn schema(&self) -> SchemaRef {
65        self.inner.schema()
66    }
67
68    fn output_ordering(&self) -> Option<&[OrderOption]> {
69        self.inner.output_ordering()
70    }
71
72    fn metrics(&self) -> Option<RecordBatchMetrics> {
73        self.inner.metrics()
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use std::pin::Pin;
80    use std::sync::Arc;
81    use std::task::{Context, Poll};
82    use std::time::Duration;
83
84    use catalog::process_manager::ProcessManager;
85    use common_recordbatch::adapter::RecordBatchMetrics;
86    use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream};
87    use datatypes::data_type::ConcreteDataType;
88    use datatypes::prelude::VectorRef;
89    use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
90    use datatypes::vectors::Int32Vector;
91    use futures::{Stream, StreamExt};
92    use tokio::time::{sleep, timeout};
93
94    use super::CancellableStreamWrapper;
95
96    // Mock stream for testing
97    struct MockRecordBatchStream {
98        schema: SchemaRef,
99        batches: Vec<common_recordbatch::error::Result<RecordBatch>>,
100        current: usize,
101        delay: Option<Duration>,
102    }
103
104    impl MockRecordBatchStream {
105        fn new(batches: Vec<common_recordbatch::error::Result<RecordBatch>>) -> Self {
106            let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
107                "test_col",
108                ConcreteDataType::int32_datatype(),
109                false,
110            )]));
111
112            Self {
113                schema,
114                batches,
115                current: 0,
116                delay: None,
117            }
118        }
119
120        fn with_delay(mut self, delay: Duration) -> Self {
121            self.delay = Some(delay);
122            self
123        }
124    }
125
126    impl Stream for MockRecordBatchStream {
127        type Item = common_recordbatch::error::Result<RecordBatch>;
128
129        fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
130            if let Some(delay) = self.delay {
131                // Simulate async delay
132                let waker = cx.waker().clone();
133                let delay_clone = delay;
134                tokio::spawn(async move {
135                    sleep(delay_clone).await;
136                    waker.wake();
137                });
138                self.delay = None; // Only delay once
139                return Poll::Pending;
140            }
141
142            if self.current >= self.batches.len() {
143                return Poll::Ready(None);
144            }
145
146            let batch = self.batches[self.current].as_ref().unwrap().clone();
147            self.current += 1;
148            Poll::Ready(Some(Ok(batch)))
149        }
150    }
151
152    impl RecordBatchStream for MockRecordBatchStream {
153        fn schema(&self) -> SchemaRef {
154            self.schema.clone()
155        }
156
157        fn output_ordering(&self) -> Option<&[OrderOption]> {
158            None
159        }
160
161        fn metrics(&self) -> Option<RecordBatchMetrics> {
162            None
163        }
164    }
165
166    fn create_test_batch() -> RecordBatch {
167        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
168            "test_col",
169            ConcreteDataType::int32_datatype(),
170            false,
171        )]));
172        RecordBatch::new(
173            schema,
174            vec![Arc::new(Int32Vector::from_values(0..3)) as VectorRef],
175        )
176        .unwrap()
177    }
178
179    #[tokio::test]
180    async fn test_stream_completes_normally() {
181        let batch = create_test_batch();
182        let mock_stream = MockRecordBatchStream::new(vec![Ok(batch.clone())]);
183        let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
184        let ticket = process_manager.register_query(
185            "catalog".to_string(),
186            vec![],
187            "query".to_string(),
188            "client".to_string(),
189            None,
190        );
191
192        let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
193
194        let result = cancellable_stream.next().await;
195        assert!(result.is_some());
196        assert!(result.unwrap().is_ok());
197
198        let end_result = cancellable_stream.next().await;
199        assert!(end_result.is_none());
200    }
201
202    #[tokio::test]
203    async fn test_stream_cancelled_before_start() {
204        let batch = create_test_batch();
205        let mock_stream = MockRecordBatchStream::new(vec![Ok(batch)]);
206        let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
207        let ticket = process_manager.register_query(
208            "catalog".to_string(),
209            vec![],
210            "query".to_string(),
211            "client".to_string(),
212            None,
213        );
214
215        // Cancel before creating the wrapper
216        ticket.cancellation_handle.cancel();
217
218        let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
219
220        let result = cancellable_stream.next().await;
221        assert!(result.is_some());
222        assert!(result.unwrap().is_err());
223    }
224
225    #[tokio::test]
226    async fn test_stream_cancelled_during_execution() {
227        let batch = create_test_batch();
228        let mock_stream =
229            MockRecordBatchStream::new(vec![Ok(batch)]).with_delay(Duration::from_millis(100));
230        let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
231        let ticket = process_manager.register_query(
232            "catalog".to_string(),
233            vec![],
234            "query".to_string(),
235            "client".to_string(),
236            None,
237        );
238        let cancellation_handle = ticket.cancellation_handle.clone();
239
240        let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
241
242        // Cancel after a short delay
243        tokio::spawn(async move {
244            sleep(Duration::from_millis(50)).await;
245            cancellation_handle.cancel();
246        });
247
248        let result = cancellable_stream.next().await;
249        assert!(result.is_some());
250        assert!(result.unwrap().is_err());
251    }
252
253    #[tokio::test]
254    async fn test_stream_completes_before_cancellation() {
255        let batch = create_test_batch();
256        let mock_stream = MockRecordBatchStream::new(vec![Ok(batch.clone())]);
257        let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
258        let ticket = process_manager.register_query(
259            "catalog".to_string(),
260            vec![],
261            "query".to_string(),
262            "client".to_string(),
263            None,
264        );
265        let cancellation_handle = ticket.cancellation_handle.clone();
266
267        let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
268
269        // Try to cancel after the stream should have completed
270        tokio::spawn(async move {
271            sleep(Duration::from_millis(100)).await;
272            cancellation_handle.cancel();
273        });
274
275        let result = cancellable_stream.next().await;
276        assert!(result.is_some());
277        assert!(result.unwrap().is_ok());
278    }
279
280    #[tokio::test]
281    async fn test_multiple_batches() {
282        let batch1 = create_test_batch();
283        let batch2 = create_test_batch();
284        let mock_stream = MockRecordBatchStream::new(vec![Ok(batch1), Ok(batch2)]);
285        let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
286        let ticket = process_manager.register_query(
287            "catalog".to_string(),
288            vec![],
289            "query".to_string(),
290            "client".to_string(),
291            None,
292        );
293
294        let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
295
296        // First batch
297        let result1 = cancellable_stream.next().await;
298        assert!(result1.is_some());
299        assert!(result1.unwrap().is_ok());
300
301        // Second batch
302        let result2 = cancellable_stream.next().await;
303        assert!(result2.is_some());
304        assert!(result2.unwrap().is_ok());
305
306        // End of stream
307        let end_result = cancellable_stream.next().await;
308        assert!(end_result.is_none());
309    }
310
311    #[tokio::test]
312    async fn test_record_batch_stream_methods() {
313        let batch = create_test_batch();
314        let mock_stream = MockRecordBatchStream::new(vec![Ok(batch)]);
315        let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
316        let ticket = process_manager.register_query(
317            "catalog".to_string(),
318            vec![],
319            "query".to_string(),
320            "client".to_string(),
321            None,
322        );
323
324        let cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
325
326        // Test schema method
327        let schema = cancellable_stream.schema();
328        assert_eq!(schema.column_schemas().len(), 1);
329        assert_eq!(schema.column_schemas()[0].name, "test_col");
330
331        // Test output_ordering method
332        assert!(cancellable_stream.output_ordering().is_none());
333
334        // Test metrics method
335        assert!(cancellable_stream.metrics().is_none());
336    }
337
338    #[tokio::test]
339    async fn test_cancellation_during_pending_poll() {
340        let batch = create_test_batch();
341        let mock_stream =
342            MockRecordBatchStream::new(vec![Ok(batch)]).with_delay(Duration::from_millis(200));
343        let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
344        let ticket = process_manager.register_query(
345            "catalog".to_string(),
346            vec![],
347            "query".to_string(),
348            "client".to_string(),
349            None,
350        );
351        let cancellation_handle = ticket.cancellation_handle.clone();
352
353        let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
354
355        // Cancel while the stream is pending
356        tokio::spawn(async move {
357            sleep(Duration::from_millis(50)).await;
358            cancellation_handle.cancel();
359        });
360
361        let result = timeout(Duration::from_millis(300), cancellable_stream.next()).await;
362        assert!(result.is_ok());
363        let stream_result = result.unwrap();
364        assert!(stream_result.is_some());
365        assert!(stream_result.unwrap().is_err());
366    }
367}