common_recordbatch/
lib.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![feature(never_type)]
16
17pub mod adapter;
18pub mod cursor;
19pub mod error;
20pub mod filter;
21pub mod recordbatch;
22pub mod util;
23
24use std::fmt;
25use std::pin::Pin;
26use std::sync::Arc;
27use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
28
29use adapter::RecordBatchMetrics;
30use arc_swap::ArcSwapOption;
31use common_base::readable_size::ReadableSize;
32use common_telemetry::tracing::Span;
33pub use datafusion::physical_plan::SendableRecordBatchStream as DfSendableRecordBatchStream;
34use datatypes::arrow::array::{ArrayRef, AsArray, StringBuilder};
35use datatypes::arrow::compute::SortOptions;
36pub use datatypes::arrow::record_batch::RecordBatch as DfRecordBatch;
37use datatypes::arrow::util::pretty;
38use datatypes::prelude::{ConcreteDataType, VectorRef};
39use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
40use datatypes::types::{JsonFormat, jsonb_to_string};
41use error::Result;
42use futures::task::{Context, Poll};
43use futures::{Stream, TryStreamExt};
44pub use recordbatch::RecordBatch;
45use snafu::{ResultExt, ensure};
46
47use crate::error::NewDfRecordBatchSnafu;
48
49pub trait RecordBatchStream: Stream<Item = Result<RecordBatch>> {
50    fn name(&self) -> &str {
51        "RecordBatchStream"
52    }
53
54    fn schema(&self) -> SchemaRef;
55
56    fn output_ordering(&self) -> Option<&[OrderOption]>;
57
58    fn metrics(&self) -> Option<RecordBatchMetrics>;
59}
60
61pub type SendableRecordBatchStream = Pin<Box<dyn RecordBatchStream + Send>>;
62
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct OrderOption {
65    pub name: String,
66    pub options: SortOptions,
67}
68
69/// A wrapper that maps a [RecordBatchStream] to a new [RecordBatchStream] by applying a function to each [RecordBatch].
70///
71/// The mapper function is applied to each [RecordBatch] in the stream.
72/// The schema of the new [RecordBatchStream] is the same as the schema of the inner [RecordBatchStream] after applying the schema mapper function.
73/// The output ordering of the new [RecordBatchStream] is the same as the output ordering of the inner [RecordBatchStream].
74/// The metrics of the new [RecordBatchStream] is the same as the metrics of the inner [RecordBatchStream] if it is not `None`.
75pub struct SendableRecordBatchMapper {
76    inner: SendableRecordBatchStream,
77    /// The mapper function is applied to each [RecordBatch] in the stream.
78    /// The original schema and the mapped schema are passed to the mapper function.
79    mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
80    /// The schema of the new [RecordBatchStream] is the same as the schema of the inner [RecordBatchStream] after applying the schema mapper function.
81    schema: SchemaRef,
82    /// Whether the mapper function is applied to each [RecordBatch] in the stream.
83    apply_mapper: bool,
84}
85
86/// Maps the json type to string in the batch.
87///
88/// The json type is mapped to string by converting the json value to string.
89/// The batch is updated to have the same number of columns as the original batch,
90/// but with the json type mapped to string.
91pub fn map_json_type_to_string(
92    batch: RecordBatch,
93    original_schema: &SchemaRef,
94    mapped_schema: &SchemaRef,
95) -> Result<RecordBatch> {
96    let mut vectors = Vec::with_capacity(original_schema.column_schemas().len());
97    for (vector, schema) in batch.columns().iter().zip(original_schema.column_schemas()) {
98        if let ConcreteDataType::Json(j) = &schema.data_type {
99            if matches!(&j.format, JsonFormat::Jsonb) {
100                let mut string_vector_builder = StringBuilder::new();
101                let binary_vector = vector.as_binary::<i32>();
102                for value in binary_vector.iter() {
103                    let Some(value) = value else {
104                        string_vector_builder.append_null();
105                        continue;
106                    };
107                    let string_value =
108                        jsonb_to_string(value).with_context(|_| error::CastVectorSnafu {
109                            from_type: schema.data_type.clone(),
110                            to_type: ConcreteDataType::string_datatype(),
111                        })?;
112                    string_vector_builder.append_value(string_value);
113                }
114
115                let string_vector = string_vector_builder.finish();
116                vectors.push(Arc::new(string_vector) as ArrayRef);
117            } else {
118                vectors.push(vector.clone());
119            }
120        } else {
121            vectors.push(vector.clone());
122        }
123    }
124
125    let record_batch = datatypes::arrow::record_batch::RecordBatch::try_new(
126        mapped_schema.arrow_schema().clone(),
127        vectors,
128    )
129    .context(NewDfRecordBatchSnafu)?;
130    Ok(RecordBatch::from_df_record_batch(
131        mapped_schema.clone(),
132        record_batch,
133    ))
134}
135
136/// Maps the json type to string in the schema.
137///
138/// The json type is mapped to string by converting the json value to string.
139/// The schema is updated to have the same number of columns as the original schema,
140/// but with the json type mapped to string.
141///
142/// Returns the new schema and whether the schema needs to be mapped to string.
143pub fn map_json_type_to_string_schema(schema: SchemaRef) -> (SchemaRef, bool) {
144    let mut new_columns = Vec::with_capacity(schema.column_schemas().len());
145    let mut apply_mapper = false;
146    for column in schema.column_schemas() {
147        if matches!(column.data_type, ConcreteDataType::Json(_)) {
148            new_columns.push(ColumnSchema::new(
149                column.name.clone(),
150                ConcreteDataType::string_datatype(),
151                column.is_nullable(),
152            ));
153            apply_mapper = true;
154        } else {
155            new_columns.push(column.clone());
156        }
157    }
158    (Arc::new(Schema::new(new_columns)), apply_mapper)
159}
160
161impl SendableRecordBatchMapper {
162    /// Creates a new [SendableRecordBatchMapper] with the given inner [RecordBatchStream], mapper function, and schema mapper function.
163    pub fn new(
164        inner: SendableRecordBatchStream,
165        mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
166        schema_mapper: fn(SchemaRef) -> (SchemaRef, bool),
167    ) -> Self {
168        let (mapped_schema, apply_mapper) = schema_mapper(inner.schema());
169        Self {
170            inner,
171            mapper,
172            schema: mapped_schema,
173            apply_mapper,
174        }
175    }
176}
177
178impl RecordBatchStream for SendableRecordBatchMapper {
179    fn name(&self) -> &str {
180        "SendableRecordBatchMapper"
181    }
182
183    fn schema(&self) -> SchemaRef {
184        self.schema.clone()
185    }
186
187    fn output_ordering(&self) -> Option<&[OrderOption]> {
188        self.inner.output_ordering()
189    }
190
191    fn metrics(&self) -> Option<RecordBatchMetrics> {
192        self.inner.metrics()
193    }
194}
195
196impl Stream for SendableRecordBatchMapper {
197    type Item = Result<RecordBatch>;
198
199    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
200        if self.apply_mapper {
201            Pin::new(&mut self.inner).poll_next(cx).map(|opt| {
202                opt.map(|result| {
203                    result
204                        .and_then(|batch| (self.mapper)(batch, &self.inner.schema(), &self.schema))
205                })
206            })
207        } else {
208            Pin::new(&mut self.inner).poll_next(cx)
209        }
210    }
211}
212
213/// EmptyRecordBatchStream can be used to create a RecordBatchStream
214/// that will produce no results
215pub struct EmptyRecordBatchStream {
216    /// Schema wrapped by Arc
217    schema: SchemaRef,
218}
219
220impl EmptyRecordBatchStream {
221    /// Create an empty RecordBatchStream
222    pub fn new(schema: SchemaRef) -> Self {
223        Self { schema }
224    }
225}
226
227impl RecordBatchStream for EmptyRecordBatchStream {
228    fn schema(&self) -> SchemaRef {
229        self.schema.clone()
230    }
231
232    fn output_ordering(&self) -> Option<&[OrderOption]> {
233        None
234    }
235
236    fn metrics(&self) -> Option<RecordBatchMetrics> {
237        None
238    }
239}
240
241impl Stream for EmptyRecordBatchStream {
242    type Item = Result<RecordBatch>;
243
244    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
245        Poll::Ready(None)
246    }
247}
248
249#[derive(Debug, PartialEq)]
250pub struct RecordBatches {
251    schema: SchemaRef,
252    batches: Vec<RecordBatch>,
253}
254
255impl RecordBatches {
256    pub fn try_from_columns<I: IntoIterator<Item = VectorRef>>(
257        schema: SchemaRef,
258        columns: I,
259    ) -> Result<Self> {
260        let batches = vec![RecordBatch::new(schema.clone(), columns)?];
261        Ok(Self { schema, batches })
262    }
263
264    pub async fn try_collect(stream: SendableRecordBatchStream) -> Result<Self> {
265        let schema = stream.schema();
266        let batches = stream.try_collect::<Vec<_>>().await?;
267        Ok(Self { schema, batches })
268    }
269
270    #[inline]
271    pub fn empty() -> Self {
272        Self {
273            schema: Arc::new(Schema::new(vec![])),
274            batches: vec![],
275        }
276    }
277
278    pub fn iter(&self) -> impl Iterator<Item = &RecordBatch> {
279        self.batches.iter()
280    }
281
282    pub fn pretty_print(&self) -> Result<String> {
283        let df_batches = &self
284            .iter()
285            .map(|x| x.df_record_batch().clone())
286            .collect::<Vec<_>>();
287        let result = pretty::pretty_format_batches(df_batches).context(error::FormatSnafu)?;
288
289        Ok(result.to_string())
290    }
291
292    pub fn try_new(schema: SchemaRef, batches: Vec<RecordBatch>) -> Result<Self> {
293        for batch in &batches {
294            ensure!(
295                batch.schema == schema,
296                error::CreateRecordBatchesSnafu {
297                    reason: format!(
298                        "expect RecordBatch schema equals {:?}, actual: {:?}",
299                        schema, batch.schema
300                    )
301                }
302            )
303        }
304        Ok(Self { schema, batches })
305    }
306
307    pub fn schema(&self) -> SchemaRef {
308        self.schema.clone()
309    }
310
311    pub fn take(self) -> Vec<RecordBatch> {
312        self.batches
313    }
314
315    pub fn as_stream(&self) -> SendableRecordBatchStream {
316        Box::pin(SimpleRecordBatchStream {
317            inner: RecordBatches {
318                schema: self.schema(),
319                batches: self.batches.clone(),
320            },
321            index: 0,
322        })
323    }
324}
325
326impl IntoIterator for RecordBatches {
327    type Item = RecordBatch;
328    type IntoIter = std::vec::IntoIter<Self::Item>;
329
330    fn into_iter(self) -> Self::IntoIter {
331        self.batches.into_iter()
332    }
333}
334
335pub struct SimpleRecordBatchStream {
336    inner: RecordBatches,
337    index: usize,
338}
339
340impl RecordBatchStream for SimpleRecordBatchStream {
341    fn schema(&self) -> SchemaRef {
342        self.inner.schema()
343    }
344
345    fn output_ordering(&self) -> Option<&[OrderOption]> {
346        None
347    }
348
349    fn metrics(&self) -> Option<RecordBatchMetrics> {
350        None
351    }
352}
353
354impl Stream for SimpleRecordBatchStream {
355    type Item = Result<RecordBatch>;
356
357    fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
358        Poll::Ready(if self.index < self.inner.batches.len() {
359            let batch = self.inner.batches[self.index].clone();
360            self.index += 1;
361            Some(Ok(batch))
362        } else {
363            None
364        })
365    }
366}
367
368/// Adapt a [Stream] of [RecordBatch] to a [RecordBatchStream].
369pub struct RecordBatchStreamWrapper<S> {
370    pub schema: SchemaRef,
371    pub stream: S,
372    pub output_ordering: Option<Vec<OrderOption>>,
373    pub metrics: Arc<ArcSwapOption<RecordBatchMetrics>>,
374    pub span: Span,
375}
376
377impl<S> RecordBatchStreamWrapper<S> {
378    /// Creates a [RecordBatchStreamWrapper] without output ordering requirement.
379    pub fn new(schema: SchemaRef, stream: S) -> RecordBatchStreamWrapper<S> {
380        RecordBatchStreamWrapper {
381            schema,
382            stream,
383            output_ordering: None,
384            metrics: Default::default(),
385            span: Span::current(),
386        }
387    }
388}
389
390impl<S: Stream<Item = Result<RecordBatch>> + Unpin> RecordBatchStream
391    for RecordBatchStreamWrapper<S>
392{
393    fn name(&self) -> &str {
394        "RecordBatchStreamWrapper"
395    }
396
397    fn schema(&self) -> SchemaRef {
398        self.schema.clone()
399    }
400
401    fn output_ordering(&self) -> Option<&[OrderOption]> {
402        self.output_ordering.as_deref()
403    }
404
405    fn metrics(&self) -> Option<RecordBatchMetrics> {
406        self.metrics.load().as_ref().map(|s| s.as_ref().clone())
407    }
408}
409
410impl<S: Stream<Item = Result<RecordBatch>> + Unpin> Stream for RecordBatchStreamWrapper<S> {
411    type Item = Result<RecordBatch>;
412
413    fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
414        let _entered = self.span.clone().entered();
415        Pin::new(&mut self.stream).poll_next(ctx)
416    }
417}
418
419/// Memory permit for a stream, providing privileged access or rate limiting.
420///
421/// The permit tracks whether this stream has privileged Top-K status.
422/// When dropped, it automatically releases any privileged slot it holds.
423pub struct MemoryPermit {
424    tracker: QueryMemoryTracker,
425    is_privileged: AtomicBool,
426}
427
428impl MemoryPermit {
429    /// Check if this permit currently has privileged status.
430    pub fn is_privileged(&self) -> bool {
431        self.is_privileged.load(Ordering::Acquire)
432    }
433
434    /// Ensure this permit has privileged status by acquiring a slot if available.
435    /// Returns true if privileged (either already privileged or just acquired privilege).
436    fn ensure_privileged(&self) -> bool {
437        if self.is_privileged.load(Ordering::Acquire) {
438            return true;
439        }
440
441        // Try to claim a privileged slot
442        self.tracker
443            .privileged_count
444            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |count| {
445                if count < self.tracker.privileged_slots {
446                    Some(count + 1)
447                } else {
448                    None
449                }
450            })
451            .map(|_| {
452                self.is_privileged.store(true, Ordering::Release);
453                true
454            })
455            .unwrap_or(false)
456    }
457
458    /// Track additional memory usage with this permit.
459    /// Returns error if limit is exceeded.
460    ///
461    /// # Arguments
462    /// * `additional` - Additional memory size to track in bytes
463    /// * `stream_tracked` - Total memory already tracked by this stream
464    ///
465    /// # Behavior
466    /// - Privileged streams: Can push global memory usage up to full limit
467    /// - Standard-tier streams: Can push global memory usage up to limit * standard_tier_memory_fraction (default: 0.7)
468    /// - Standard-tier streams automatically attempt to acquire privilege if slots become available
469    /// - The configured limit is absolute hard limit - no stream can exceed it
470    pub fn track(&self, additional: usize, stream_tracked: usize) -> Result<()> {
471        // Ensure privileged status if possible
472        let is_privileged = self.ensure_privileged();
473
474        self.tracker
475            .track_internal(additional, is_privileged, stream_tracked)
476    }
477
478    /// Release tracked memory.
479    ///
480    /// # Arguments
481    /// * `amount` - Amount of memory to release in bytes
482    pub fn release(&self, amount: usize) {
483        self.tracker.release(amount);
484    }
485}
486
487impl Drop for MemoryPermit {
488    fn drop(&mut self) {
489        // Release privileged slot if we had one
490        if self.is_privileged.load(Ordering::Acquire) {
491            self.tracker
492                .privileged_count
493                .fetch_sub(1, Ordering::Release);
494        }
495    }
496}
497
498/// Memory tracker for RecordBatch streams. Clone to share the same limit across queries.
499///
500/// Implements a two-tier memory allocation strategy:
501/// - **Privileged tier**: First N streams (default: 20) can use up to the full memory limit
502/// - **Standard tier**: Remaining streams are restricted to a fraction of the limit (default: 70%)
503/// - Privilege is granted on a first-come-first-served basis
504/// - The configured limit is an absolute hard cap - no stream can exceed it
505#[derive(Clone)]
506pub struct QueryMemoryTracker {
507    current: Arc<AtomicUsize>,
508    limit: usize,
509    standard_tier_memory_fraction: f64,
510    privileged_count: Arc<AtomicUsize>,
511    privileged_slots: usize,
512    on_update: Option<Arc<dyn Fn(usize) + Send + Sync>>,
513    on_reject: Option<Arc<dyn Fn() + Send + Sync>>,
514}
515
516impl fmt::Debug for QueryMemoryTracker {
517    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
518        f.debug_struct("QueryMemoryTracker")
519            .field("current", &self.current.load(Ordering::Acquire))
520            .field("limit", &self.limit)
521            .field(
522                "standard_tier_memory_fraction",
523                &self.standard_tier_memory_fraction,
524            )
525            .field(
526                "privileged_count",
527                &self.privileged_count.load(Ordering::Acquire),
528            )
529            .field("privileged_slots", &self.privileged_slots)
530            .field("on_update", &self.on_update.is_some())
531            .field("on_reject", &self.on_reject.is_some())
532            .finish()
533    }
534}
535
536impl QueryMemoryTracker {
537    // Default privileged slots when max_concurrent_queries is 0.
538    const DEFAULT_PRIVILEGED_SLOTS: usize = 20;
539    // Ratio for privileged tier: 70% queries get privileged access, standard tier uses 70% memory.
540    const DEFAULT_PRIVILEGED_TIER_RATIO: f64 = 0.7;
541
542    /// Create a new memory tracker with the given limit and max_concurrent_queries.
543    /// Calculates privileged slots as 70% of max_concurrent_queries (or 20 if max_concurrent_queries is 0).
544    ///
545    /// # Arguments
546    /// * `limit` - Maximum memory usage in bytes (hard limit for all streams). 0 means unlimited.
547    /// * `max_concurrent_queries` - Maximum number of concurrent queries (0 = unlimited).
548    pub fn new(limit: usize, max_concurrent_queries: usize) -> Self {
549        let privileged_slots = Self::calculate_privileged_slots(max_concurrent_queries);
550        Self::with_privileged_slots(limit, privileged_slots)
551    }
552
553    /// Create a new memory tracker with custom privileged slots limit.
554    pub fn with_privileged_slots(limit: usize, privileged_slots: usize) -> Self {
555        Self::with_config(limit, privileged_slots, Self::DEFAULT_PRIVILEGED_TIER_RATIO)
556    }
557
558    /// Create a new memory tracker with full configuration.
559    ///
560    /// # Arguments
561    /// * `limit` - Maximum memory usage in bytes (hard limit for all streams). 0 means unlimited.
562    /// * `privileged_slots` - Maximum number of streams that can get privileged status.
563    /// * `standard_tier_memory_fraction` - Memory fraction for standard-tier streams (range: [0.0, 1.0]).
564    ///
565    /// # Panics
566    /// Panics if `standard_tier_memory_fraction` is not in the range [0.0, 1.0].
567    pub fn with_config(
568        limit: usize,
569        privileged_slots: usize,
570        standard_tier_memory_fraction: f64,
571    ) -> Self {
572        assert!(
573            (0.0..=1.0).contains(&standard_tier_memory_fraction),
574            "standard_tier_memory_fraction must be in [0.0, 1.0], got {}",
575            standard_tier_memory_fraction
576        );
577
578        Self {
579            current: Arc::new(AtomicUsize::new(0)),
580            limit,
581            standard_tier_memory_fraction,
582            privileged_count: Arc::new(AtomicUsize::new(0)),
583            privileged_slots,
584            on_update: None,
585            on_reject: None,
586        }
587    }
588
589    /// Register a new permit for memory tracking.
590    /// The first `privileged_slots` permits get privileged status automatically.
591    /// The returned permit can be shared across multiple streams of the same query.
592    pub fn register_permit(&self) -> MemoryPermit {
593        // Try to claim a privileged slot
594        let is_privileged = self
595            .privileged_count
596            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |count| {
597                if count < self.privileged_slots {
598                    Some(count + 1)
599                } else {
600                    None
601                }
602            })
603            .is_ok();
604
605        MemoryPermit {
606            tracker: self.clone(),
607            is_privileged: AtomicBool::new(is_privileged),
608        }
609    }
610
611    /// Set a callback to be called whenever the usage changes successfully.
612    /// The callback receives the new total usage in bytes.
613    ///
614    /// # Note
615    /// The callback is called after both successful `track()` and `release()` operations.
616    /// It is called even when `limit == 0` (unlimited mode) to track actual usage.
617    pub fn with_on_update<F>(mut self, on_update: F) -> Self
618    where
619        F: Fn(usize) + Send + Sync + 'static,
620    {
621        self.on_update = Some(Arc::new(on_update));
622        self
623    }
624
625    /// Set a callback to be called when memory allocation is rejected.
626    ///
627    /// # Note
628    /// This is only called when `track()` fails due to exceeding the limit.
629    /// It is never called when `limit == 0` (unlimited mode).
630    pub fn with_on_reject<F>(mut self, on_reject: F) -> Self
631    where
632        F: Fn() + Send + Sync + 'static,
633    {
634        self.on_reject = Some(Arc::new(on_reject));
635        self
636    }
637
638    /// Get the current memory usage in bytes.
639    pub fn current(&self) -> usize {
640        self.current.load(Ordering::Acquire)
641    }
642
643    fn calculate_privileged_slots(max_concurrent_queries: usize) -> usize {
644        if max_concurrent_queries == 0 {
645            Self::DEFAULT_PRIVILEGED_SLOTS
646        } else {
647            ((max_concurrent_queries as f64 * Self::DEFAULT_PRIVILEGED_TIER_RATIO) as usize).max(1)
648        }
649    }
650
651    /// Internal method to track additional memory usage.
652    ///
653    /// Called by `MemoryPermit::track()`. Use `MemoryPermit::track()` instead of calling this directly.
654    fn track_internal(
655        &self,
656        additional: usize,
657        is_privileged: bool,
658        stream_tracked: usize,
659    ) -> Result<()> {
660        // Calculate effective global limit based on stream privilege
661        // Privileged streams: can push global usage up to full limit
662        // Standard-tier streams: can only push global usage up to fraction of limit
663        let effective_limit = if is_privileged {
664            self.limit
665        } else {
666            (self.limit as f64 * self.standard_tier_memory_fraction) as usize
667        };
668
669        let mut new_total = 0;
670        let result = self
671            .current
672            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
673                new_total = current.saturating_add(additional);
674
675                if self.limit == 0 {
676                    // Unlimited mode
677                    return Some(new_total);
678                }
679
680                // Check if new global total exceeds effective limit
681                // The configured limit is absolute hard limit - no stream can exceed it
682                if new_total <= effective_limit {
683                    Some(new_total)
684                } else {
685                    None
686                }
687            });
688
689        match result {
690            Ok(_) => {
691                if let Some(callback) = &self.on_update {
692                    callback(new_total);
693                }
694                Ok(())
695            }
696            Err(current) => {
697                if let Some(callback) = &self.on_reject {
698                    callback();
699                }
700                let msg = format!(
701                    "{} requested, {} used globally ({}%), {} used by this stream (privileged: {}), effective limit: {} ({}%), hard limit: {}",
702                    ReadableSize(additional as u64),
703                    ReadableSize(current as u64),
704                    if self.limit > 0 {
705                        current * 100 / self.limit
706                    } else {
707                        0
708                    },
709                    ReadableSize(stream_tracked as u64),
710                    is_privileged,
711                    ReadableSize(effective_limit as u64),
712                    if self.limit > 0 {
713                        effective_limit * 100 / self.limit
714                    } else {
715                        0
716                    },
717                    ReadableSize(self.limit as u64)
718                );
719                error::ExceedMemoryLimitSnafu { msg }.fail()
720            }
721        }
722    }
723
724    /// Release tracked memory.
725    ///
726    /// # Arguments
727    /// * `amount` - Amount of memory to release in bytes
728    pub fn release(&self, amount: usize) {
729        if let Ok(old_value) =
730            self.current
731                .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
732                    Some(current.saturating_sub(amount))
733                })
734            && let Some(callback) = &self.on_update
735        {
736            callback(old_value.saturating_sub(amount));
737        }
738    }
739}
740
741/// A wrapper stream that tracks memory usage of RecordBatches.
742pub struct MemoryTrackedStream {
743    inner: SendableRecordBatchStream,
744    permit: Arc<MemoryPermit>,
745    // Total tracked size, released when stream drops.
746    total_tracked: usize,
747}
748
749impl MemoryTrackedStream {
750    pub fn new(inner: SendableRecordBatchStream, permit: Arc<MemoryPermit>) -> Self {
751        Self {
752            inner,
753            permit,
754            total_tracked: 0,
755        }
756    }
757}
758
759impl Stream for MemoryTrackedStream {
760    type Item = Result<RecordBatch>;
761
762    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
763        match Pin::new(&mut self.inner).poll_next(cx) {
764            Poll::Ready(Some(Ok(batch))) => {
765                let additional = batch.buffer_memory_size();
766
767                if let Err(e) = self.permit.track(additional, self.total_tracked) {
768                    return Poll::Ready(Some(Err(e)));
769                }
770
771                self.total_tracked += additional;
772
773                Poll::Ready(Some(Ok(batch)))
774            }
775            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
776            Poll::Ready(None) => Poll::Ready(None),
777            Poll::Pending => Poll::Pending,
778        }
779    }
780
781    fn size_hint(&self) -> (usize, Option<usize>) {
782        self.inner.size_hint()
783    }
784}
785
786impl Drop for MemoryTrackedStream {
787    fn drop(&mut self) {
788        if self.total_tracked > 0 {
789            self.permit.release(self.total_tracked);
790        }
791    }
792}
793
794impl RecordBatchStream for MemoryTrackedStream {
795    fn schema(&self) -> SchemaRef {
796        self.inner.schema()
797    }
798
799    fn output_ordering(&self) -> Option<&[OrderOption]> {
800        self.inner.output_ordering()
801    }
802
803    fn metrics(&self) -> Option<RecordBatchMetrics> {
804        self.inner.metrics()
805    }
806}
807
808#[cfg(test)]
809mod tests {
810    use std::sync::Arc;
811
812    use datatypes::prelude::{ConcreteDataType, VectorRef};
813    use datatypes::schema::{ColumnSchema, Schema};
814    use datatypes::vectors::{BooleanVector, Int32Vector, StringVector};
815
816    use super::*;
817
818    #[test]
819    fn test_recordbatches_try_from_columns() {
820        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
821            "a",
822            ConcreteDataType::int32_datatype(),
823            false,
824        )]));
825        let result = RecordBatches::try_from_columns(
826            schema.clone(),
827            vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
828        );
829        assert!(result.is_err());
830
831        let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
832        let expected = vec![RecordBatch::new(schema.clone(), vec![v.clone()]).unwrap()];
833        let r = RecordBatches::try_from_columns(schema, vec![v]).unwrap();
834        assert_eq!(r.take(), expected);
835    }
836
837    #[test]
838    fn test_recordbatches_try_new() {
839        let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
840        let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
841        let column_c = ColumnSchema::new("c", ConcreteDataType::boolean_datatype(), false);
842
843        let va: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
844        let vb: VectorRef = Arc::new(StringVector::from(vec!["hello", "world"]));
845        let vc: VectorRef = Arc::new(BooleanVector::from(vec![true, false]));
846
847        let schema1 = Arc::new(Schema::new(vec![column_a.clone(), column_b]));
848        let batch1 = RecordBatch::new(schema1.clone(), vec![va.clone(), vb]).unwrap();
849
850        let schema2 = Arc::new(Schema::new(vec![column_a, column_c]));
851        let batch2 = RecordBatch::new(schema2.clone(), vec![va, vc]).unwrap();
852
853        let result = RecordBatches::try_new(schema1.clone(), vec![batch1.clone(), batch2]);
854        assert!(result.is_err());
855        assert_eq!(
856            result.unwrap_err().to_string(),
857            format!(
858                "Failed to create RecordBatches, reason: expect RecordBatch schema equals {schema1:?}, actual: {schema2:?}",
859            )
860        );
861
862        let batches = RecordBatches::try_new(schema1.clone(), vec![batch1.clone()]).unwrap();
863        let expected = "\
864+---+-------+
865| a | b     |
866+---+-------+
867| 1 | hello |
868| 2 | world |
869+---+-------+";
870        assert_eq!(batches.pretty_print().unwrap(), expected);
871
872        assert_eq!(schema1, batches.schema());
873        assert_eq!(vec![batch1], batches.take());
874    }
875
876    #[tokio::test]
877    async fn test_simple_recordbatch_stream() {
878        let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
879        let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
880        let schema = Arc::new(Schema::new(vec![column_a, column_b]));
881
882        let va1: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
883        let vb1: VectorRef = Arc::new(StringVector::from(vec!["a", "b"]));
884        let batch1 = RecordBatch::new(schema.clone(), vec![va1, vb1]).unwrap();
885
886        let va2: VectorRef = Arc::new(Int32Vector::from_slice([3, 4, 5]));
887        let vb2: VectorRef = Arc::new(StringVector::from(vec!["c", "d", "e"]));
888        let batch2 = RecordBatch::new(schema.clone(), vec![va2, vb2]).unwrap();
889
890        let recordbatches =
891            RecordBatches::try_new(schema.clone(), vec![batch1.clone(), batch2.clone()]).unwrap();
892        let stream = recordbatches.as_stream();
893        let collected = util::collect(stream).await.unwrap();
894        assert_eq!(collected.len(), 2);
895        assert_eq!(collected[0], batch1);
896        assert_eq!(collected[1], batch2);
897    }
898
899    #[test]
900    fn test_query_memory_tracker_basic() {
901        let tracker = Arc::new(QueryMemoryTracker::new(1000, 0));
902
903        // Register first stream - should get privileged status
904        let permit1 = tracker.register_permit();
905        assert!(permit1.is_privileged());
906
907        // Privileged stream can use up to limit
908        assert!(permit1.track(500, 0).is_ok());
909        assert_eq!(tracker.current(), 500);
910
911        // Register second stream - also privileged
912        let permit2 = tracker.register_permit();
913        assert!(permit2.is_privileged());
914        // Can add more but cannot exceed hard limit (1000)
915        assert!(permit2.track(400, 0).is_ok());
916        assert_eq!(tracker.current(), 900);
917
918        permit1.release(500);
919        permit2.release(400);
920        assert_eq!(tracker.current(), 0);
921    }
922
923    #[test]
924    fn test_query_memory_tracker_privileged_limit() {
925        // Privileged slots = 2 for easy testing
926        // Limit: 1000, standard-tier fraction: 0.7 (default)
927        // Privileged can push global to 1000, standard-tier can push global to 700
928        let tracker = Arc::new(QueryMemoryTracker::with_privileged_slots(1000, 2));
929
930        // First 2 streams are privileged
931        let permit1 = tracker.register_permit();
932        let permit2 = tracker.register_permit();
933        assert!(permit1.is_privileged());
934        assert!(permit2.is_privileged());
935
936        // Third stream is standard-tier (not privileged)
937        let permit3 = tracker.register_permit();
938        assert!(!permit3.is_privileged());
939
940        // Privileged stream uses some memory
941        assert!(permit1.track(300, 0).is_ok());
942        assert_eq!(tracker.current(), 300);
943
944        // Standard-tier can add up to 400 (total becomes 700, its effective limit)
945        assert!(permit3.track(400, 0).is_ok());
946        assert_eq!(tracker.current(), 700);
947
948        // Standard-tier stream cannot push global beyond 700
949        let err = permit3.track(100, 400).unwrap_err();
950        let err_msg = err.to_string();
951        assert!(err_msg.contains("400B used by this stream"));
952        assert!(err_msg.contains("effective limit: 700B (70%)"));
953        assert!(err_msg.contains("700B used globally (70%)"));
954        assert_eq!(tracker.current(), 700);
955
956        permit1.release(300);
957        permit3.release(400);
958        assert_eq!(tracker.current(), 0);
959    }
960
961    #[test]
962    fn test_query_memory_tracker_promotion() {
963        // Privileged slots = 1 for easy testing
964        let tracker = Arc::new(QueryMemoryTracker::with_privileged_slots(1000, 1));
965
966        // First stream is privileged
967        let permit1 = tracker.register_permit();
968        assert!(permit1.is_privileged());
969
970        // Second stream is standard-tier (can only use 500)
971        let permit2 = tracker.register_permit();
972        assert!(!permit2.is_privileged());
973
974        // Standard-tier can only track 500
975        assert!(permit2.track(400, 0).is_ok());
976        assert_eq!(tracker.current(), 400);
977
978        // Drop first permit to release privileged slot
979        drop(permit1);
980
981        // Second stream can now be promoted and use more memory
982        assert!(permit2.track(500, 400).is_ok());
983        assert!(permit2.is_privileged());
984        assert_eq!(tracker.current(), 900);
985
986        permit2.release(900);
987        assert_eq!(tracker.current(), 0);
988    }
989
990    #[test]
991    fn test_query_memory_tracker_privileged_hard_limit() {
992        // Test that the configured limit is absolute hard limit for all streams
993        // Privileged: can use full limit (1000)
994        // Standard-tier: can use 0.7x limit (700 with defaults)
995        let tracker = Arc::new(QueryMemoryTracker::new(1000, 0));
996
997        let permit1 = tracker.register_permit();
998        assert!(permit1.is_privileged());
999
1000        // Privileged can use up to full limit (1000)
1001        assert!(permit1.track(900, 0).is_ok());
1002        assert_eq!(tracker.current(), 900);
1003
1004        // Privileged cannot exceed hard limit (1000)
1005        assert!(permit1.track(200, 900).is_err());
1006        assert_eq!(tracker.current(), 900);
1007
1008        // Can add within hard limit
1009        assert!(permit1.track(100, 900).is_ok());
1010        assert_eq!(tracker.current(), 1000);
1011
1012        // Cannot exceed even by 1 byte
1013        assert!(permit1.track(1, 1000).is_err());
1014        assert_eq!(tracker.current(), 1000);
1015
1016        permit1.release(1000);
1017        assert_eq!(tracker.current(), 0);
1018    }
1019
1020    #[test]
1021    fn test_query_memory_tracker_standard_tier_fraction() {
1022        // Test standard-tier streams use fraction of limit
1023        // Limit: 1000, default fraction: 0.7, so standard-tier can use 700
1024        let tracker = Arc::new(QueryMemoryTracker::with_privileged_slots(1000, 1));
1025
1026        let permit1 = tracker.register_permit();
1027        assert!(permit1.is_privileged());
1028
1029        let permit2 = tracker.register_permit();
1030        assert!(!permit2.is_privileged());
1031
1032        // Standard-tier can use up to 700 (1000 * 0.7 default)
1033        assert!(permit2.track(600, 0).is_ok());
1034        assert_eq!(tracker.current(), 600);
1035
1036        // Cannot exceed standard-tier limit (700)
1037        assert!(permit2.track(200, 600).is_err());
1038        assert_eq!(tracker.current(), 600);
1039
1040        // Can add within standard-tier limit
1041        assert!(permit2.track(100, 600).is_ok());
1042        assert_eq!(tracker.current(), 700);
1043
1044        // Cannot exceed standard-tier limit
1045        assert!(permit2.track(1, 700).is_err());
1046        assert_eq!(tracker.current(), 700);
1047
1048        permit2.release(700);
1049        assert_eq!(tracker.current(), 0);
1050    }
1051}