1use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use catalog::process_manager::Ticket;
19use common_recordbatch::adapter::RecordBatchMetrics;
20use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream};
21use datatypes::schema::SchemaRef;
22use futures::Stream;
23
24pub struct CancellableStreamWrapper {
25 inner: SendableRecordBatchStream,
26 ticket: Ticket,
27}
28
29impl Unpin for CancellableStreamWrapper {}
30
31impl CancellableStreamWrapper {
32 pub fn new(stream: SendableRecordBatchStream, ticket: Ticket) -> Self {
33 Self {
34 inner: stream,
35 ticket,
36 }
37 }
38}
39
40impl Stream for CancellableStreamWrapper {
41 type Item = common_recordbatch::error::Result<RecordBatch>;
42
43 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
44 let this = &mut *self;
45 if this.ticket.cancellation_handle.is_cancelled() {
46 return Poll::Ready(Some(common_recordbatch::error::StreamCancelledSnafu.fail()));
47 }
48
49 if let Poll::Ready(res) = Pin::new(&mut this.inner).poll_next(cx) {
50 return Poll::Ready(res);
51 }
52
53 this.ticket.cancellation_handle.waker().register(cx.waker());
55 if this.ticket.cancellation_handle.is_cancelled() {
57 return Poll::Ready(Some(common_recordbatch::error::StreamCancelledSnafu.fail()));
58 }
59 Poll::Pending
60 }
61}
62
63impl RecordBatchStream for CancellableStreamWrapper {
64 fn schema(&self) -> SchemaRef {
65 self.inner.schema()
66 }
67
68 fn output_ordering(&self) -> Option<&[OrderOption]> {
69 self.inner.output_ordering()
70 }
71
72 fn metrics(&self) -> Option<RecordBatchMetrics> {
73 self.inner.metrics()
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use std::pin::Pin;
80 use std::sync::Arc;
81 use std::task::{Context, Poll};
82 use std::time::Duration;
83
84 use catalog::process_manager::ProcessManager;
85 use common_recordbatch::adapter::RecordBatchMetrics;
86 use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream};
87 use datatypes::data_type::ConcreteDataType;
88 use datatypes::prelude::VectorRef;
89 use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
90 use datatypes::vectors::Int32Vector;
91 use futures::{Stream, StreamExt};
92 use tokio::time::{sleep, timeout};
93
94 use super::CancellableStreamWrapper;
95
96 struct MockRecordBatchStream {
98 schema: SchemaRef,
99 batches: Vec<common_recordbatch::error::Result<RecordBatch>>,
100 current: usize,
101 delay: Option<Duration>,
102 }
103
104 impl MockRecordBatchStream {
105 fn new(batches: Vec<common_recordbatch::error::Result<RecordBatch>>) -> Self {
106 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
107 "test_col",
108 ConcreteDataType::int32_datatype(),
109 false,
110 )]));
111
112 Self {
113 schema,
114 batches,
115 current: 0,
116 delay: None,
117 }
118 }
119
120 fn with_delay(mut self, delay: Duration) -> Self {
121 self.delay = Some(delay);
122 self
123 }
124 }
125
126 impl Stream for MockRecordBatchStream {
127 type Item = common_recordbatch::error::Result<RecordBatch>;
128
129 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
130 if let Some(delay) = self.delay {
131 let waker = cx.waker().clone();
133 let delay_clone = delay;
134 tokio::spawn(async move {
135 sleep(delay_clone).await;
136 waker.wake();
137 });
138 self.delay = None; return Poll::Pending;
140 }
141
142 if self.current >= self.batches.len() {
143 return Poll::Ready(None);
144 }
145
146 let batch = self.batches[self.current].as_ref().unwrap().clone();
147 self.current += 1;
148 Poll::Ready(Some(Ok(batch)))
149 }
150 }
151
152 impl RecordBatchStream for MockRecordBatchStream {
153 fn schema(&self) -> SchemaRef {
154 self.schema.clone()
155 }
156
157 fn output_ordering(&self) -> Option<&[OrderOption]> {
158 None
159 }
160
161 fn metrics(&self) -> Option<RecordBatchMetrics> {
162 None
163 }
164 }
165
166 fn create_test_batch() -> RecordBatch {
167 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
168 "test_col",
169 ConcreteDataType::int32_datatype(),
170 false,
171 )]));
172 RecordBatch::new(
173 schema,
174 vec![Arc::new(Int32Vector::from_values(0..3)) as VectorRef],
175 )
176 .unwrap()
177 }
178
179 #[tokio::test]
180 async fn test_stream_completes_normally() {
181 let batch = create_test_batch();
182 let mock_stream = MockRecordBatchStream::new(vec![Ok(batch.clone())]);
183 let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
184 let ticket = process_manager.register_query(
185 "catalog".to_string(),
186 vec![],
187 "query".to_string(),
188 "client".to_string(),
189 None,
190 None,
191 );
192
193 let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
194
195 let result = cancellable_stream.next().await;
196 assert!(result.is_some());
197 assert!(result.unwrap().is_ok());
198
199 let end_result = cancellable_stream.next().await;
200 assert!(end_result.is_none());
201 }
202
203 #[tokio::test]
204 async fn test_stream_cancelled_before_start() {
205 let batch = create_test_batch();
206 let mock_stream = MockRecordBatchStream::new(vec![Ok(batch)]);
207 let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
208 let ticket = process_manager.register_query(
209 "catalog".to_string(),
210 vec![],
211 "query".to_string(),
212 "client".to_string(),
213 None,
214 None,
215 );
216
217 ticket.cancellation_handle.cancel();
219
220 let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
221
222 let result = cancellable_stream.next().await;
223 assert!(result.is_some());
224 assert!(result.unwrap().is_err());
225 }
226
227 #[tokio::test]
228 async fn test_stream_cancelled_during_execution() {
229 let batch = create_test_batch();
230 let mock_stream =
231 MockRecordBatchStream::new(vec![Ok(batch)]).with_delay(Duration::from_millis(100));
232 let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
233 let ticket = process_manager.register_query(
234 "catalog".to_string(),
235 vec![],
236 "query".to_string(),
237 "client".to_string(),
238 None,
239 None,
240 );
241 let cancellation_handle = ticket.cancellation_handle.clone();
242
243 let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
244
245 tokio::spawn(async move {
247 sleep(Duration::from_millis(50)).await;
248 cancellation_handle.cancel();
249 });
250
251 let result = cancellable_stream.next().await;
252 assert!(result.is_some());
253 assert!(result.unwrap().is_err());
254 }
255
256 #[tokio::test]
257 async fn test_stream_completes_before_cancellation() {
258 let batch = create_test_batch();
259 let mock_stream = MockRecordBatchStream::new(vec![Ok(batch.clone())]);
260 let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
261 let ticket = process_manager.register_query(
262 "catalog".to_string(),
263 vec![],
264 "query".to_string(),
265 "client".to_string(),
266 None,
267 None,
268 );
269 let cancellation_handle = ticket.cancellation_handle.clone();
270
271 let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
272
273 tokio::spawn(async move {
275 sleep(Duration::from_millis(100)).await;
276 cancellation_handle.cancel();
277 });
278
279 let result = cancellable_stream.next().await;
280 assert!(result.is_some());
281 assert!(result.unwrap().is_ok());
282 }
283
284 #[tokio::test]
285 async fn test_multiple_batches() {
286 let batch1 = create_test_batch();
287 let batch2 = create_test_batch();
288 let mock_stream = MockRecordBatchStream::new(vec![Ok(batch1), Ok(batch2)]);
289 let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
290 let ticket = process_manager.register_query(
291 "catalog".to_string(),
292 vec![],
293 "query".to_string(),
294 "client".to_string(),
295 None,
296 None,
297 );
298
299 let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
300
301 let result1 = cancellable_stream.next().await;
303 assert!(result1.is_some());
304 assert!(result1.unwrap().is_ok());
305
306 let result2 = cancellable_stream.next().await;
308 assert!(result2.is_some());
309 assert!(result2.unwrap().is_ok());
310
311 let end_result = cancellable_stream.next().await;
313 assert!(end_result.is_none());
314 }
315
316 #[tokio::test]
317 async fn test_record_batch_stream_methods() {
318 let batch = create_test_batch();
319 let mock_stream = MockRecordBatchStream::new(vec![Ok(batch)]);
320 let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
321 let ticket = process_manager.register_query(
322 "catalog".to_string(),
323 vec![],
324 "query".to_string(),
325 "client".to_string(),
326 None,
327 None,
328 );
329
330 let cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
331
332 let schema = cancellable_stream.schema();
334 assert_eq!(schema.column_schemas().len(), 1);
335 assert_eq!(schema.column_schemas()[0].name, "test_col");
336
337 assert!(cancellable_stream.output_ordering().is_none());
339
340 assert!(cancellable_stream.metrics().is_none());
342 }
343
344 #[tokio::test]
345 async fn test_cancellation_during_pending_poll() {
346 let batch = create_test_batch();
347 let mock_stream =
348 MockRecordBatchStream::new(vec![Ok(batch)]).with_delay(Duration::from_millis(200));
349 let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
350 let ticket = process_manager.register_query(
351 "catalog".to_string(),
352 vec![],
353 "query".to_string(),
354 "client".to_string(),
355 None,
356 None,
357 );
358 let cancellation_handle = ticket.cancellation_handle.clone();
359
360 let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
361
362 tokio::spawn(async move {
364 sleep(Duration::from_millis(50)).await;
365 cancellation_handle.cancel();
366 });
367
368 let result = timeout(Duration::from_millis(300), cancellable_stream.next()).await;
369 assert!(result.is_ok());
370 let stream_result = result.unwrap();
371 assert!(stream_result.is_some());
372 assert!(stream_result.unwrap().is_err());
373 }
374}