Skip to main content

common_recordbatch/
lib.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![feature(never_type)]
16
17pub mod adapter;
18pub mod cursor;
19pub mod error;
20pub mod filter;
21pub mod recordbatch;
22pub mod util;
23
24use std::fmt;
25use std::future::Future;
26use std::pin::Pin;
27use std::sync::Arc;
28
29use adapter::RecordBatchMetrics;
30use arc_swap::ArcSwapOption;
31use common_base::readable_size::ReadableSize;
32use common_error::ext::BoxedError;
33use common_memory_manager::{
34    MemoryGuard, MemoryManager, MemoryMetrics, OnExhaustedPolicy, PermitGranularity,
35};
36use common_telemetry::tracing::Span;
37pub use datafusion::physical_plan::SendableRecordBatchStream as DfSendableRecordBatchStream;
38use datatypes::arrow::array::{ArrayRef, AsArray, StringBuilder};
39use datatypes::arrow::compute::SortOptions;
40pub use datatypes::arrow::record_batch::RecordBatch as DfRecordBatch;
41use datatypes::arrow::util::pretty;
42use datatypes::prelude::{ConcreteDataType, VectorRef};
43use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
44use datatypes::types::{JsonFormat, jsonb_to_string};
45use error::Result;
46use futures::task::{Context, Poll};
47use futures::{Stream, TryStreamExt};
48pub use recordbatch::RecordBatch;
49use snafu::{IntoError, ResultExt, ensure};
50
51use crate::error::NewDfRecordBatchSnafu;
52
53pub trait RecordBatchStream: Stream<Item = Result<RecordBatch>> {
54    fn name(&self) -> &str {
55        "RecordBatchStream"
56    }
57
58    fn schema(&self) -> SchemaRef;
59
60    fn output_ordering(&self) -> Option<&[OrderOption]>;
61
62    fn metrics(&self) -> Option<RecordBatchMetrics>;
63}
64
65pub type SendableRecordBatchStream = Pin<Box<dyn RecordBatchStream + Send>>;
66
67#[derive(Debug, Clone, PartialEq, Eq)]
68pub struct OrderOption {
69    pub name: String,
70    pub options: SortOptions,
71}
72
73/// A wrapper that maps a [RecordBatchStream] to a new [RecordBatchStream] by applying a function to each [RecordBatch].
74///
75/// The mapper function is applied to each [RecordBatch] in the stream.
76/// The schema of the new [RecordBatchStream] is the same as the schema of the inner [RecordBatchStream] after applying the schema mapper function.
77/// The output ordering of the new [RecordBatchStream] is the same as the output ordering of the inner [RecordBatchStream].
78/// The metrics of the new [RecordBatchStream] is the same as the metrics of the inner [RecordBatchStream] if it is not `None`.
79pub struct SendableRecordBatchMapper {
80    inner: SendableRecordBatchStream,
81    /// The mapper function is applied to each [RecordBatch] in the stream.
82    /// The original schema and the mapped schema are passed to the mapper function.
83    mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
84    /// The schema of the new [RecordBatchStream] is the same as the schema of the inner [RecordBatchStream] after applying the schema mapper function.
85    schema: SchemaRef,
86    /// Whether the mapper function is applied to each [RecordBatch] in the stream.
87    apply_mapper: bool,
88}
89
90/// Maps the json type to string in the batch.
91///
92/// The json type is mapped to string by converting the json value to string.
93/// The batch is updated to have the same number of columns as the original batch,
94/// but with the json type mapped to string.
95pub fn map_json_type_to_string(
96    batch: RecordBatch,
97    original_schema: &SchemaRef,
98    mapped_schema: &SchemaRef,
99) -> Result<RecordBatch> {
100    let mut vectors = Vec::with_capacity(original_schema.column_schemas().len());
101    for (vector, schema) in batch.columns().iter().zip(original_schema.column_schemas()) {
102        if let ConcreteDataType::Json(j) = &schema.data_type {
103            if matches!(&j.format, JsonFormat::Jsonb) {
104                let mut string_vector_builder = StringBuilder::new();
105                let binary_vector = vector.as_binary::<i32>();
106                for value in binary_vector.iter() {
107                    let Some(value) = value else {
108                        string_vector_builder.append_null();
109                        continue;
110                    };
111                    let string_value =
112                        jsonb_to_string(value).with_context(|_| error::CastVectorSnafu {
113                            from_type: schema.data_type.clone(),
114                            to_type: ConcreteDataType::string_datatype(),
115                        })?;
116                    string_vector_builder.append_value(string_value);
117                }
118
119                let string_vector = string_vector_builder.finish();
120                vectors.push(Arc::new(string_vector) as ArrayRef);
121            } else {
122                vectors.push(vector.clone());
123            }
124        } else {
125            vectors.push(vector.clone());
126        }
127    }
128
129    let record_batch = datatypes::arrow::record_batch::RecordBatch::try_new(
130        mapped_schema.arrow_schema().clone(),
131        vectors,
132    )
133    .context(NewDfRecordBatchSnafu)?;
134    Ok(RecordBatch::from_df_record_batch(
135        mapped_schema.clone(),
136        record_batch,
137    ))
138}
139
140/// Maps the json type to string in the schema.
141///
142/// The json type is mapped to string by converting the json value to string.
143/// The schema is updated to have the same number of columns as the original schema,
144/// but with the json type mapped to string.
145///
146/// Returns the new schema and whether the schema needs to be mapped to string.
147pub fn map_json_type_to_string_schema(schema: SchemaRef) -> (SchemaRef, bool) {
148    let mut new_columns = Vec::with_capacity(schema.column_schemas().len());
149    let mut apply_mapper = false;
150    for column in schema.column_schemas() {
151        if matches!(column.data_type, ConcreteDataType::Json(_)) {
152            new_columns.push(ColumnSchema::new(
153                column.name.clone(),
154                ConcreteDataType::string_datatype(),
155                column.is_nullable(),
156            ));
157            apply_mapper = true;
158        } else {
159            new_columns.push(column.clone());
160        }
161    }
162    (Arc::new(Schema::new(new_columns)), apply_mapper)
163}
164
165impl SendableRecordBatchMapper {
166    /// Creates a new [SendableRecordBatchMapper] with the given inner [RecordBatchStream], mapper function, and schema mapper function.
167    pub fn new(
168        inner: SendableRecordBatchStream,
169        mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
170        schema_mapper: fn(SchemaRef) -> (SchemaRef, bool),
171    ) -> Self {
172        let (mapped_schema, apply_mapper) = schema_mapper(inner.schema());
173        Self {
174            inner,
175            mapper,
176            schema: mapped_schema,
177            apply_mapper,
178        }
179    }
180}
181
182impl RecordBatchStream for SendableRecordBatchMapper {
183    fn name(&self) -> &str {
184        "SendableRecordBatchMapper"
185    }
186
187    fn schema(&self) -> SchemaRef {
188        self.schema.clone()
189    }
190
191    fn output_ordering(&self) -> Option<&[OrderOption]> {
192        self.inner.output_ordering()
193    }
194
195    fn metrics(&self) -> Option<RecordBatchMetrics> {
196        self.inner.metrics()
197    }
198}
199
200impl Stream for SendableRecordBatchMapper {
201    type Item = Result<RecordBatch>;
202
203    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
204        if self.apply_mapper {
205            Pin::new(&mut self.inner).poll_next(cx).map(|opt| {
206                opt.map(|result| {
207                    result
208                        .and_then(|batch| (self.mapper)(batch, &self.inner.schema(), &self.schema))
209                })
210            })
211        } else {
212            Pin::new(&mut self.inner).poll_next(cx)
213        }
214    }
215}
216
217/// EmptyRecordBatchStream can be used to create a RecordBatchStream
218/// that will produce no results
219pub struct EmptyRecordBatchStream {
220    /// Schema wrapped by Arc
221    schema: SchemaRef,
222}
223
224impl EmptyRecordBatchStream {
225    /// Create an empty RecordBatchStream
226    pub fn new(schema: SchemaRef) -> Self {
227        Self { schema }
228    }
229}
230
231impl RecordBatchStream for EmptyRecordBatchStream {
232    fn schema(&self) -> SchemaRef {
233        self.schema.clone()
234    }
235
236    fn output_ordering(&self) -> Option<&[OrderOption]> {
237        None
238    }
239
240    fn metrics(&self) -> Option<RecordBatchMetrics> {
241        None
242    }
243}
244
245impl Stream for EmptyRecordBatchStream {
246    type Item = Result<RecordBatch>;
247
248    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
249        Poll::Ready(None)
250    }
251}
252
253#[derive(Debug, PartialEq)]
254pub struct RecordBatches {
255    schema: SchemaRef,
256    batches: Vec<RecordBatch>,
257}
258
259impl RecordBatches {
260    pub fn try_from_columns<I: IntoIterator<Item = VectorRef>>(
261        schema: SchemaRef,
262        columns: I,
263    ) -> Result<Self> {
264        let batches = vec![RecordBatch::new(schema.clone(), columns)?];
265        Ok(Self { schema, batches })
266    }
267
268    pub async fn try_collect(stream: SendableRecordBatchStream) -> Result<Self> {
269        let schema = stream.schema();
270        let batches = stream.try_collect::<Vec<_>>().await?;
271        Ok(Self { schema, batches })
272    }
273
274    #[inline]
275    pub fn empty() -> Self {
276        Self {
277            schema: Arc::new(Schema::new(vec![])),
278            batches: vec![],
279        }
280    }
281
282    pub fn iter(&self) -> impl Iterator<Item = &RecordBatch> {
283        self.batches.iter()
284    }
285
286    pub fn pretty_print(&self) -> Result<String> {
287        let df_batches = &self
288            .iter()
289            .map(|x| x.df_record_batch().clone())
290            .collect::<Vec<_>>();
291        let result = pretty::pretty_format_batches(df_batches).context(error::FormatSnafu)?;
292
293        Ok(result.to_string())
294    }
295
296    pub fn try_new(schema: SchemaRef, batches: Vec<RecordBatch>) -> Result<Self> {
297        for batch in &batches {
298            ensure!(
299                batch.schema == schema,
300                error::CreateRecordBatchesSnafu {
301                    reason: format!(
302                        "expect RecordBatch schema equals {:?}, actual: {:?}",
303                        schema, batch.schema
304                    )
305                }
306            )
307        }
308        Ok(Self { schema, batches })
309    }
310
311    pub fn schema(&self) -> SchemaRef {
312        self.schema.clone()
313    }
314
315    pub fn take(self) -> Vec<RecordBatch> {
316        self.batches
317    }
318
319    pub fn as_stream(&self) -> SendableRecordBatchStream {
320        Box::pin(SimpleRecordBatchStream {
321            inner: RecordBatches {
322                schema: self.schema(),
323                batches: self.batches.clone(),
324            },
325            index: 0,
326        })
327    }
328}
329
330impl IntoIterator for RecordBatches {
331    type Item = RecordBatch;
332    type IntoIter = std::vec::IntoIter<Self::Item>;
333
334    fn into_iter(self) -> Self::IntoIter {
335        self.batches.into_iter()
336    }
337}
338
339pub struct SimpleRecordBatchStream {
340    inner: RecordBatches,
341    index: usize,
342}
343
344impl RecordBatchStream for SimpleRecordBatchStream {
345    fn schema(&self) -> SchemaRef {
346        self.inner.schema()
347    }
348
349    fn output_ordering(&self) -> Option<&[OrderOption]> {
350        None
351    }
352
353    fn metrics(&self) -> Option<RecordBatchMetrics> {
354        None
355    }
356}
357
358impl Stream for SimpleRecordBatchStream {
359    type Item = Result<RecordBatch>;
360
361    fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
362        Poll::Ready(if self.index < self.inner.batches.len() {
363            let batch = self.inner.batches[self.index].clone();
364            self.index += 1;
365            Some(Ok(batch))
366        } else {
367            None
368        })
369    }
370}
371
372/// Adapt a [Stream] of [RecordBatch] to a [RecordBatchStream].
373pub struct RecordBatchStreamWrapper<S> {
374    pub schema: SchemaRef,
375    pub stream: S,
376    pub output_ordering: Option<Vec<OrderOption>>,
377    pub metrics: Arc<ArcSwapOption<RecordBatchMetrics>>,
378    pub span: Span,
379}
380
381impl<S> RecordBatchStreamWrapper<S> {
382    /// Creates a [RecordBatchStreamWrapper] without output ordering requirement.
383    pub fn new(schema: SchemaRef, stream: S) -> RecordBatchStreamWrapper<S> {
384        RecordBatchStreamWrapper {
385            schema,
386            stream,
387            output_ordering: None,
388            metrics: Default::default(),
389            span: Span::current(),
390        }
391    }
392}
393
394impl<S: Stream<Item = Result<RecordBatch>> + Unpin> RecordBatchStream
395    for RecordBatchStreamWrapper<S>
396{
397    fn name(&self) -> &str {
398        "RecordBatchStreamWrapper"
399    }
400
401    fn schema(&self) -> SchemaRef {
402        self.schema.clone()
403    }
404
405    fn output_ordering(&self) -> Option<&[OrderOption]> {
406        self.output_ordering.as_deref()
407    }
408
409    fn metrics(&self) -> Option<RecordBatchMetrics> {
410        self.metrics.load().as_ref().map(|s| s.as_ref().clone())
411    }
412}
413
414impl<S: Stream<Item = Result<RecordBatch>> + Unpin> Stream for RecordBatchStreamWrapper<S> {
415    type Item = Result<RecordBatch>;
416
417    fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
418        let _entered = self.span.clone().entered();
419        Pin::new(&mut self.stream).poll_next(ctx)
420    }
421}
422
423/// Memory tracker for RecordBatch streams. Clone to share the same limit across queries.
424///
425/// Each stream acquires quota independently from this tracker.
426#[derive(Clone)]
427pub struct QueryMemoryTracker {
428    manager: MemoryManager<CallbackMemoryMetrics>,
429    metrics: CallbackMemoryMetrics,
430    on_exhausted_policy: OnExhaustedPolicy,
431}
432
433impl fmt::Debug for QueryMemoryTracker {
434    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435        f.debug_struct("QueryMemoryTracker")
436            .field("current", &self.current())
437            .field("limit", &self.limit())
438            .field("on_exhausted_policy", &self.on_exhausted_policy)
439            .field("on_update", &self.metrics.has_on_update())
440            .field("on_exhausted", &self.metrics.has_on_exhausted())
441            .field("on_rejected", &self.metrics.has_on_rejected())
442            .finish()
443    }
444}
445
446impl QueryMemoryTracker {
447    /// Create a builder for a query memory tracker.
448    pub fn builder(
449        limit: usize,
450        on_exhausted_policy: OnExhaustedPolicy,
451    ) -> QueryMemoryTrackerBuilder {
452        QueryMemoryTrackerBuilder {
453            limit,
454            on_exhausted_policy,
455            on_update: None,
456            on_exhausted: None,
457            on_reject: None,
458        }
459    }
460
461    fn new_stream_tracker(&self) -> StreamMemoryTracker {
462        StreamMemoryTracker {
463            tracker: self.clone(),
464            guard: self.manager.try_acquire(0).unwrap(),
465            tracked_bytes: 0,
466        }
467    }
468    /// Get the current memory usage in bytes.
469    pub fn current(&self) -> usize {
470        self.manager.used_bytes() as usize
471    }
472
473    fn limit(&self) -> usize {
474        self.manager.limit_bytes() as usize
475    }
476
477    fn reject_error(
478        &self,
479        current: usize,
480        additional: usize,
481        stream_tracked: usize,
482    ) -> error::Error {
483        let limit = self.limit();
484        let msg = format!(
485            "{} requested, {} used globally ({}%), {} used by this stream, hard limit: {}",
486            ReadableSize(additional as u64),
487            ReadableSize(current as u64),
488            (current * 100).checked_div(limit).unwrap_or(0),
489            ReadableSize(stream_tracked as u64),
490            ReadableSize(limit as u64)
491        );
492        error::ExceedMemoryLimitSnafu { msg }.build()
493    }
494
495    fn inc_rejected(&self) {
496        self.metrics.inc_rejected();
497    }
498}
499
500/// Builder for constructing a [`QueryMemoryTracker`] with optional callbacks.
501pub struct QueryMemoryTrackerBuilder {
502    limit: usize,
503    on_exhausted_policy: OnExhaustedPolicy,
504    on_update: Option<UpdateCallback>,
505    on_exhausted: Option<UnitCallback>,
506    on_reject: Option<RejectCallback>,
507}
508
509impl QueryMemoryTrackerBuilder {
510    /// Set a callback to be called whenever the usage changes successfully.
511    /// The callback receives the new total usage in bytes.
512    ///
513    /// # Note
514    /// The callback is called after both successful `track()` and stream drop.
515    /// Usage is exact in unlimited mode and 1KB-aligned in limited mode.
516    pub fn on_update<F>(mut self, on_update: F) -> Self
517    where
518        F: Fn(usize) + Send + Sync + 'static,
519    {
520        self.on_update = Some(Arc::new(on_update));
521        self
522    }
523
524    /// Set a callback to be called when memory is unavailable for immediate acquisition.
525    ///
526    /// # Note
527    /// This is called when the non-blocking allocation fast path fails.
528    /// Requests using `OnExhaustedPolicy::Wait` may still succeed after waiting.
529    /// It is never called when `limit == 0` (unlimited mode).
530    pub fn on_exhausted<F>(mut self, on_exhausted: F) -> Self
531    where
532        F: Fn() + Send + Sync + 'static,
533    {
534        self.on_exhausted = Some(Arc::new(on_exhausted));
535        self
536    }
537
538    /// Set a callback to be called when the request ultimately fails due to memory pressure.
539    pub fn on_reject<F>(mut self, on_reject: F) -> Self
540    where
541        F: Fn() + Send + Sync + 'static,
542    {
543        self.on_reject = Some(Arc::new(on_reject));
544        self
545    }
546
547    /// Build a [`QueryMemoryTracker`] from this builder.
548    pub fn build(self) -> QueryMemoryTracker {
549        let metrics = CallbackMemoryMetrics::new(self.on_update, self.on_exhausted, self.on_reject);
550        let manager = MemoryManager::with_granularity(
551            self.limit as u64,
552            PermitGranularity::Kilobyte,
553            metrics.clone(),
554        );
555
556        QueryMemoryTracker {
557            manager,
558            metrics,
559            on_exhausted_policy: self.on_exhausted_policy,
560        }
561    }
562}
563
564struct StreamMemoryTracker {
565    tracker: QueryMemoryTracker,
566    guard: MemoryGuard<CallbackMemoryMetrics>,
567    tracked_bytes: usize,
568}
569
570type MemoryAcquireResult = std::result::Result<(), common_memory_manager::Error>;
571
572impl StreamMemoryTracker {
573    fn inc_rejected(&self) {
574        self.tracker.inc_rejected();
575    }
576
577    fn try_track(&mut self, additional: usize) -> Result<()> {
578        if self.guard.try_acquire_additional(additional as u64) {
579            self.tracked_bytes = self.tracked_bytes.saturating_add(additional);
580            Ok(())
581        } else {
582            Err(self.reject_error(additional))
583        }
584    }
585
586    async fn track_with_policy(mut self, additional: usize) -> (Self, MemoryAcquireResult) {
587        let result = self
588            .guard
589            .acquire_additional_with_policy(additional as u64, self.tracker.on_exhausted_policy)
590            .await;
591        if result.is_ok() {
592            self.tracked_bytes = self.tracked_bytes.saturating_add(additional);
593        }
594        (self, result)
595    }
596
597    fn reject_error(&self, additional: usize) -> error::Error {
598        let current = self.tracker.current();
599        self.tracker
600            .reject_error(current, additional, self.tracked_bytes)
601    }
602
603    fn wait_error(&self, additional: usize, source: common_memory_manager::Error) -> error::Error {
604        match source {
605            common_memory_manager::Error::MemoryLimitExceeded { .. } => {
606                self.reject_error(additional)
607            }
608            common_memory_manager::Error::MemoryAcquireTimeout { waited, .. } => {
609                let current = self.tracker.current();
610                let limit = self.tracker.limit();
611                let msg = format!(
612                    "timed out waiting {:?} for {}, {} used globally ({}%), {} used by this stream, hard limit: {}",
613                    waited,
614                    ReadableSize(additional as u64),
615                    ReadableSize(current as u64),
616                    (current * 100).checked_div(limit).unwrap_or(0),
617                    ReadableSize(self.tracked_bytes as u64),
618                    ReadableSize(limit as u64)
619                );
620                error::ExceedMemoryLimitSnafu { msg }.build()
621            }
622            error => error::ExternalSnafu.into_error(BoxedError::new(error)),
623        }
624    }
625}
626
627type PendingTrackFuture = Pin<
628    Box<dyn Future<Output = (StreamMemoryTracker, RecordBatch, usize, MemoryAcquireResult)> + Send>,
629>;
630
631#[derive(Clone)]
632struct CallbackMemoryMetrics {
633    inner: Arc<CallbackMemoryMetricsInner>,
634}
635
636type UpdateCallback = Arc<dyn Fn(usize) + Send + Sync>;
637type UnitCallback = Arc<dyn Fn() + Send + Sync>;
638type RejectCallback = UnitCallback;
639
640struct CallbackMemoryMetricsInner {
641    on_update: Option<UpdateCallback>,
642    on_exhausted: Option<UnitCallback>,
643    on_reject: Option<RejectCallback>,
644}
645
646impl CallbackMemoryMetrics {
647    fn new(
648        on_update: Option<UpdateCallback>,
649        on_exhausted: Option<UnitCallback>,
650        on_reject: Option<RejectCallback>,
651    ) -> Self {
652        Self {
653            inner: Arc::new(CallbackMemoryMetricsInner {
654                on_update,
655                on_exhausted,
656                on_reject,
657            }),
658        }
659    }
660
661    fn has_on_update(&self) -> bool {
662        self.inner.on_update.is_some()
663    }
664
665    fn has_on_exhausted(&self) -> bool {
666        self.inner.on_exhausted.is_some()
667    }
668
669    fn has_on_rejected(&self) -> bool {
670        self.inner.on_reject.is_some()
671    }
672
673    fn inc_rejected(&self) {
674        if let Some(callback) = &self.inner.on_reject {
675            callback();
676        }
677    }
678}
679
680impl MemoryMetrics for CallbackMemoryMetrics {
681    fn set_limit(&self, _: i64) {}
682
683    fn set_in_use(&self, bytes: i64) {
684        if let Some(callback) = &self.inner.on_update {
685            callback(bytes.max(0) as usize);
686        }
687    }
688
689    fn inc_exhausted(&self, _: &str) {
690        if let Some(callback) = &self.inner.on_exhausted {
691            callback();
692        }
693    }
694}
695
696/// A wrapper stream that tracks memory usage of RecordBatches.
697pub struct MemoryTrackedStream {
698    inner: SendableRecordBatchStream,
699    tracker: Option<StreamMemoryTracker>,
700    // Waiting stores a batch that has already been pulled from the inner stream but has not yet
701    // acquired additional quota. This keeps `poll_next()` non-blocking and allows bounded waits,
702    // at the cost of temporarily holding one untracked batch per blocked stream in memory.
703    waiting: Option<PendingTrackFuture>,
704}
705
706impl MemoryTrackedStream {
707    pub fn new(inner: SendableRecordBatchStream, tracker: QueryMemoryTracker) -> Self {
708        Self {
709            inner,
710            tracker: Some(tracker.new_stream_tracker()),
711            waiting: None,
712        }
713    }
714
715    fn ready_tracker_mut(&mut self) -> &mut StreamMemoryTracker {
716        debug_assert!(
717            self.waiting.is_none(),
718            "a ready tracker must not coexist with a waiting future"
719        );
720        self.tracker.as_mut().unwrap()
721    }
722
723    fn enter_waiting(&mut self, batch: RecordBatch, additional: usize) {
724        debug_assert!(
725            self.waiting.is_none(),
726            "enter_waiting should only be called from the ready state"
727        );
728        debug_assert!(
729            self.tracker.is_some(),
730            "enter_waiting requires a tracker in the ready state"
731        );
732        let tracker = self.tracker.take().unwrap();
733        self.waiting = Some(Self::start_waiting(tracker, batch, additional));
734    }
735
736    fn start_waiting(
737        tracker: StreamMemoryTracker,
738        batch: RecordBatch,
739        additional: usize,
740    ) -> PendingTrackFuture {
741        Box::pin(async move {
742            let (tracker, result) = tracker.track_with_policy(additional).await;
743            (tracker, batch, additional, result)
744        })
745    }
746
747    fn poll_waiting(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<RecordBatch>>> {
748        let future = self.waiting.as_mut().unwrap();
749        match future.as_mut().poll(cx) {
750            Poll::Ready((tracker, batch, additional, result)) => {
751                let output = match result {
752                    Ok(()) => Ok(batch),
753                    Err(error) => {
754                        tracker.inc_rejected();
755                        Err(tracker.wait_error(additional, error))
756                    }
757                };
758                self.waiting = None;
759                self.tracker = Some(tracker);
760                Poll::Ready(Some(output))
761            }
762            Poll::Pending => Poll::Pending,
763        }
764    }
765
766    fn poll_batch(
767        &mut self,
768        batch: RecordBatch,
769        cx: &mut Context<'_>,
770    ) -> Poll<Option<Result<RecordBatch>>> {
771        let additional = batch.buffer_memory_size();
772        let tracker = self.ready_tracker_mut();
773
774        if let Err(error) = tracker.try_track(additional) {
775            match tracker.tracker.on_exhausted_policy {
776                OnExhaustedPolicy::Fail => {
777                    tracker.inc_rejected();
778                    return Poll::Ready(Some(Err(error)));
779                }
780                // `Wait` is a deliberate tradeoff: the batch has already been materialized, so we
781                // keep it in memory while waiting for quota instead of failing immediately. Under
782                // contention, real memory usage can therefore exceed `scan_memory_limit` by up to
783                // one buffered batch per blocked stream.
784                OnExhaustedPolicy::Wait { .. } => {
785                    self.enter_waiting(batch, additional);
786                    return self.poll_waiting(cx);
787                }
788            }
789        }
790
791        Poll::Ready(Some(Ok(batch)))
792    }
793}
794
795impl Stream for MemoryTrackedStream {
796    type Item = Result<RecordBatch>;
797
798    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
799        if self.waiting.is_some() {
800            return self.poll_waiting(cx);
801        }
802
803        match Pin::new(&mut self.inner).poll_next(cx) {
804            Poll::Ready(Some(Ok(batch))) => self.poll_batch(batch, cx),
805            Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
806            Poll::Ready(None) => Poll::Ready(None),
807            Poll::Pending => Poll::Pending,
808        }
809    }
810
811    fn size_hint(&self) -> (usize, Option<usize>) {
812        self.inner.size_hint()
813    }
814}
815
816impl RecordBatchStream for MemoryTrackedStream {
817    fn schema(&self) -> SchemaRef {
818        self.inner.schema()
819    }
820
821    fn output_ordering(&self) -> Option<&[OrderOption]> {
822        self.inner.output_ordering()
823    }
824
825    fn metrics(&self) -> Option<RecordBatchMetrics> {
826        self.inner.metrics()
827    }
828}
829
830#[cfg(test)]
831mod tests {
832    use std::sync::Arc;
833    use std::sync::atomic::{AtomicUsize, Ordering};
834    use std::time::Duration;
835
836    use common_memory_manager::{OnExhaustedPolicy, PermitGranularity};
837    use datatypes::prelude::{ConcreteDataType, VectorRef};
838    use datatypes::schema::{ColumnSchema, Schema};
839    use datatypes::vectors::{BooleanVector, Int32Vector, StringVector};
840    use futures::StreamExt;
841    use tokio::time::{sleep, timeout};
842
843    use super::*;
844
845    fn large_string_batch(bytes: usize) -> RecordBatch {
846        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
847            "payload",
848            ConcreteDataType::string_datatype(),
849            false,
850        )]));
851        let payload = "x".repeat(bytes);
852        let vector: VectorRef = Arc::new(StringVector::from(vec![payload]));
853        RecordBatch::new(schema, vec![vector]).unwrap()
854    }
855
856    fn aligned_tracked_bytes(bytes: usize) -> usize {
857        PermitGranularity::Kilobyte
858            .permits_to_bytes(PermitGranularity::Kilobyte.bytes_to_permits(bytes as u64))
859            as usize
860    }
861
862    #[test]
863    fn test_recordbatches_try_from_columns() {
864        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
865            "a",
866            ConcreteDataType::int32_datatype(),
867            false,
868        )]));
869        let result = RecordBatches::try_from_columns(
870            schema.clone(),
871            vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
872        );
873        assert!(result.is_err());
874
875        let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
876        let expected = vec![RecordBatch::new(schema.clone(), vec![v.clone()]).unwrap()];
877        let r = RecordBatches::try_from_columns(schema, vec![v]).unwrap();
878        assert_eq!(r.take(), expected);
879    }
880
881    #[test]
882    fn test_recordbatches_try_new() {
883        let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
884        let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
885        let column_c = ColumnSchema::new("c", ConcreteDataType::boolean_datatype(), false);
886
887        let va: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
888        let vb: VectorRef = Arc::new(StringVector::from(vec!["hello", "world"]));
889        let vc: VectorRef = Arc::new(BooleanVector::from(vec![true, false]));
890
891        let schema1 = Arc::new(Schema::new(vec![column_a.clone(), column_b]));
892        let batch1 = RecordBatch::new(schema1.clone(), vec![va.clone(), vb]).unwrap();
893
894        let schema2 = Arc::new(Schema::new(vec![column_a, column_c]));
895        let batch2 = RecordBatch::new(schema2.clone(), vec![va, vc]).unwrap();
896
897        let result = RecordBatches::try_new(schema1.clone(), vec![batch1.clone(), batch2]);
898        assert!(result.is_err());
899        assert_eq!(
900            result.unwrap_err().to_string(),
901            format!(
902                "Failed to create RecordBatches, reason: expect RecordBatch schema equals {schema1:?}, actual: {schema2:?}",
903            )
904        );
905
906        let batches = RecordBatches::try_new(schema1.clone(), vec![batch1.clone()]).unwrap();
907        let expected = "\
908+---+-------+
909| a | b     |
910+---+-------+
911| 1 | hello |
912| 2 | world |
913+---+-------+";
914        assert_eq!(batches.pretty_print().unwrap(), expected);
915
916        assert_eq!(schema1, batches.schema());
917        assert_eq!(vec![batch1], batches.take());
918    }
919
920    #[tokio::test]
921    async fn test_simple_recordbatch_stream() {
922        let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
923        let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
924        let schema = Arc::new(Schema::new(vec![column_a, column_b]));
925
926        let va1: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
927        let vb1: VectorRef = Arc::new(StringVector::from(vec!["a", "b"]));
928        let batch1 = RecordBatch::new(schema.clone(), vec![va1, vb1]).unwrap();
929
930        let va2: VectorRef = Arc::new(Int32Vector::from_slice([3, 4, 5]));
931        let vb2: VectorRef = Arc::new(StringVector::from(vec!["c", "d", "e"]));
932        let batch2 = RecordBatch::new(schema.clone(), vec![va2, vb2]).unwrap();
933
934        let recordbatches =
935            RecordBatches::try_new(schema.clone(), vec![batch1.clone(), batch2.clone()]).unwrap();
936        let stream = recordbatches.as_stream();
937        let collected = util::collect(stream).await.unwrap();
938        assert_eq!(collected.len(), 2);
939        assert_eq!(collected[0], batch1);
940        assert_eq!(collected[1], batch2);
941    }
942
943    const MB: usize = 1024 * 1024;
944
945    #[test]
946    fn test_query_memory_tracker_basic() {
947        let tracker =
948            Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
949
950        let mut stream1 = tracker.new_stream_tracker();
951        assert!(stream1.try_track(5 * MB).is_ok());
952        assert_eq!(tracker.current(), 5 * MB);
953
954        let mut stream2 = tracker.new_stream_tracker();
955        assert!(stream2.try_track(4 * MB).is_ok());
956        assert_eq!(tracker.current(), 9 * MB);
957
958        drop(stream1);
959        drop(stream2);
960        assert_eq!(tracker.current(), 0);
961    }
962
963    #[test]
964    fn test_query_memory_tracker_shared_global_limit() {
965        let tracker =
966            Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
967        let mut stream1 = tracker.new_stream_tracker();
968        let mut stream2 = tracker.new_stream_tracker();
969
970        assert!(stream1.try_track(3 * MB).is_ok());
971        assert_eq!(tracker.current(), 3 * MB);
972        assert!(stream2.try_track(6 * MB).is_ok());
973        assert_eq!(tracker.current(), 9 * MB);
974
975        let err = stream2.try_track(2 * MB).unwrap_err();
976        let err_msg = err.to_string();
977        assert!(err_msg.contains("6.0MiB used by this stream"));
978        assert!(err_msg.contains("9.0MiB used globally (90%)"));
979        assert!(err_msg.contains("hard limit: 10.0MiB"));
980        assert_eq!(tracker.current(), 9 * MB);
981
982        drop(stream1);
983        assert_eq!(tracker.current(), 6 * MB);
984        drop(stream2);
985        assert_eq!(tracker.current(), 0);
986    }
987
988    #[test]
989    fn test_query_memory_tracker_hard_limit() {
990        let tracker =
991            Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
992        let mut stream = tracker.new_stream_tracker();
993
994        assert!(stream.try_track(9 * MB).is_ok());
995        assert_eq!(tracker.current(), 9 * MB);
996
997        assert!(stream.try_track(2 * MB).is_err());
998        assert_eq!(tracker.current(), 9 * MB);
999
1000        assert!(stream.try_track(MB).is_ok());
1001        assert_eq!(tracker.current(), 10 * MB);
1002
1003        assert!(stream.try_track(MB).is_err());
1004        assert_eq!(tracker.current(), 10 * MB);
1005
1006        drop(stream);
1007        assert_eq!(tracker.current(), 0);
1008    }
1009
1010    #[test]
1011    fn test_query_memory_tracker_unlimited() {
1012        let tracker = Arc::new(QueryMemoryTracker::builder(0, OnExhaustedPolicy::Fail).build());
1013        let mut stream = tracker.new_stream_tracker();
1014
1015        assert!(stream.try_track(10 * MB).is_ok());
1016        assert_eq!(tracker.current(), 10 * MB);
1017        drop(stream);
1018        assert_eq!(tracker.current(), 0);
1019    }
1020
1021    #[test]
1022    fn test_query_memory_tracker_rounds_to_kilobytes() {
1023        let tracker =
1024            Arc::new(QueryMemoryTracker::builder(10 * MB, OnExhaustedPolicy::Fail).build());
1025        let mut stream = tracker.new_stream_tracker();
1026
1027        assert!(stream.try_track(1_537).is_ok());
1028        assert_eq!(tracker.current(), 2 * 1024);
1029
1030        drop(stream);
1031        assert_eq!(tracker.current(), 0);
1032    }
1033
1034    #[tokio::test]
1035    async fn test_memory_tracked_stream_waits_for_capacity() {
1036        let exhausted = Arc::new(AtomicUsize::new(0));
1037        let rejected = Arc::new(AtomicUsize::new(0));
1038        let exhausted_counter = exhausted.clone();
1039        let rejected_counter = rejected.clone();
1040        let tracker = QueryMemoryTracker::builder(
1041            MB,
1042            OnExhaustedPolicy::Wait {
1043                timeout: Duration::from_millis(200),
1044            },
1045        )
1046        .on_exhausted(move || {
1047            exhausted_counter.fetch_add(1, Ordering::Relaxed);
1048        })
1049        .on_reject(move || {
1050            rejected_counter.fetch_add(1, Ordering::Relaxed);
1051        })
1052        .build();
1053        let batch = large_string_batch(700 * 1024);
1054        let expected_bytes = aligned_tracked_bytes(batch.buffer_memory_size());
1055
1056        let mut stream1 = MemoryTrackedStream::new(
1057            RecordBatches::try_new(batch.schema.clone(), vec![batch.clone()])
1058                .unwrap()
1059                .as_stream(),
1060            tracker.clone(),
1061        );
1062        let first = stream1.next().await.unwrap().unwrap();
1063        assert_eq!(first.num_rows(), 1);
1064        assert_eq!(tracker.current(), expected_bytes);
1065
1066        let stream2 = MemoryTrackedStream::new(
1067            RecordBatches::try_new(batch.schema.clone(), vec![batch])
1068                .unwrap()
1069                .as_stream(),
1070            tracker.clone(),
1071        );
1072        let waiter = tokio::spawn(async move {
1073            let mut stream2 = stream2;
1074            stream2.next().await.unwrap()
1075        });
1076
1077        sleep(Duration::from_millis(50)).await;
1078        assert!(!waiter.is_finished());
1079
1080        drop(stream1);
1081        let second = waiter.await.unwrap().unwrap();
1082        assert_eq!(second.num_rows(), 1);
1083        assert_eq!(exhausted.load(Ordering::Relaxed), 1);
1084        assert_eq!(rejected.load(Ordering::Relaxed), 0);
1085    }
1086
1087    #[tokio::test]
1088    async fn test_memory_tracked_stream_wait_times_out() {
1089        let exhausted = Arc::new(AtomicUsize::new(0));
1090        let rejected = Arc::new(AtomicUsize::new(0));
1091        let exhausted_counter = exhausted.clone();
1092        let rejected_counter = rejected.clone();
1093        let tracker = QueryMemoryTracker::builder(
1094            MB,
1095            OnExhaustedPolicy::Wait {
1096                timeout: Duration::from_millis(50),
1097            },
1098        )
1099        .on_exhausted(move || {
1100            exhausted_counter.fetch_add(1, Ordering::Relaxed);
1101        })
1102        .on_reject(move || {
1103            rejected_counter.fetch_add(1, Ordering::Relaxed);
1104        })
1105        .build();
1106        let batch = large_string_batch(700 * 1024);
1107
1108        let mut stream1 = MemoryTrackedStream::new(
1109            RecordBatches::try_new(batch.schema.clone(), vec![batch.clone()])
1110                .unwrap()
1111                .as_stream(),
1112            tracker.clone(),
1113        );
1114        let first = stream1.next().await.unwrap().unwrap();
1115        assert_eq!(first.num_rows(), 1);
1116
1117        let mut stream2 = MemoryTrackedStream::new(
1118            RecordBatches::try_new(batch.schema.clone(), vec![batch])
1119                .unwrap()
1120                .as_stream(),
1121            tracker,
1122        );
1123        let result = timeout(Duration::from_secs(1), stream2.next())
1124            .await
1125            .unwrap();
1126        let error = result.unwrap().unwrap_err();
1127        assert!(error.to_string().contains("timed out waiting"));
1128        assert_eq!(exhausted.load(Ordering::Relaxed), 1);
1129        assert_eq!(rejected.load(Ordering::Relaxed), 1);
1130    }
1131
1132    #[tokio::test]
1133    async fn test_memory_tracked_stream_fail_policy_rejects_immediately() {
1134        let exhausted = Arc::new(AtomicUsize::new(0));
1135        let rejected = Arc::new(AtomicUsize::new(0));
1136        let exhausted_counter = exhausted.clone();
1137        let rejected_counter = rejected.clone();
1138        let tracker = QueryMemoryTracker::builder(MB, OnExhaustedPolicy::Fail)
1139            .on_exhausted(move || {
1140                exhausted_counter.fetch_add(1, Ordering::Relaxed);
1141            })
1142            .on_reject(move || {
1143                rejected_counter.fetch_add(1, Ordering::Relaxed);
1144            })
1145            .build();
1146        let batch = large_string_batch(700 * 1024);
1147
1148        let mut stream1 = MemoryTrackedStream::new(
1149            RecordBatches::try_new(batch.schema.clone(), vec![batch.clone()])
1150                .unwrap()
1151                .as_stream(),
1152            tracker.clone(),
1153        );
1154        let first = stream1.next().await.unwrap().unwrap();
1155        assert_eq!(first.num_rows(), 1);
1156
1157        let mut stream2 = MemoryTrackedStream::new(
1158            RecordBatches::try_new(batch.schema.clone(), vec![batch])
1159                .unwrap()
1160                .as_stream(),
1161            tracker,
1162        );
1163        let result = stream2.next().await.unwrap();
1164        assert!(result.is_err());
1165        assert_eq!(exhausted.load(Ordering::Relaxed), 1);
1166        assert_eq!(rejected.load(Ordering::Relaxed), 1);
1167    }
1168}