frontend/
stream_wrapper.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use catalog::process_manager::Ticket;
19use common_recordbatch::adapter::RecordBatchMetrics;
20use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream};
21use datatypes::schema::SchemaRef;
22use futures::Stream;
23
24pub struct CancellableStreamWrapper {
25    inner: SendableRecordBatchStream,
26    ticket: Ticket,
27}
28
29impl Unpin for CancellableStreamWrapper {}
30
31impl CancellableStreamWrapper {
32    pub fn new(stream: SendableRecordBatchStream, ticket: Ticket) -> Self {
33        Self {
34            inner: stream,
35            ticket,
36        }
37    }
38}
39
40impl Stream for CancellableStreamWrapper {
41    type Item = common_recordbatch::error::Result<RecordBatch>;
42
43    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
44        let this = &mut *self;
45        if this.ticket.cancellation_handle.is_cancelled() {
46            return Poll::Ready(Some(common_recordbatch::error::StreamCancelledSnafu.fail()));
47        }
48
49        if let Poll::Ready(res) = Pin::new(&mut this.inner).poll_next(cx) {
50            return Poll::Ready(res);
51        }
52
53        // on pending, register cancellation waker.
54        this.ticket.cancellation_handle.waker().register(cx.waker());
55        // check if canceled again.
56        if this.ticket.cancellation_handle.is_cancelled() {
57            return Poll::Ready(Some(common_recordbatch::error::StreamCancelledSnafu.fail()));
58        }
59        Poll::Pending
60    }
61}
62
63impl RecordBatchStream for CancellableStreamWrapper {
64    fn schema(&self) -> SchemaRef {
65        self.inner.schema()
66    }
67
68    fn output_ordering(&self) -> Option<&[OrderOption]> {
69        self.inner.output_ordering()
70    }
71
72    fn metrics(&self) -> Option<RecordBatchMetrics> {
73        self.inner.metrics()
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use std::pin::Pin;
80    use std::sync::Arc;
81    use std::task::{Context, Poll};
82    use std::time::Duration;
83
84    use catalog::process_manager::ProcessManager;
85    use common_recordbatch::adapter::RecordBatchMetrics;
86    use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream};
87    use datatypes::data_type::ConcreteDataType;
88    use datatypes::prelude::VectorRef;
89    use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
90    use datatypes::vectors::Int32Vector;
91    use futures::{Stream, StreamExt};
92    use tokio::time::{sleep, timeout};
93
94    use super::CancellableStreamWrapper;
95
96    // Mock stream for testing
97    struct MockRecordBatchStream {
98        schema: SchemaRef,
99        batches: Vec<common_recordbatch::error::Result<RecordBatch>>,
100        current: usize,
101        delay: Option<Duration>,
102    }
103
104    impl MockRecordBatchStream {
105        fn new(batches: Vec<common_recordbatch::error::Result<RecordBatch>>) -> Self {
106            let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
107                "test_col",
108                ConcreteDataType::int32_datatype(),
109                false,
110            )]));
111
112            Self {
113                schema,
114                batches,
115                current: 0,
116                delay: None,
117            }
118        }
119
120        fn with_delay(mut self, delay: Duration) -> Self {
121            self.delay = Some(delay);
122            self
123        }
124    }
125
126    impl Stream for MockRecordBatchStream {
127        type Item = common_recordbatch::error::Result<RecordBatch>;
128
129        fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
130            if let Some(delay) = self.delay {
131                // Simulate async delay
132                let waker = cx.waker().clone();
133                let delay_clone = delay;
134                tokio::spawn(async move {
135                    sleep(delay_clone).await;
136                    waker.wake();
137                });
138                self.delay = None; // Only delay once
139                return Poll::Pending;
140            }
141
142            if self.current >= self.batches.len() {
143                return Poll::Ready(None);
144            }
145
146            let batch = self.batches[self.current].as_ref().unwrap().clone();
147            self.current += 1;
148            Poll::Ready(Some(Ok(batch)))
149        }
150    }
151
152    impl RecordBatchStream for MockRecordBatchStream {
153        fn schema(&self) -> SchemaRef {
154            self.schema.clone()
155        }
156
157        fn output_ordering(&self) -> Option<&[OrderOption]> {
158            None
159        }
160
161        fn metrics(&self) -> Option<RecordBatchMetrics> {
162            None
163        }
164    }
165
166    fn create_test_batch() -> RecordBatch {
167        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
168            "test_col",
169            ConcreteDataType::int32_datatype(),
170            false,
171        )]));
172        RecordBatch::new(
173            schema,
174            vec![Arc::new(Int32Vector::from_values(0..3)) as VectorRef],
175        )
176        .unwrap()
177    }
178
179    #[tokio::test]
180    async fn test_stream_completes_normally() {
181        let batch = create_test_batch();
182        let mock_stream = MockRecordBatchStream::new(vec![Ok(batch.clone())]);
183        let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
184        let ticket = process_manager.register_query(
185            "catalog".to_string(),
186            vec![],
187            "query".to_string(),
188            "client".to_string(),
189            None,
190            None,
191        );
192
193        let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
194
195        let result = cancellable_stream.next().await;
196        assert!(result.is_some());
197        assert!(result.unwrap().is_ok());
198
199        let end_result = cancellable_stream.next().await;
200        assert!(end_result.is_none());
201    }
202
203    #[tokio::test]
204    async fn test_stream_cancelled_before_start() {
205        let batch = create_test_batch();
206        let mock_stream = MockRecordBatchStream::new(vec![Ok(batch)]);
207        let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
208        let ticket = process_manager.register_query(
209            "catalog".to_string(),
210            vec![],
211            "query".to_string(),
212            "client".to_string(),
213            None,
214            None,
215        );
216
217        // Cancel before creating the wrapper
218        ticket.cancellation_handle.cancel();
219
220        let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
221
222        let result = cancellable_stream.next().await;
223        assert!(result.is_some());
224        assert!(result.unwrap().is_err());
225    }
226
227    #[tokio::test]
228    async fn test_stream_cancelled_during_execution() {
229        let batch = create_test_batch();
230        let mock_stream =
231            MockRecordBatchStream::new(vec![Ok(batch)]).with_delay(Duration::from_millis(100));
232        let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
233        let ticket = process_manager.register_query(
234            "catalog".to_string(),
235            vec![],
236            "query".to_string(),
237            "client".to_string(),
238            None,
239            None,
240        );
241        let cancellation_handle = ticket.cancellation_handle.clone();
242
243        let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
244
245        // Cancel after a short delay
246        tokio::spawn(async move {
247            sleep(Duration::from_millis(50)).await;
248            cancellation_handle.cancel();
249        });
250
251        let result = cancellable_stream.next().await;
252        assert!(result.is_some());
253        assert!(result.unwrap().is_err());
254    }
255
256    #[tokio::test]
257    async fn test_stream_completes_before_cancellation() {
258        let batch = create_test_batch();
259        let mock_stream = MockRecordBatchStream::new(vec![Ok(batch.clone())]);
260        let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
261        let ticket = process_manager.register_query(
262            "catalog".to_string(),
263            vec![],
264            "query".to_string(),
265            "client".to_string(),
266            None,
267            None,
268        );
269        let cancellation_handle = ticket.cancellation_handle.clone();
270
271        let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
272
273        // Try to cancel after the stream should have completed
274        tokio::spawn(async move {
275            sleep(Duration::from_millis(100)).await;
276            cancellation_handle.cancel();
277        });
278
279        let result = cancellable_stream.next().await;
280        assert!(result.is_some());
281        assert!(result.unwrap().is_ok());
282    }
283
284    #[tokio::test]
285    async fn test_multiple_batches() {
286        let batch1 = create_test_batch();
287        let batch2 = create_test_batch();
288        let mock_stream = MockRecordBatchStream::new(vec![Ok(batch1), Ok(batch2)]);
289        let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
290        let ticket = process_manager.register_query(
291            "catalog".to_string(),
292            vec![],
293            "query".to_string(),
294            "client".to_string(),
295            None,
296            None,
297        );
298
299        let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
300
301        // First batch
302        let result1 = cancellable_stream.next().await;
303        assert!(result1.is_some());
304        assert!(result1.unwrap().is_ok());
305
306        // Second batch
307        let result2 = cancellable_stream.next().await;
308        assert!(result2.is_some());
309        assert!(result2.unwrap().is_ok());
310
311        // End of stream
312        let end_result = cancellable_stream.next().await;
313        assert!(end_result.is_none());
314    }
315
316    #[tokio::test]
317    async fn test_record_batch_stream_methods() {
318        let batch = create_test_batch();
319        let mock_stream = MockRecordBatchStream::new(vec![Ok(batch)]);
320        let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
321        let ticket = process_manager.register_query(
322            "catalog".to_string(),
323            vec![],
324            "query".to_string(),
325            "client".to_string(),
326            None,
327            None,
328        );
329
330        let cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
331
332        // Test schema method
333        let schema = cancellable_stream.schema();
334        assert_eq!(schema.column_schemas().len(), 1);
335        assert_eq!(schema.column_schemas()[0].name, "test_col");
336
337        // Test output_ordering method
338        assert!(cancellable_stream.output_ordering().is_none());
339
340        // Test metrics method
341        assert!(cancellable_stream.metrics().is_none());
342    }
343
344    #[tokio::test]
345    async fn test_cancellation_during_pending_poll() {
346        let batch = create_test_batch();
347        let mock_stream =
348            MockRecordBatchStream::new(vec![Ok(batch)]).with_delay(Duration::from_millis(200));
349        let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
350        let ticket = process_manager.register_query(
351            "catalog".to_string(),
352            vec![],
353            "query".to_string(),
354            "client".to_string(),
355            None,
356            None,
357        );
358        let cancellation_handle = ticket.cancellation_handle.clone();
359
360        let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
361
362        // Cancel while the stream is pending
363        tokio::spawn(async move {
364            sleep(Duration::from_millis(50)).await;
365            cancellation_handle.cancel();
366        });
367
368        let result = timeout(Duration::from_millis(300), cancellable_stream.next()).await;
369        assert!(result.is_ok());
370        let stream_result = result.unwrap();
371        assert!(stream_result.is_some());
372        assert!(stream_result.unwrap().is_err());
373    }
374}