1#![feature(never_type)]
16
17pub mod adapter;
18pub mod cursor;
19pub mod error;
20pub mod filter;
21pub mod recordbatch;
22pub mod util;
23
24use std::fmt;
25use std::future::Future;
26use std::pin::Pin;
27use std::sync::Arc;
28
29use adapter::RecordBatchMetrics;
30use arc_swap::ArcSwapOption;
31use common_base::readable_size::ReadableSize;
32use common_error::ext::BoxedError;
33use common_memory_manager::{
34 MemoryGuard, MemoryManager, MemoryMetrics, OnExhaustedPolicy, PermitGranularity,
35};
36use common_telemetry::tracing::Span;
37pub use datafusion::physical_plan::SendableRecordBatchStream as DfSendableRecordBatchStream;
38use datatypes::arrow::array::{ArrayRef, AsArray, StringBuilder};
39use datatypes::arrow::compute::SortOptions;
40pub use datatypes::arrow::record_batch::RecordBatch as DfRecordBatch;
41use datatypes::arrow::util::pretty;
42use datatypes::prelude::{ConcreteDataType, VectorRef};
43use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
44use datatypes::types::{JsonFormat, jsonb_to_string};
45use error::Result;
46use futures::task::{Context, Poll};
47use futures::{Stream, TryStreamExt};
48pub use recordbatch::RecordBatch;
49use snafu::{IntoError, ResultExt, ensure};
50
51use crate::error::NewDfRecordBatchSnafu;
52
53pub trait RecordBatchStream: Stream<Item = Result<RecordBatch>> {
54 fn name(&self) -> &str {
55 "RecordBatchStream"
56 }
57
58 fn schema(&self) -> SchemaRef;
59
60 fn output_ordering(&self) -> Option<&[OrderOption]>;
61
62 fn metrics(&self) -> Option<RecordBatchMetrics>;
63}
64
65pub type SendableRecordBatchStream = Pin<Box<dyn RecordBatchStream + Send>>;
66
67#[derive(Debug, Clone, PartialEq, Eq)]
68pub struct OrderOption {
69 pub name: String,
70 pub options: SortOptions,
71}
72
73pub struct SendableRecordBatchMapper {
80 inner: SendableRecordBatchStream,
81 mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
84 schema: SchemaRef,
86 apply_mapper: bool,
88}
89
90pub fn map_json_type_to_string(
96 batch: RecordBatch,
97 original_schema: &SchemaRef,
98 mapped_schema: &SchemaRef,
99) -> Result<RecordBatch> {
100 let mut vectors = Vec::with_capacity(original_schema.column_schemas().len());
101 for (vector, schema) in batch.columns().iter().zip(original_schema.column_schemas()) {
102 if let ConcreteDataType::Json(j) = &schema.data_type {
103 if matches!(&j.format, JsonFormat::Jsonb) {
104 let mut string_vector_builder = StringBuilder::new();
105 let binary_vector = vector.as_binary::<i32>();
106 for value in binary_vector.iter() {
107 let Some(value) = value else {
108 string_vector_builder.append_null();
109 continue;
110 };
111 let string_value =
112 jsonb_to_string(value).with_context(|_| error::CastVectorSnafu {
113 from_type: schema.data_type.clone(),
114 to_type: ConcreteDataType::string_datatype(),
115 })?;
116 string_vector_builder.append_value(string_value);
117 }
118
119 let string_vector = string_vector_builder.finish();
120 vectors.push(Arc::new(string_vector) as ArrayRef);
121 } else {
122 vectors.push(vector.clone());
123 }
124 } else {
125 vectors.push(vector.clone());
126 }
127 }
128
129 let record_batch = datatypes::arrow::record_batch::RecordBatch::try_new(
130 mapped_schema.arrow_schema().clone(),
131 vectors,
132 )
133 .context(NewDfRecordBatchSnafu)?;
134 Ok(RecordBatch::from_df_record_batch(
135 mapped_schema.clone(),
136 record_batch,
137 ))
138}
139
140pub fn map_json_type_to_string_schema(schema: SchemaRef) -> (SchemaRef, bool) {
148 let mut new_columns = Vec::with_capacity(schema.column_schemas().len());
149 let mut apply_mapper = false;
150 for column in schema.column_schemas() {
151 if matches!(column.data_type, ConcreteDataType::Json(_)) {
152 new_columns.push(ColumnSchema::new(
153 column.name.clone(),
154 ConcreteDataType::string_datatype(),
155 column.is_nullable(),
156 ));
157 apply_mapper = true;
158 } else {
159 new_columns.push(column.clone());
160 }
161 }
162 (Arc::new(Schema::new(new_columns)), apply_mapper)
163}
164
165impl SendableRecordBatchMapper {
166 pub fn new(
168 inner: SendableRecordBatchStream,
169 mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
170 schema_mapper: fn(SchemaRef) -> (SchemaRef, bool),
171 ) -> Self {
172 let (mapped_schema, apply_mapper) = schema_mapper(inner.schema());
173 Self {
174 inner,
175 mapper,
176 schema: mapped_schema,
177 apply_mapper,
178 }
179 }
180}
181
182impl RecordBatchStream for SendableRecordBatchMapper {
183 fn name(&self) -> &str {
184 "SendableRecordBatchMapper"
185 }
186
187 fn schema(&self) -> SchemaRef {
188 self.schema.clone()
189 }
190
191 fn output_ordering(&self) -> Option<&[OrderOption]> {
192 self.inner.output_ordering()
193 }
194
195 fn metrics(&self) -> Option<RecordBatchMetrics> {
196 self.inner.metrics()
197 }
198}
199
200impl Stream for SendableRecordBatchMapper {
201 type Item = Result<RecordBatch>;
202
203 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
204 if self.apply_mapper {
205 Pin::new(&mut self.inner).poll_next(cx).map(|opt| {
206 opt.map(|result| {
207 result
208 .and_then(|batch| (self.mapper)(batch, &self.inner.schema(), &self.schema))
209 })
210 })
211 } else {
212 Pin::new(&mut self.inner).poll_next(cx)
213 }
214 }
215}
216
217pub struct EmptyRecordBatchStream {
220 schema: SchemaRef,
222}
223
224impl EmptyRecordBatchStream {
225 pub fn new(schema: SchemaRef) -> Self {
227 Self { schema }
228 }
229}
230
231impl RecordBatchStream for EmptyRecordBatchStream {
232 fn schema(&self) -> SchemaRef {
233 self.schema.clone()
234 }
235
236 fn output_ordering(&self) -> Option<&[OrderOption]> {
237 None
238 }
239
240 fn metrics(&self) -> Option<RecordBatchMetrics> {
241 None
242 }
243}
244
245impl Stream for EmptyRecordBatchStream {
246 type Item = Result<RecordBatch>;
247
248 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
249 Poll::Ready(None)
250 }
251}
252
253#[derive(Debug, PartialEq)]
254pub struct RecordBatches {
255 schema: SchemaRef,
256 batches: Vec<RecordBatch>,
257}
258
259impl RecordBatches {
260 pub fn try_from_columns<I: IntoIterator<Item = VectorRef>>(
261 schema: SchemaRef,
262 columns: I,
263 ) -> Result<Self> {
264 let batches = vec![RecordBatch::new(schema.clone(), columns)?];
265 Ok(Self { schema, batches })
266 }
267
268 pub async fn try_collect(stream: SendableRecordBatchStream) -> Result<Self> {
269 let schema = stream.schema();
270 let batches = stream.try_collect::<Vec<_>>().await?;
271 Ok(Self { schema, batches })
272 }
273
274 #[inline]
275 pub fn empty() -> Self {
276 Self {
277 schema: Arc::new(Schema::new(vec![])),
278 batches: vec![],
279 }
280 }
281
282 pub fn iter(&self) -> impl Iterator<Item = &RecordBatch> {
283 self.batches.iter()
284 }
285
286 pub fn pretty_print(&self) -> Result<String> {
287 let df_batches = &self
288 .iter()
289 .map(|x| x.df_record_batch().clone())
290 .collect::<Vec<_>>();
291 let result = pretty::pretty_format_batches(df_batches).context(error::FormatSnafu)?;
292
293 Ok(result.to_string())
294 }
295
296 pub fn try_new(schema: SchemaRef, batches: Vec<RecordBatch>) -> Result<Self> {
297 for batch in &batches {
298 ensure!(
299 batch.schema == schema,
300 error::CreateRecordBatchesSnafu {
301 reason: format!(
302 "expect RecordBatch schema equals {:?}, actual: {:?}",
303 schema, batch.schema
304 )
305 }
306 )
307 }
308 Ok(Self { schema, batches })
309 }
310
311 pub fn schema(&self) -> SchemaRef {
312 self.schema.clone()
313 }
314
315 pub fn take(self) -> Vec<RecordBatch> {
316 self.batches
317 }
318
319 pub fn as_stream(&self) -> SendableRecordBatchStream {
320 Box::pin(SimpleRecordBatchStream {
321 inner: RecordBatches {
322 schema: self.schema(),
323 batches: self.batches.clone(),
324 },
325 index: 0,
326 })
327 }
328}
329
330impl IntoIterator for RecordBatches {
331 type Item = RecordBatch;
332 type IntoIter = std::vec::IntoIter<Self::Item>;
333
334 fn into_iter(self) -> Self::IntoIter {
335 self.batches.into_iter()
336 }
337}
338
339pub struct SimpleRecordBatchStream {
340 inner: RecordBatches,
341 index: usize,
342}
343
344impl RecordBatchStream for SimpleRecordBatchStream {
345 fn schema(&self) -> SchemaRef {
346 self.inner.schema()
347 }
348
349 fn output_ordering(&self) -> Option<&[OrderOption]> {
350 None
351 }
352
353 fn metrics(&self) -> Option<RecordBatchMetrics> {
354 None
355 }
356}
357
358impl Stream for SimpleRecordBatchStream {
359 type Item = Result<RecordBatch>;
360
361 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
362 Poll::Ready(if self.index < self.inner.batches.len() {
363 let batch = self.inner.batches[self.index].clone();
364 self.index += 1;
365 Some(Ok(batch))
366 } else {
367 None
368 })
369 }
370}
371
372pub struct RecordBatchStreamWrapper<S> {
374 pub schema: SchemaRef,
375 pub stream: S,
376 pub output_ordering: Option<Vec<OrderOption>>,
377 pub metrics: Arc<ArcSwapOption<RecordBatchMetrics>>,
378 pub span: Span,
379}
380
381impl<S> RecordBatchStreamWrapper<S> {
382 pub fn new(schema: SchemaRef, stream: S) -> RecordBatchStreamWrapper<S> {
384 RecordBatchStreamWrapper {
385 schema,
386 stream,
387 output_ordering: None,
388 metrics: Default::default(),
389 span: Span::current(),
390 }
391 }
392}
393
394impl<S: Stream<Item = Result<RecordBatch>> + Unpin> RecordBatchStream
395 for RecordBatchStreamWrapper<S>
396{
397 fn name(&self) -> &str {
398 "RecordBatchStreamWrapper"
399 }
400
401 fn schema(&self) -> SchemaRef {
402 self.schema.clone()
403 }
404
405 fn output_ordering(&self) -> Option<&[OrderOption]> {
406 self.output_ordering.as_deref()
407 }
408
409 fn metrics(&self) -> Option<RecordBatchMetrics> {
410 self.metrics.load().as_ref().map(|s| s.as_ref().clone())
411 }
412}
413
414impl<S: Stream<Item = Result<RecordBatch>> + Unpin> Stream for RecordBatchStreamWrapper<S> {
415 type Item = Result<RecordBatch>;
416
417 fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
418 let _entered = self.span.clone().entered();
419 Pin::new(&mut self.stream).poll_next(ctx)
420 }
421}
422
423#[derive(Clone)]
427pub struct QueryMemoryTracker {
428 manager: MemoryManager<CallbackMemoryMetrics>,
429 metrics: CallbackMemoryMetrics,
430 on_exhausted_policy: OnExhaustedPolicy,
431}
432
433impl fmt::Debug for QueryMemoryTracker {
434 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435 f.debug_struct("QueryMemoryTracker")
436 .field("current", &self.current())
437 .field("limit", &self.limit())
438 .field("on_exhausted_policy", &self.on_exhausted_policy)
439 .field("on_update", &self.metrics.has_on_update())
440 .field("on_exhausted", &self.metrics.has_on_exhausted())
441 .field("on_rejected", &self.metrics.has_on_rejected())
442 .finish()
443 }
444}
445
446impl QueryMemoryTracker {
447 pub fn builder(
449 limit: usize,
450 on_exhausted_policy: OnExhaustedPolicy,
451 ) -> QueryMemoryTrackerBuilder {
452 QueryMemoryTrackerBuilder {
453 limit,
454 on_exhausted_policy,
455 on_update: None,
456 on_exhausted: None,
457 on_reject: None,
458 }
459 }
460
461 fn new_stream_tracker(&self) -> StreamMemoryTracker {
462 StreamMemoryTracker {
463 tracker: self.clone(),
464 guard: self.manager.try_acquire(0).unwrap(),
465 tracked_bytes: 0,
466 }
467 }
468 pub fn current(&self) -> usize {
470 self.manager.used_bytes() as usize
471 }
472
473 fn limit(&self) -> usize {
474 self.manager.limit_bytes() as usize
475 }
476
477 fn reject_error(
478 &self,
479 current: usize,
480 additional: usize,
481 stream_tracked: usize,
482 ) -> error::Error {
483 let limit = self.limit();
484 let msg = format!(
485 "{} requested, {} used globally ({}%), {} used by this stream, hard limit: {}",
486 ReadableSize(additional as u64),
487 ReadableSize(current as u64),
488 (current * 100).checked_div(limit).unwrap_or(0),
489 ReadableSize(stream_tracked as u64),
490 ReadableSize(limit as u64)
491 );
492 error::ExceedMemoryLimitSnafu { msg }.build()
493 }
494
495 fn inc_rejected(&self) {
496 self.metrics.inc_rejected();
497 }
498}
499
500pub struct QueryMemoryTrackerBuilder {
502 limit: usize,
503 on_exhausted_policy: OnExhaustedPolicy,
504 on_update: Option<UpdateCallback>,
505 on_exhausted: Option<UnitCallback>,
506 on_reject: Option<RejectCallback>,
507}
508
509impl QueryMemoryTrackerBuilder {
510 pub fn on_update<F>(mut self, on_update: F) -> Self
517 where
518 F: Fn(usize) + Send + Sync + 'static,
519 {
520 self.on_update = Some(Arc::new(on_update));
521 self
522 }
523
524 pub fn on_exhausted<F>(mut self, on_exhausted: F) -> Self
531 where
532 F: Fn() + Send + Sync + 'static,
533 {
534 self.on_exhausted = Some(Arc::new(on_exhausted));
535 self
536 }
537
538 pub fn on_reject<F>(mut self, on_reject: F) -> Self
540 where
541 F: Fn() + Send + Sync + 'static,
542 {
543 self.on_reject = Some(Arc::new(on_reject));
544 self
545 }
546
547 pub fn build(self) -> QueryMemoryTracker {
549 let metrics = CallbackMemoryMetrics::new(self.on_update, self.on_exhausted, self.on_reject);
550 let manager = MemoryManager::with_granularity(
551 self.limit as u64,
552 PermitGranularity::Kilobyte,
553 metrics.clone(),
554 );
555
556 QueryMemoryTracker {
557 manager,
558 metrics,
559 on_exhausted_policy: self.on_exhausted_policy,
560 }
561 }
562}
563
564struct StreamMemoryTracker {
565 tracker: QueryMemoryTracker,
566 guard: MemoryGuard<CallbackMemoryMetrics>,
567 tracked_bytes: usize,
568}
569
570type MemoryAcquireResult = std::result::Result<(), common_memory_manager::Error>;
571
572impl StreamMemoryTracker {
573 fn inc_rejected(&self) {
574 self.tracker.inc_rejected();
575 }
576
577 fn try_track(&mut self, additional: usize) -> Result<()> {
578 if self.guard.try_acquire_additional(additional as u64) {
579 self.tracked_bytes = self.tracked_bytes.saturating_add(additional);
580 Ok(())
581 } else {
582 Err(self.reject_error(additional))
583 }
584 }
585
586 async fn track_with_policy(mut self, additional: usize) -> (Self, MemoryAcquireResult) {
587 let result = self
588 .guard
589 .acquire_additional_with_policy(additional as u64, self.tracker.on_exhausted_policy)
590 .await;
591 if result.is_ok() {
592 self.tracked_bytes = self.tracked_bytes.saturating_add(additional);
593 }
594 (self, result)
595 }
596
597 fn reject_error(&self, additional: usize) -> error::Error {
598 let current = self.tracker.current();
599 self.tracker
600 .reject_error(current, additional, self.tracked_bytes)
601 }
602
603 fn wait_error(&self, additional: usize, source: common_memory_manager::Error) -> error::Error {
604 match source {
605 common_memory_manager::Error::MemoryLimitExceeded { .. } => {
606 self.reject_error(additional)
607 }
608 common_memory_manager::Error::MemoryAcquireTimeout { waited, .. } => {
609 let current = self.tracker.current();
610 let limit = self.tracker.limit();
611 let msg = format!(
612 "timed out waiting {:?} for {}, {} used globally ({}%), {} used by this stream, hard limit: {}",
613 waited,
614 ReadableSize(additional as u64),
615 ReadableSize(current as u64),
616 (current * 100).checked_div(limit).unwrap_or(0),
617 ReadableSize(self.tracked_bytes as u64),
618 ReadableSize(limit as u64)
619 );
620 error::ExceedMemoryLimitSnafu { msg }.build()
621 }
622 error => error::ExternalSnafu.into_error(BoxedError::new(error)),
623 }
624 }
625}
626
627type PendingTrackFuture = Pin<
628 Box<dyn Future<Output = (StreamMemoryTracker, RecordBatch, usize, MemoryAcquireResult)> + Send>,
629>;
630
631#[derive(Clone)]
632struct CallbackMemoryMetrics {
633 inner: Arc<CallbackMemoryMetricsInner>,
634}
635
636type UpdateCallback = Arc<dyn Fn(usize) + Send + Sync>;
637type UnitCallback = Arc<dyn Fn() + Send + Sync>;
638type RejectCallback = UnitCallback;
639
640struct CallbackMemoryMetricsInner {
641 on_update: Option<UpdateCallback>,
642 on_exhausted: Option<UnitCallback>,
643 on_reject: Option<RejectCallback>,
644}
645
646impl CallbackMemoryMetrics {
647 fn new(
648 on_update: Option<UpdateCallback>,
649 on_exhausted: Option<UnitCallback>,
650 on_reject: Option<RejectCallback>,
651 ) -> Self {
652 Self {
653 inner: Arc::new(CallbackMemoryMetricsInner {
654 on_update,
655 on_exhausted,
656 on_reject,
657 }),
658 }
659 }
660
661 fn has_on_update(&self) -> bool {
662 self.inner.on_update.is_some()
663 }
664
665 fn has_on_exhausted(&self) -> bool {
666 self.inner.on_exhausted.is_some()
667 }
668
669 fn has_on_rejected(&self) -> bool {
670 self.inner.on_reject.is_some()
671 }
672
673 fn inc_rejected(&self) {
674 if let Some(callback) = &self.inner.on_reject {
675 callback();
676 }
677 }
678}
679
680impl MemoryMetrics for CallbackMemoryMetrics {
681 fn set_limit(&self, _: i64) {}
682
683 fn set_in_use(&self, bytes: i64) {
684 if let Some(callback) = &self.inner.on_update {
685 callback(bytes.max(0) as usize);
686 }
687 }
688
689 fn inc_exhausted(&self, _: &str) {
690 if let Some(callback) = &self.inner.on_exhausted {
691 callback();
692 }
693 }
694}
695
696pub struct MemoryTrackedStream {
698 inner: SendableRecordBatchStream,
699 tracker: Option<StreamMemoryTracker>,
700 waiting: Option<PendingTrackFuture>,
704}
705
706impl MemoryTrackedStream {
707 pub fn new(inner: SendableRecordBatchStream, tracker: QueryMemoryTracker) -> Self {
708 Self {
709 inner,
710 tracker: Some(tracker.new_stream_tracker()),
711 waiting: None,
712 }
713 }
714
715 fn ready_tracker_mut(&mut self) -> &mut StreamMemoryTracker {
716 debug_assert!(
717 self.waiting.is_none(),
718 "a ready tracker must not coexist with a waiting future"
719 );
720 self.tracker.as_mut().unwrap()
721 }
722
723 fn enter_waiting(&mut self, batch: RecordBatch, additional: usize) {
724 debug_assert!(
725 self.waiting.is_none(),
726 "enter_waiting should only be called from the ready state"
727 );
728 debug_assert!(
729 self.tracker.is_some(),
730 "enter_waiting requires a tracker in the ready state"
731 );
732 let tracker = self.tracker.take().unwrap();
733 self.waiting = Some(Self::start_waiting(tracker, batch, additional));
734 }
735
736 fn start_waiting(
737 tracker: StreamMemoryTracker,
738 batch: RecordBatch,
739 additional: usize,
740 ) -> PendingTrackFuture {
741 Box::pin(async move {
742 let (tracker, result) = tracker.track_with_policy(additional).await;
743 (tracker, batch, additional, result)
744 })
745 }
746
747 fn poll_waiting(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<RecordBatch>>> {
748 let future = self.waiting.as_mut().unwrap();
749 match future.as_mut().poll(cx) {
750 Poll::Ready((tracker, batch, additional, result)) => {
751 let output = match result {
752 Ok(()) => Ok(batch),
753 Err(error) => {
754 tracker.inc_rejected();
755 Err(tracker.wait_error(additional, error))
756 }
757 };
758 self.waiting = None;
759 self.tracker = Some(tracker);
760 Poll::Ready(Some(output))
761 }
762 Poll::Pending => Poll::Pending,
763 }
764 }
765
766 fn poll_batch(
767 &mut self,
768 batch: RecordBatch,
769 cx: &mut Context<'_>,
770 ) -> Poll<Option<Result<RecordBatch>>> {
771 let additional = batch.buffer_memory_size();
772 let tracker = self.ready_tracker_mut();
773
774 if let Err(error) = tracker.try_track(additional) {
775 match tracker.tracker.on_exhausted_policy {
776 OnExhaustedPolicy::Fail => {
777 tracker.inc_rejected();
778 return Poll::Ready(Some(Err(error)));
779 }
780 OnExhaustedPolicy::Wait { .. } => {
785 self.enter_waiting(batch, additional);
786 return self.poll_waiting(cx);
787 }
788 }
789 }
790
791 Poll::Ready(Some(Ok(batch)))
792 }
793}
794
795impl Stream for MemoryTrackedStream {
796 type Item = Result<RecordBatch>;
797
798 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
799 if self.waiting.is_some() {
800 return self.poll_waiting(cx);
801 }
802
803 match Pin::new(&mut self.inner).poll_next(cx) {
804 Poll::Ready(Some(Ok(batch))) => self.poll_batch(batch, cx),
805 Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
806 Poll::Ready(None) => Poll::Ready(None),
807 Poll::Pending => Poll::Pending,
808 }
809 }
810
811 fn size_hint(&self) -> (usize, Option<usize>) {
812 self.inner.size_hint()
813 }
814}
815
816impl RecordBatchStream for MemoryTrackedStream {
817 fn schema(&self) -> SchemaRef {
818 self.inner.schema()
819 }
820
821 fn output_ordering(&self) -> Option<&[OrderOption]> {
822 self.inner.output_ordering()
823 }
824
825 fn metrics(&self) -> Option<RecordBatchMetrics> {
826 self.inner.metrics()
827 }
828}
829
830#[cfg(test)]
831mod tests {
832 use std::sync::Arc;
833 use std::sync::atomic::{AtomicUsize, Ordering};
834 use std::time::Duration;
835
836 use common_memory_manager::{OnExhaustedPolicy, PermitGranularity};
837 use datatypes::prelude::{ConcreteDataType, VectorRef};
838 use datatypes::schema::{ColumnSchema, Schema};
839 use datatypes::vectors::{BooleanVector, Int32Vector, StringVector};
840 use futures::StreamExt;
841 use tokio::time::{sleep, timeout};
842
843 use super::*;
844
845 fn large_string_batch(bytes: usize) -> RecordBatch {
846 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
847 "payload",
848 ConcreteDataType::string_datatype(),
849 false,
850 )]));
851 let payload = "x".repeat(bytes);
852 let vector: VectorRef = Arc::new(StringVector::from(vec![payload]));
853 RecordBatch::new(schema, vec![vector]).unwrap()
854 }
855
856 fn aligned_tracked_bytes(bytes: usize) -> usize {
857 PermitGranularity::Kilobyte
858 .permits_to_bytes(PermitGranularity::Kilobyte.bytes_to_permits(bytes as u64))
859 as usize
860 }
861
862 #[test]
863 fn test_recordbatches_try_from_columns() {
864 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
865 "a",
866 ConcreteDataType::int32_datatype(),
867 false,
868 )]));
869 let result = RecordBatches::try_from_columns(
870 schema.clone(),
871 vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
872 );
873 assert!(result.is_err());
874
875 let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
876 let expected = vec![RecordBatch::new(schema.clone(), vec![v.clone()]).unwrap()];
877 let r = RecordBatches::try_from_columns(schema, vec![v]).unwrap();
878 assert_eq!(r.take(), expected);
879 }
880
881 #[test]
882 fn test_recordbatches_try_new() {
883 let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
884 let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
885 let column_c = ColumnSchema::new("c", ConcreteDataType::boolean_datatype(), false);
886
887 let va: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
888 let vb: VectorRef = Arc::new(StringVector::from(vec!["hello", "world"]));
889 let vc: VectorRef = Arc::new(BooleanVector::from(vec![true, false]));
890
891 let schema1 = Arc::new(Schema::new(vec![column_a.clone(), column_b]));
892 let batch1 = RecordBatch::new(schema1.clone(), vec![va.clone(), vb]).unwrap();
893
894 let schema2 = Arc::new(Schema::new(vec![column_a, column_c]));
895 let batch2 = RecordBatch::new(schema2.clone(), vec![va, vc]).unwrap();
896
897 let result = RecordBatches::try_new(schema1.clone(), vec![batch1.clone(), batch2]);
898 assert!(result.is_err());
899 assert_eq!(
900 result.unwrap_err().to_string(),
901 format!(
902 "Failed to create RecordBatches, reason: expect RecordBatch schema equals {schema1:?}, actual: {schema2:?}",
903 )
904 );
905
906 let batches = RecordBatches::try_new(schema1.clone(), vec![batch1.clone()]).unwrap();
907 let expected = "\
908+---+-------+
909| a | b |
910+---+-------+
911| 1 | hello |
912| 2 | world |
913+---+-------+";
914 assert_eq!(batches.pretty_print().unwrap(), expected);
915
916 assert_eq!(schema1, batches.schema());
917 assert_eq!(vec![batch1], batches.take());
918 }
919
920 #[tokio::test]
921 async fn test_simple_recordbatch_stream() {
922 let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
923 let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
924 let schema = Arc::new(Schema::new(vec![column_a, column_b]));
925
926 let va1: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
927 let vb1: VectorRef = Arc::new(StringVector::from(vec!["a", "b"]));
928 let batch1 = RecordBatch::new(schema.clone(), vec![va1, vb1]).unwrap();
929
930 let va2: VectorRef = Arc::new(Int32Vector::from_slice([3, 4, 5]));
931 let vb2: VectorRef = Arc::new(StringVector::from(vec!["c", "d", "e"]));
932 let batch2 = RecordBatch::new(schema.clone(), vec![va2, vb2]).unwrap();
933
934 let recordbatches =
935 RecordBatches::try_new(schema.clone(), vec![batch1.clone(), batch2.clone()]).unwrap();
936 let stream = recordbatches.as_stream();
937 let collected = util::collect(stream).await.unwrap();
938 assert_eq!(collected.len(), 2);
939 assert_eq!(collected[0], batch1);
940 assert_eq!(collected[1], batch2);
941 }
942
943 const MB: usize = 1024 * 1024;
944
945 #[test]
946 fn test_query_memory_tracker_basic() {
947 let tracker =
948 Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
949
950 let mut stream1 = tracker.new_stream_tracker();
951 assert!(stream1.try_track(5 * MB).is_ok());
952 assert_eq!(tracker.current(), 5 * MB);
953
954 let mut stream2 = tracker.new_stream_tracker();
955 assert!(stream2.try_track(4 * MB).is_ok());
956 assert_eq!(tracker.current(), 9 * MB);
957
958 drop(stream1);
959 drop(stream2);
960 assert_eq!(tracker.current(), 0);
961 }
962
963 #[test]
964 fn test_query_memory_tracker_shared_global_limit() {
965 let tracker =
966 Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
967 let mut stream1 = tracker.new_stream_tracker();
968 let mut stream2 = tracker.new_stream_tracker();
969
970 assert!(stream1.try_track(3 * MB).is_ok());
971 assert_eq!(tracker.current(), 3 * MB);
972 assert!(stream2.try_track(6 * MB).is_ok());
973 assert_eq!(tracker.current(), 9 * MB);
974
975 let err = stream2.try_track(2 * MB).unwrap_err();
976 let err_msg = err.to_string();
977 assert!(err_msg.contains("6.0MiB used by this stream"));
978 assert!(err_msg.contains("9.0MiB used globally (90%)"));
979 assert!(err_msg.contains("hard limit: 10.0MiB"));
980 assert_eq!(tracker.current(), 9 * MB);
981
982 drop(stream1);
983 assert_eq!(tracker.current(), 6 * MB);
984 drop(stream2);
985 assert_eq!(tracker.current(), 0);
986 }
987
988 #[test]
989 fn test_query_memory_tracker_hard_limit() {
990 let tracker =
991 Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
992 let mut stream = tracker.new_stream_tracker();
993
994 assert!(stream.try_track(9 * MB).is_ok());
995 assert_eq!(tracker.current(), 9 * MB);
996
997 assert!(stream.try_track(2 * MB).is_err());
998 assert_eq!(tracker.current(), 9 * MB);
999
1000 assert!(stream.try_track(MB).is_ok());
1001 assert_eq!(tracker.current(), 10 * MB);
1002
1003 assert!(stream.try_track(MB).is_err());
1004 assert_eq!(tracker.current(), 10 * MB);
1005
1006 drop(stream);
1007 assert_eq!(tracker.current(), 0);
1008 }
1009
1010 #[test]
1011 fn test_query_memory_tracker_unlimited() {
1012 let tracker = Arc::new(QueryMemoryTracker::builder(0, OnExhaustedPolicy::Fail).build());
1013 let mut stream = tracker.new_stream_tracker();
1014
1015 assert!(stream.try_track(10 * MB).is_ok());
1016 assert_eq!(tracker.current(), 10 * MB);
1017 drop(stream);
1018 assert_eq!(tracker.current(), 0);
1019 }
1020
1021 #[test]
1022 fn test_query_memory_tracker_rounds_to_kilobytes() {
1023 let tracker =
1024 Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
1025 let mut stream = tracker.new_stream_tracker();
1026
1027 assert!(stream.try_track(1_537).is_ok());
1028 assert_eq!(tracker.current(), 2 * 1024);
1029
1030 drop(stream);
1031 assert_eq!(tracker.current(), 0);
1032 }
1033
1034 #[tokio::test]
1035 async fn test_memory_tracked_stream_waits_for_capacity() {
1036 let exhausted = Arc::new(AtomicUsize::new(0));
1037 let rejected = Arc::new(AtomicUsize::new(0));
1038 let exhausted_counter = exhausted.clone();
1039 let rejected_counter = rejected.clone();
1040 let tracker = QueryMemoryTracker::builder(
1041 MB,
1042 OnExhaustedPolicy::Wait {
1043 timeout: Duration::from_millis(200),
1044 },
1045 )
1046 .on_exhausted(move || {
1047 exhausted_counter.fetch_add(1, Ordering::Relaxed);
1048 })
1049 .on_reject(move || {
1050 rejected_counter.fetch_add(1, Ordering::Relaxed);
1051 })
1052 .build();
1053 let batch = large_string_batch(700 * 1024);
1054 let expected_bytes = aligned_tracked_bytes(batch.buffer_memory_size());
1055
1056 let mut stream1 = MemoryTrackedStream::new(
1057 RecordBatches::try_new(batch.schema.clone(), vec![batch.clone()])
1058 .unwrap()
1059 .as_stream(),
1060 tracker.clone(),
1061 );
1062 let first = stream1.next().await.unwrap().unwrap();
1063 assert_eq!(first.num_rows(), 1);
1064 assert_eq!(tracker.current(), expected_bytes);
1065
1066 let stream2 = MemoryTrackedStream::new(
1067 RecordBatches::try_new(batch.schema.clone(), vec![batch])
1068 .unwrap()
1069 .as_stream(),
1070 tracker.clone(),
1071 );
1072 let waiter = tokio::spawn(async move {
1073 let mut stream2 = stream2;
1074 stream2.next().await.unwrap()
1075 });
1076
1077 sleep(Duration::from_millis(50)).await;
1078 assert!(!waiter.is_finished());
1079
1080 drop(stream1);
1081 let second = waiter.await.unwrap().unwrap();
1082 assert_eq!(second.num_rows(), 1);
1083 assert_eq!(exhausted.load(Ordering::Relaxed), 1);
1084 assert_eq!(rejected.load(Ordering::Relaxed), 0);
1085 }
1086
1087 #[tokio::test]
1088 async fn test_memory_tracked_stream_wait_times_out() {
1089 let exhausted = Arc::new(AtomicUsize::new(0));
1090 let rejected = Arc::new(AtomicUsize::new(0));
1091 let exhausted_counter = exhausted.clone();
1092 let rejected_counter = rejected.clone();
1093 let tracker = QueryMemoryTracker::builder(
1094 MB,
1095 OnExhaustedPolicy::Wait {
1096 timeout: Duration::from_millis(50),
1097 },
1098 )
1099 .on_exhausted(move || {
1100 exhausted_counter.fetch_add(1, Ordering::Relaxed);
1101 })
1102 .on_reject(move || {
1103 rejected_counter.fetch_add(1, Ordering::Relaxed);
1104 })
1105 .build();
1106 let batch = large_string_batch(700 * 1024);
1107
1108 let mut stream1 = MemoryTrackedStream::new(
1109 RecordBatches::try_new(batch.schema.clone(), vec![batch.clone()])
1110 .unwrap()
1111 .as_stream(),
1112 tracker.clone(),
1113 );
1114 let first = stream1.next().await.unwrap().unwrap();
1115 assert_eq!(first.num_rows(), 1);
1116
1117 let mut stream2 = MemoryTrackedStream::new(
1118 RecordBatches::try_new(batch.schema.clone(), vec![batch])
1119 .unwrap()
1120 .as_stream(),
1121 tracker,
1122 );
1123 let result = timeout(Duration::from_secs(1), stream2.next())
1124 .await
1125 .unwrap();
1126 let error = result.unwrap().unwrap_err();
1127 assert!(error.to_string().contains("timed out waiting"));
1128 assert_eq!(exhausted.load(Ordering::Relaxed), 1);
1129 assert_eq!(rejected.load(Ordering::Relaxed), 1);
1130 }
1131
1132 #[tokio::test]
1133 async fn test_memory_tracked_stream_fail_policy_rejects_immediately() {
1134 let exhausted = Arc::new(AtomicUsize::new(0));
1135 let rejected = Arc::new(AtomicUsize::new(0));
1136 let exhausted_counter = exhausted.clone();
1137 let rejected_counter = rejected.clone();
1138 let tracker = QueryMemoryTracker::builder(MB, OnExhaustedPolicy::Fail)
1139 .on_exhausted(move || {
1140 exhausted_counter.fetch_add(1, Ordering::Relaxed);
1141 })
1142 .on_reject(move || {
1143 rejected_counter.fetch_add(1, Ordering::Relaxed);
1144 })
1145 .build();
1146 let batch = large_string_batch(700 * 1024);
1147
1148 let mut stream1 = MemoryTrackedStream::new(
1149 RecordBatches::try_new(batch.schema.clone(), vec![batch.clone()])
1150 .unwrap()
1151 .as_stream(),
1152 tracker.clone(),
1153 );
1154 let first = stream1.next().await.unwrap().unwrap();
1155 assert_eq!(first.num_rows(), 1);
1156
1157 let mut stream2 = MemoryTrackedStream::new(
1158 RecordBatches::try_new(batch.schema.clone(), vec![batch])
1159 .unwrap()
1160 .as_stream(),
1161 tracker,
1162 );
1163 let result = stream2.next().await.unwrap();
1164 assert!(result.is_err());
1165 assert_eq!(exhausted.load(Ordering::Relaxed), 1);
1166 assert_eq!(rejected.load(Ordering::Relaxed), 1);
1167 }
1168}