1use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use catalog::process_manager::Ticket;
19use common_recordbatch::adapter::RecordBatchMetrics;
20use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream};
21use datatypes::schema::SchemaRef;
22use futures::Stream;
23
24pub struct CancellableStreamWrapper {
25 inner: SendableRecordBatchStream,
26 ticket: Ticket,
27}
28
29impl Unpin for CancellableStreamWrapper {}
30
31impl CancellableStreamWrapper {
32 pub fn new(stream: SendableRecordBatchStream, ticket: Ticket) -> Self {
33 Self {
34 inner: stream,
35 ticket,
36 }
37 }
38}
39
40impl Stream for CancellableStreamWrapper {
41 type Item = common_recordbatch::error::Result<RecordBatch>;
42
43 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
44 let this = &mut *self;
45 if this.ticket.cancellation_handle.is_cancelled() {
46 return Poll::Ready(Some(common_recordbatch::error::StreamCancelledSnafu.fail()));
47 }
48
49 if let Poll::Ready(res) = Pin::new(&mut this.inner).poll_next(cx) {
50 return Poll::Ready(res);
51 }
52
53 this.ticket.cancellation_handle.waker().register(cx.waker());
55 if this.ticket.cancellation_handle.is_cancelled() {
57 return Poll::Ready(Some(common_recordbatch::error::StreamCancelledSnafu.fail()));
58 }
59 Poll::Pending
60 }
61}
62
63impl RecordBatchStream for CancellableStreamWrapper {
64 fn schema(&self) -> SchemaRef {
65 self.inner.schema()
66 }
67
68 fn output_ordering(&self) -> Option<&[OrderOption]> {
69 self.inner.output_ordering()
70 }
71
72 fn metrics(&self) -> Option<RecordBatchMetrics> {
73 self.inner.metrics()
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use std::pin::Pin;
80 use std::sync::Arc;
81 use std::task::{Context, Poll};
82 use std::time::Duration;
83
84 use catalog::process_manager::ProcessManager;
85 use common_recordbatch::adapter::RecordBatchMetrics;
86 use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream};
87 use datatypes::data_type::ConcreteDataType;
88 use datatypes::prelude::VectorRef;
89 use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
90 use datatypes::vectors::Int32Vector;
91 use futures::{Stream, StreamExt};
92 use tokio::time::{sleep, timeout};
93
94 use super::CancellableStreamWrapper;
95
96 struct MockRecordBatchStream {
98 schema: SchemaRef,
99 batches: Vec<common_recordbatch::error::Result<RecordBatch>>,
100 current: usize,
101 delay: Option<Duration>,
102 }
103
104 impl MockRecordBatchStream {
105 fn new(batches: Vec<common_recordbatch::error::Result<RecordBatch>>) -> Self {
106 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
107 "test_col",
108 ConcreteDataType::int32_datatype(),
109 false,
110 )]));
111
112 Self {
113 schema,
114 batches,
115 current: 0,
116 delay: None,
117 }
118 }
119
120 fn with_delay(mut self, delay: Duration) -> Self {
121 self.delay = Some(delay);
122 self
123 }
124 }
125
126 impl Stream for MockRecordBatchStream {
127 type Item = common_recordbatch::error::Result<RecordBatch>;
128
129 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
130 if let Some(delay) = self.delay {
131 let waker = cx.waker().clone();
133 let delay_clone = delay;
134 tokio::spawn(async move {
135 sleep(delay_clone).await;
136 waker.wake();
137 });
138 self.delay = None; return Poll::Pending;
140 }
141
142 if self.current >= self.batches.len() {
143 return Poll::Ready(None);
144 }
145
146 let batch = self.batches[self.current].as_ref().unwrap().clone();
147 self.current += 1;
148 Poll::Ready(Some(Ok(batch)))
149 }
150 }
151
152 impl RecordBatchStream for MockRecordBatchStream {
153 fn schema(&self) -> SchemaRef {
154 self.schema.clone()
155 }
156
157 fn output_ordering(&self) -> Option<&[OrderOption]> {
158 None
159 }
160
161 fn metrics(&self) -> Option<RecordBatchMetrics> {
162 None
163 }
164 }
165
166 fn create_test_batch() -> RecordBatch {
167 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
168 "test_col",
169 ConcreteDataType::int32_datatype(),
170 false,
171 )]));
172 RecordBatch::new(
173 schema,
174 vec![Arc::new(Int32Vector::from_values(0..3)) as VectorRef],
175 )
176 .unwrap()
177 }
178
179 #[tokio::test]
180 async fn test_stream_completes_normally() {
181 let batch = create_test_batch();
182 let mock_stream = MockRecordBatchStream::new(vec![Ok(batch.clone())]);
183 let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
184 let ticket = process_manager.register_query(
185 "catalog".to_string(),
186 vec![],
187 "query".to_string(),
188 "client".to_string(),
189 None,
190 );
191
192 let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
193
194 let result = cancellable_stream.next().await;
195 assert!(result.is_some());
196 assert!(result.unwrap().is_ok());
197
198 let end_result = cancellable_stream.next().await;
199 assert!(end_result.is_none());
200 }
201
202 #[tokio::test]
203 async fn test_stream_cancelled_before_start() {
204 let batch = create_test_batch();
205 let mock_stream = MockRecordBatchStream::new(vec![Ok(batch)]);
206 let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
207 let ticket = process_manager.register_query(
208 "catalog".to_string(),
209 vec![],
210 "query".to_string(),
211 "client".to_string(),
212 None,
213 );
214
215 ticket.cancellation_handle.cancel();
217
218 let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
219
220 let result = cancellable_stream.next().await;
221 assert!(result.is_some());
222 assert!(result.unwrap().is_err());
223 }
224
225 #[tokio::test]
226 async fn test_stream_cancelled_during_execution() {
227 let batch = create_test_batch();
228 let mock_stream =
229 MockRecordBatchStream::new(vec![Ok(batch)]).with_delay(Duration::from_millis(100));
230 let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
231 let ticket = process_manager.register_query(
232 "catalog".to_string(),
233 vec![],
234 "query".to_string(),
235 "client".to_string(),
236 None,
237 );
238 let cancellation_handle = ticket.cancellation_handle.clone();
239
240 let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
241
242 tokio::spawn(async move {
244 sleep(Duration::from_millis(50)).await;
245 cancellation_handle.cancel();
246 });
247
248 let result = cancellable_stream.next().await;
249 assert!(result.is_some());
250 assert!(result.unwrap().is_err());
251 }
252
253 #[tokio::test]
254 async fn test_stream_completes_before_cancellation() {
255 let batch = create_test_batch();
256 let mock_stream = MockRecordBatchStream::new(vec![Ok(batch.clone())]);
257 let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
258 let ticket = process_manager.register_query(
259 "catalog".to_string(),
260 vec![],
261 "query".to_string(),
262 "client".to_string(),
263 None,
264 );
265 let cancellation_handle = ticket.cancellation_handle.clone();
266
267 let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
268
269 tokio::spawn(async move {
271 sleep(Duration::from_millis(100)).await;
272 cancellation_handle.cancel();
273 });
274
275 let result = cancellable_stream.next().await;
276 assert!(result.is_some());
277 assert!(result.unwrap().is_ok());
278 }
279
280 #[tokio::test]
281 async fn test_multiple_batches() {
282 let batch1 = create_test_batch();
283 let batch2 = create_test_batch();
284 let mock_stream = MockRecordBatchStream::new(vec![Ok(batch1), Ok(batch2)]);
285 let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
286 let ticket = process_manager.register_query(
287 "catalog".to_string(),
288 vec![],
289 "query".to_string(),
290 "client".to_string(),
291 None,
292 );
293
294 let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
295
296 let result1 = cancellable_stream.next().await;
298 assert!(result1.is_some());
299 assert!(result1.unwrap().is_ok());
300
301 let result2 = cancellable_stream.next().await;
303 assert!(result2.is_some());
304 assert!(result2.unwrap().is_ok());
305
306 let end_result = cancellable_stream.next().await;
308 assert!(end_result.is_none());
309 }
310
311 #[tokio::test]
312 async fn test_record_batch_stream_methods() {
313 let batch = create_test_batch();
314 let mock_stream = MockRecordBatchStream::new(vec![Ok(batch)]);
315 let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
316 let ticket = process_manager.register_query(
317 "catalog".to_string(),
318 vec![],
319 "query".to_string(),
320 "client".to_string(),
321 None,
322 );
323
324 let cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
325
326 let schema = cancellable_stream.schema();
328 assert_eq!(schema.column_schemas().len(), 1);
329 assert_eq!(schema.column_schemas()[0].name, "test_col");
330
331 assert!(cancellable_stream.output_ordering().is_none());
333
334 assert!(cancellable_stream.metrics().is_none());
336 }
337
338 #[tokio::test]
339 async fn test_cancellation_during_pending_poll() {
340 let batch = create_test_batch();
341 let mock_stream =
342 MockRecordBatchStream::new(vec![Ok(batch)]).with_delay(Duration::from_millis(200));
343 let process_manager = Arc::new(ProcessManager::new("".to_string(), None));
344 let ticket = process_manager.register_query(
345 "catalog".to_string(),
346 vec![],
347 "query".to_string(),
348 "client".to_string(),
349 None,
350 );
351 let cancellation_handle = ticket.cancellation_handle.clone();
352
353 let mut cancellable_stream = CancellableStreamWrapper::new(Box::pin(mock_stream), ticket);
354
355 tokio::spawn(async move {
357 sleep(Duration::from_millis(50)).await;
358 cancellation_handle.cancel();
359 });
360
361 let result = timeout(Duration::from_millis(300), cancellable_stream.next()).await;
362 assert!(result.is_ok());
363 let stream_result = result.unwrap();
364 assert!(stream_result.is_some());
365 assert!(stream_result.unwrap().is_err());
366 }
367}