common_recordbatch/
lib.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![feature(never_type)]
16
17pub mod adapter;
18pub mod cursor;
19pub mod error;
20pub mod filter;
21pub mod recordbatch;
22pub mod util;
23
24use std::fmt;
25use std::pin::Pin;
26use std::sync::Arc;
27use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
28
29use adapter::RecordBatchMetrics;
30use arc_swap::ArcSwapOption;
31use common_base::readable_size::ReadableSize;
32pub use datafusion::physical_plan::SendableRecordBatchStream as DfSendableRecordBatchStream;
33use datatypes::arrow::array::{ArrayRef, AsArray, StringBuilder};
34use datatypes::arrow::compute::SortOptions;
35pub use datatypes::arrow::record_batch::RecordBatch as DfRecordBatch;
36use datatypes::arrow::util::pretty;
37use datatypes::prelude::{ConcreteDataType, VectorRef};
38use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
39use datatypes::types::{JsonFormat, jsonb_to_string};
40use error::Result;
41use futures::task::{Context, Poll};
42use futures::{Stream, TryStreamExt};
43pub use recordbatch::RecordBatch;
44use snafu::{ResultExt, ensure};
45
46use crate::error::NewDfRecordBatchSnafu;
47
48pub trait RecordBatchStream: Stream<Item = Result<RecordBatch>> {
49    fn name(&self) -> &str {
50        "RecordBatchStream"
51    }
52
53    fn schema(&self) -> SchemaRef;
54
55    fn output_ordering(&self) -> Option<&[OrderOption]>;
56
57    fn metrics(&self) -> Option<RecordBatchMetrics>;
58}
59
60pub type SendableRecordBatchStream = Pin<Box<dyn RecordBatchStream + Send>>;
61
62#[derive(Debug, Clone, PartialEq, Eq)]
63pub struct OrderOption {
64    pub name: String,
65    pub options: SortOptions,
66}
67
68/// A wrapper that maps a [RecordBatchStream] to a new [RecordBatchStream] by applying a function to each [RecordBatch].
69///
70/// The mapper function is applied to each [RecordBatch] in the stream.
71/// The schema of the new [RecordBatchStream] is the same as the schema of the inner [RecordBatchStream] after applying the schema mapper function.
72/// The output ordering of the new [RecordBatchStream] is the same as the output ordering of the inner [RecordBatchStream].
73/// The metrics of the new [RecordBatchStream] is the same as the metrics of the inner [RecordBatchStream] if it is not `None`.
74pub struct SendableRecordBatchMapper {
75    inner: SendableRecordBatchStream,
76    /// The mapper function is applied to each [RecordBatch] in the stream.
77    /// The original schema and the mapped schema are passed to the mapper function.
78    mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
79    /// The schema of the new [RecordBatchStream] is the same as the schema of the inner [RecordBatchStream] after applying the schema mapper function.
80    schema: SchemaRef,
81    /// Whether the mapper function is applied to each [RecordBatch] in the stream.
82    apply_mapper: bool,
83}
84
85/// Maps the json type to string in the batch.
86///
87/// The json type is mapped to string by converting the json value to string.
88/// The batch is updated to have the same number of columns as the original batch,
89/// but with the json type mapped to string.
90pub fn map_json_type_to_string(
91    batch: RecordBatch,
92    original_schema: &SchemaRef,
93    mapped_schema: &SchemaRef,
94) -> Result<RecordBatch> {
95    let mut vectors = Vec::with_capacity(original_schema.column_schemas().len());
96    for (vector, schema) in batch.columns().iter().zip(original_schema.column_schemas()) {
97        if let ConcreteDataType::Json(j) = &schema.data_type {
98            if matches!(&j.format, JsonFormat::Jsonb) {
99                let mut string_vector_builder = StringBuilder::new();
100                let binary_vector = vector.as_binary::<i32>();
101                for value in binary_vector.iter() {
102                    let Some(value) = value else {
103                        string_vector_builder.append_null();
104                        continue;
105                    };
106                    let string_value =
107                        jsonb_to_string(value).with_context(|_| error::CastVectorSnafu {
108                            from_type: schema.data_type.clone(),
109                            to_type: ConcreteDataType::string_datatype(),
110                        })?;
111                    string_vector_builder.append_value(string_value);
112                }
113
114                let string_vector = string_vector_builder.finish();
115                vectors.push(Arc::new(string_vector) as ArrayRef);
116            } else {
117                vectors.push(vector.clone());
118            }
119        } else {
120            vectors.push(vector.clone());
121        }
122    }
123
124    let record_batch = datatypes::arrow::record_batch::RecordBatch::try_new(
125        mapped_schema.arrow_schema().clone(),
126        vectors,
127    )
128    .context(NewDfRecordBatchSnafu)?;
129    Ok(RecordBatch::from_df_record_batch(
130        mapped_schema.clone(),
131        record_batch,
132    ))
133}
134
135/// Maps the json type to string in the schema.
136///
137/// The json type is mapped to string by converting the json value to string.
138/// The schema is updated to have the same number of columns as the original schema,
139/// but with the json type mapped to string.
140///
141/// Returns the new schema and whether the schema needs to be mapped to string.
142pub fn map_json_type_to_string_schema(schema: SchemaRef) -> (SchemaRef, bool) {
143    let mut new_columns = Vec::with_capacity(schema.column_schemas().len());
144    let mut apply_mapper = false;
145    for column in schema.column_schemas() {
146        if matches!(column.data_type, ConcreteDataType::Json(_)) {
147            new_columns.push(ColumnSchema::new(
148                column.name.clone(),
149                ConcreteDataType::string_datatype(),
150                column.is_nullable(),
151            ));
152            apply_mapper = true;
153        } else {
154            new_columns.push(column.clone());
155        }
156    }
157    (Arc::new(Schema::new(new_columns)), apply_mapper)
158}
159
160impl SendableRecordBatchMapper {
161    /// Creates a new [SendableRecordBatchMapper] with the given inner [RecordBatchStream], mapper function, and schema mapper function.
162    pub fn new(
163        inner: SendableRecordBatchStream,
164        mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
165        schema_mapper: fn(SchemaRef) -> (SchemaRef, bool),
166    ) -> Self {
167        let (mapped_schema, apply_mapper) = schema_mapper(inner.schema());
168        Self {
169            inner,
170            mapper,
171            schema: mapped_schema,
172            apply_mapper,
173        }
174    }
175}
176
177impl RecordBatchStream for SendableRecordBatchMapper {
178    fn name(&self) -> &str {
179        "SendableRecordBatchMapper"
180    }
181
182    fn schema(&self) -> SchemaRef {
183        self.schema.clone()
184    }
185
186    fn output_ordering(&self) -> Option<&[OrderOption]> {
187        self.inner.output_ordering()
188    }
189
190    fn metrics(&self) -> Option<RecordBatchMetrics> {
191        self.inner.metrics()
192    }
193}
194
195impl Stream for SendableRecordBatchMapper {
196    type Item = Result<RecordBatch>;
197
198    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
199        if self.apply_mapper {
200            Pin::new(&mut self.inner).poll_next(cx).map(|opt| {
201                opt.map(|result| {
202                    result
203                        .and_then(|batch| (self.mapper)(batch, &self.inner.schema(), &self.schema))
204                })
205            })
206        } else {
207            Pin::new(&mut self.inner).poll_next(cx)
208        }
209    }
210}
211
212/// EmptyRecordBatchStream can be used to create a RecordBatchStream
213/// that will produce no results
214pub struct EmptyRecordBatchStream {
215    /// Schema wrapped by Arc
216    schema: SchemaRef,
217}
218
219impl EmptyRecordBatchStream {
220    /// Create an empty RecordBatchStream
221    pub fn new(schema: SchemaRef) -> Self {
222        Self { schema }
223    }
224}
225
226impl RecordBatchStream for EmptyRecordBatchStream {
227    fn schema(&self) -> SchemaRef {
228        self.schema.clone()
229    }
230
231    fn output_ordering(&self) -> Option<&[OrderOption]> {
232        None
233    }
234
235    fn metrics(&self) -> Option<RecordBatchMetrics> {
236        None
237    }
238}
239
240impl Stream for EmptyRecordBatchStream {
241    type Item = Result<RecordBatch>;
242
243    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
244        Poll::Ready(None)
245    }
246}
247
248#[derive(Debug, PartialEq)]
249pub struct RecordBatches {
250    schema: SchemaRef,
251    batches: Vec<RecordBatch>,
252}
253
254impl RecordBatches {
255    pub fn try_from_columns<I: IntoIterator<Item = VectorRef>>(
256        schema: SchemaRef,
257        columns: I,
258    ) -> Result<Self> {
259        let batches = vec![RecordBatch::new(schema.clone(), columns)?];
260        Ok(Self { schema, batches })
261    }
262
263    pub async fn try_collect(stream: SendableRecordBatchStream) -> Result<Self> {
264        let schema = stream.schema();
265        let batches = stream.try_collect::<Vec<_>>().await?;
266        Ok(Self { schema, batches })
267    }
268
269    #[inline]
270    pub fn empty() -> Self {
271        Self {
272            schema: Arc::new(Schema::new(vec![])),
273            batches: vec![],
274        }
275    }
276
277    pub fn iter(&self) -> impl Iterator<Item = &RecordBatch> {
278        self.batches.iter()
279    }
280
281    pub fn pretty_print(&self) -> Result<String> {
282        let df_batches = &self
283            .iter()
284            .map(|x| x.df_record_batch().clone())
285            .collect::<Vec<_>>();
286        let result = pretty::pretty_format_batches(df_batches).context(error::FormatSnafu)?;
287
288        Ok(result.to_string())
289    }
290
291    pub fn try_new(schema: SchemaRef, batches: Vec<RecordBatch>) -> Result<Self> {
292        for batch in &batches {
293            ensure!(
294                batch.schema == schema,
295                error::CreateRecordBatchesSnafu {
296                    reason: format!(
297                        "expect RecordBatch schema equals {:?}, actual: {:?}",
298                        schema, batch.schema
299                    )
300                }
301            )
302        }
303        Ok(Self { schema, batches })
304    }
305
306    pub fn schema(&self) -> SchemaRef {
307        self.schema.clone()
308    }
309
310    pub fn take(self) -> Vec<RecordBatch> {
311        self.batches
312    }
313
314    pub fn as_stream(&self) -> SendableRecordBatchStream {
315        Box::pin(SimpleRecordBatchStream {
316            inner: RecordBatches {
317                schema: self.schema(),
318                batches: self.batches.clone(),
319            },
320            index: 0,
321        })
322    }
323}
324
325impl IntoIterator for RecordBatches {
326    type Item = RecordBatch;
327    type IntoIter = std::vec::IntoIter<Self::Item>;
328
329    fn into_iter(self) -> Self::IntoIter {
330        self.batches.into_iter()
331    }
332}
333
334pub struct SimpleRecordBatchStream {
335    inner: RecordBatches,
336    index: usize,
337}
338
339impl RecordBatchStream for SimpleRecordBatchStream {
340    fn schema(&self) -> SchemaRef {
341        self.inner.schema()
342    }
343
344    fn output_ordering(&self) -> Option<&[OrderOption]> {
345        None
346    }
347
348    fn metrics(&self) -> Option<RecordBatchMetrics> {
349        None
350    }
351}
352
353impl Stream for SimpleRecordBatchStream {
354    type Item = Result<RecordBatch>;
355
356    fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
357        Poll::Ready(if self.index < self.inner.batches.len() {
358            let batch = self.inner.batches[self.index].clone();
359            self.index += 1;
360            Some(Ok(batch))
361        } else {
362            None
363        })
364    }
365}
366
367/// Adapt a [Stream] of [RecordBatch] to a [RecordBatchStream].
368pub struct RecordBatchStreamWrapper<S> {
369    pub schema: SchemaRef,
370    pub stream: S,
371    pub output_ordering: Option<Vec<OrderOption>>,
372    pub metrics: Arc<ArcSwapOption<RecordBatchMetrics>>,
373}
374
375impl<S> RecordBatchStreamWrapper<S> {
376    /// Creates a [RecordBatchStreamWrapper] without output ordering requirement.
377    pub fn new(schema: SchemaRef, stream: S) -> RecordBatchStreamWrapper<S> {
378        RecordBatchStreamWrapper {
379            schema,
380            stream,
381            output_ordering: None,
382            metrics: Default::default(),
383        }
384    }
385}
386
387impl<S: Stream<Item = Result<RecordBatch>> + Unpin> RecordBatchStream
388    for RecordBatchStreamWrapper<S>
389{
390    fn name(&self) -> &str {
391        "RecordBatchStreamWrapper"
392    }
393
394    fn schema(&self) -> SchemaRef {
395        self.schema.clone()
396    }
397
398    fn output_ordering(&self) -> Option<&[OrderOption]> {
399        self.output_ordering.as_deref()
400    }
401
402    fn metrics(&self) -> Option<RecordBatchMetrics> {
403        self.metrics.load().as_ref().map(|s| s.as_ref().clone())
404    }
405}
406
407impl<S: Stream<Item = Result<RecordBatch>> + Unpin> Stream for RecordBatchStreamWrapper<S> {
408    type Item = Result<RecordBatch>;
409
410    fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
411        Pin::new(&mut self.stream).poll_next(ctx)
412    }
413}
414
415/// Memory permit for a stream, providing privileged access or rate limiting.
416///
417/// The permit tracks whether this stream has privileged Top-K status.
418/// When dropped, it automatically releases any privileged slot it holds.
419pub struct MemoryPermit {
420    tracker: QueryMemoryTracker,
421    is_privileged: AtomicBool,
422}
423
424impl MemoryPermit {
425    /// Check if this permit currently has privileged status.
426    pub fn is_privileged(&self) -> bool {
427        self.is_privileged.load(Ordering::Acquire)
428    }
429
430    /// Ensure this permit has privileged status by acquiring a slot if available.
431    /// Returns true if privileged (either already privileged or just acquired privilege).
432    fn ensure_privileged(&self) -> bool {
433        if self.is_privileged.load(Ordering::Acquire) {
434            return true;
435        }
436
437        // Try to claim a privileged slot
438        self.tracker
439            .privileged_count
440            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |count| {
441                if count < self.tracker.privileged_slots {
442                    Some(count + 1)
443                } else {
444                    None
445                }
446            })
447            .map(|_| {
448                self.is_privileged.store(true, Ordering::Release);
449                true
450            })
451            .unwrap_or(false)
452    }
453
454    /// Track additional memory usage with this permit.
455    /// Returns error if limit is exceeded.
456    ///
457    /// # Arguments
458    /// * `additional` - Additional memory size to track in bytes
459    /// * `stream_tracked` - Total memory already tracked by this stream
460    ///
461    /// # Behavior
462    /// - Privileged streams: Can push global memory usage up to full limit
463    /// - Standard-tier streams: Can push global memory usage up to limit * standard_tier_memory_fraction (default: 0.7)
464    /// - Standard-tier streams automatically attempt to acquire privilege if slots become available
465    /// - The configured limit is absolute hard limit - no stream can exceed it
466    pub fn track(&self, additional: usize, stream_tracked: usize) -> Result<()> {
467        // Ensure privileged status if possible
468        let is_privileged = self.ensure_privileged();
469
470        self.tracker
471            .track_internal(additional, is_privileged, stream_tracked)
472    }
473
474    /// Release tracked memory.
475    ///
476    /// # Arguments
477    /// * `amount` - Amount of memory to release in bytes
478    pub fn release(&self, amount: usize) {
479        self.tracker.release(amount);
480    }
481}
482
483impl Drop for MemoryPermit {
484    fn drop(&mut self) {
485        // Release privileged slot if we had one
486        if self.is_privileged.load(Ordering::Acquire) {
487            self.tracker
488                .privileged_count
489                .fetch_sub(1, Ordering::Release);
490        }
491    }
492}
493
494/// Memory tracker for RecordBatch streams. Clone to share the same limit across queries.
495///
496/// Implements a two-tier memory allocation strategy:
497/// - **Privileged tier**: First N streams (default: 20) can use up to the full memory limit
498/// - **Standard tier**: Remaining streams are restricted to a fraction of the limit (default: 70%)
499/// - Privilege is granted on a first-come-first-served basis
500/// - The configured limit is an absolute hard cap - no stream can exceed it
501#[derive(Clone)]
502pub struct QueryMemoryTracker {
503    current: Arc<AtomicUsize>,
504    limit: usize,
505    standard_tier_memory_fraction: f64,
506    privileged_count: Arc<AtomicUsize>,
507    privileged_slots: usize,
508    on_update: Option<Arc<dyn Fn(usize) + Send + Sync>>,
509    on_reject: Option<Arc<dyn Fn() + Send + Sync>>,
510}
511
512impl fmt::Debug for QueryMemoryTracker {
513    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
514        f.debug_struct("QueryMemoryTracker")
515            .field("current", &self.current.load(Ordering::Acquire))
516            .field("limit", &self.limit)
517            .field(
518                "standard_tier_memory_fraction",
519                &self.standard_tier_memory_fraction,
520            )
521            .field(
522                "privileged_count",
523                &self.privileged_count.load(Ordering::Acquire),
524            )
525            .field("privileged_slots", &self.privileged_slots)
526            .field("on_update", &self.on_update.is_some())
527            .field("on_reject", &self.on_reject.is_some())
528            .finish()
529    }
530}
531
532impl QueryMemoryTracker {
533    // Default privileged slots when max_concurrent_queries is 0.
534    const DEFAULT_PRIVILEGED_SLOTS: usize = 20;
535    // Ratio for privileged tier: 70% queries get privileged access, standard tier uses 70% memory.
536    const DEFAULT_PRIVILEGED_TIER_RATIO: f64 = 0.7;
537
538    /// Create a new memory tracker with the given limit and max_concurrent_queries.
539    /// Calculates privileged slots as 70% of max_concurrent_queries (or 20 if max_concurrent_queries is 0).
540    ///
541    /// # Arguments
542    /// * `limit` - Maximum memory usage in bytes (hard limit for all streams). 0 means unlimited.
543    /// * `max_concurrent_queries` - Maximum number of concurrent queries (0 = unlimited).
544    pub fn new(limit: usize, max_concurrent_queries: usize) -> Self {
545        let privileged_slots = Self::calculate_privileged_slots(max_concurrent_queries);
546        Self::with_privileged_slots(limit, privileged_slots)
547    }
548
549    /// Create a new memory tracker with custom privileged slots limit.
550    pub fn with_privileged_slots(limit: usize, privileged_slots: usize) -> Self {
551        Self::with_config(limit, privileged_slots, Self::DEFAULT_PRIVILEGED_TIER_RATIO)
552    }
553
554    /// Create a new memory tracker with full configuration.
555    ///
556    /// # Arguments
557    /// * `limit` - Maximum memory usage in bytes (hard limit for all streams). 0 means unlimited.
558    /// * `privileged_slots` - Maximum number of streams that can get privileged status.
559    /// * `standard_tier_memory_fraction` - Memory fraction for standard-tier streams (range: [0.0, 1.0]).
560    ///
561    /// # Panics
562    /// Panics if `standard_tier_memory_fraction` is not in the range [0.0, 1.0].
563    pub fn with_config(
564        limit: usize,
565        privileged_slots: usize,
566        standard_tier_memory_fraction: f64,
567    ) -> Self {
568        assert!(
569            (0.0..=1.0).contains(&standard_tier_memory_fraction),
570            "standard_tier_memory_fraction must be in [0.0, 1.0], got {}",
571            standard_tier_memory_fraction
572        );
573
574        Self {
575            current: Arc::new(AtomicUsize::new(0)),
576            limit,
577            standard_tier_memory_fraction,
578            privileged_count: Arc::new(AtomicUsize::new(0)),
579            privileged_slots,
580            on_update: None,
581            on_reject: None,
582        }
583    }
584
585    /// Register a new permit for memory tracking.
586    /// The first `privileged_slots` permits get privileged status automatically.
587    /// The returned permit can be shared across multiple streams of the same query.
588    pub fn register_permit(&self) -> MemoryPermit {
589        // Try to claim a privileged slot
590        let is_privileged = self
591            .privileged_count
592            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |count| {
593                if count < self.privileged_slots {
594                    Some(count + 1)
595                } else {
596                    None
597                }
598            })
599            .is_ok();
600
601        MemoryPermit {
602            tracker: self.clone(),
603            is_privileged: AtomicBool::new(is_privileged),
604        }
605    }
606
607    /// Set a callback to be called whenever the usage changes successfully.
608    /// The callback receives the new total usage in bytes.
609    ///
610    /// # Note
611    /// The callback is called after both successful `track()` and `release()` operations.
612    /// It is called even when `limit == 0` (unlimited mode) to track actual usage.
613    pub fn with_on_update<F>(mut self, on_update: F) -> Self
614    where
615        F: Fn(usize) + Send + Sync + 'static,
616    {
617        self.on_update = Some(Arc::new(on_update));
618        self
619    }
620
621    /// Set a callback to be called when memory allocation is rejected.
622    ///
623    /// # Note
624    /// This is only called when `track()` fails due to exceeding the limit.
625    /// It is never called when `limit == 0` (unlimited mode).
626    pub fn with_on_reject<F>(mut self, on_reject: F) -> Self
627    where
628        F: Fn() + Send + Sync + 'static,
629    {
630        self.on_reject = Some(Arc::new(on_reject));
631        self
632    }
633
634    /// Get the current memory usage in bytes.
635    pub fn current(&self) -> usize {
636        self.current.load(Ordering::Acquire)
637    }
638
639    fn calculate_privileged_slots(max_concurrent_queries: usize) -> usize {
640        if max_concurrent_queries == 0 {
641            Self::DEFAULT_PRIVILEGED_SLOTS
642        } else {
643            ((max_concurrent_queries as f64 * Self::DEFAULT_PRIVILEGED_TIER_RATIO) as usize).max(1)
644        }
645    }
646
647    /// Internal method to track additional memory usage.
648    ///
649    /// Called by `MemoryPermit::track()`. Use `MemoryPermit::track()` instead of calling this directly.
650    fn track_internal(
651        &self,
652        additional: usize,
653        is_privileged: bool,
654        stream_tracked: usize,
655    ) -> Result<()> {
656        // Calculate effective global limit based on stream privilege
657        // Privileged streams: can push global usage up to full limit
658        // Standard-tier streams: can only push global usage up to fraction of limit
659        let effective_limit = if is_privileged {
660            self.limit
661        } else {
662            (self.limit as f64 * self.standard_tier_memory_fraction) as usize
663        };
664
665        let mut new_total = 0;
666        let result = self
667            .current
668            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
669                new_total = current.saturating_add(additional);
670
671                if self.limit == 0 {
672                    // Unlimited mode
673                    return Some(new_total);
674                }
675
676                // Check if new global total exceeds effective limit
677                // The configured limit is absolute hard limit - no stream can exceed it
678                if new_total <= effective_limit {
679                    Some(new_total)
680                } else {
681                    None
682                }
683            });
684
685        match result {
686            Ok(_) => {
687                if let Some(callback) = &self.on_update {
688                    callback(new_total);
689                }
690                Ok(())
691            }
692            Err(current) => {
693                if let Some(callback) = &self.on_reject {
694                    callback();
695                }
696                let msg = format!(
697                    "{} requested, {} used globally ({}%), {} used by this stream (privileged: {}), effective limit: {} ({}%), hard limit: {}",
698                    ReadableSize(additional as u64),
699                    ReadableSize(current as u64),
700                    if self.limit > 0 {
701                        current * 100 / self.limit
702                    } else {
703                        0
704                    },
705                    ReadableSize(stream_tracked as u64),
706                    is_privileged,
707                    ReadableSize(effective_limit as u64),
708                    if self.limit > 0 {
709                        effective_limit * 100 / self.limit
710                    } else {
711                        0
712                    },
713                    ReadableSize(self.limit as u64)
714                );
715                error::ExceedMemoryLimitSnafu { msg }.fail()
716            }
717        }
718    }
719
720    /// Release tracked memory.
721    ///
722    /// # Arguments
723    /// * `amount` - Amount of memory to release in bytes
724    pub fn release(&self, amount: usize) {
725        if let Ok(old_value) =
726            self.current
727                .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
728                    Some(current.saturating_sub(amount))
729                })
730            && let Some(callback) = &self.on_update
731        {
732            callback(old_value.saturating_sub(amount));
733        }
734    }
735}
736
737/// A wrapper stream that tracks memory usage of RecordBatches.
738pub struct MemoryTrackedStream {
739    inner: SendableRecordBatchStream,
740    permit: Arc<MemoryPermit>,
741    // Total tracked size, released when stream drops.
742    total_tracked: usize,
743}
744
745impl MemoryTrackedStream {
746    pub fn new(inner: SendableRecordBatchStream, permit: Arc<MemoryPermit>) -> Self {
747        Self {
748            inner,
749            permit,
750            total_tracked: 0,
751        }
752    }
753}
754
755impl Stream for MemoryTrackedStream {
756    type Item = Result<RecordBatch>;
757
758    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
759        match Pin::new(&mut self.inner).poll_next(cx) {
760            Poll::Ready(Some(Ok(batch))) => {
761                let additional = batch.buffer_memory_size();
762
763                if let Err(e) = self.permit.track(additional, self.total_tracked) {
764                    return Poll::Ready(Some(Err(e)));
765                }
766
767                self.total_tracked += additional;
768
769                Poll::Ready(Some(Ok(batch)))
770            }
771            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
772            Poll::Ready(None) => Poll::Ready(None),
773            Poll::Pending => Poll::Pending,
774        }
775    }
776
777    fn size_hint(&self) -> (usize, Option<usize>) {
778        self.inner.size_hint()
779    }
780}
781
782impl Drop for MemoryTrackedStream {
783    fn drop(&mut self) {
784        if self.total_tracked > 0 {
785            self.permit.release(self.total_tracked);
786        }
787    }
788}
789
790impl RecordBatchStream for MemoryTrackedStream {
791    fn schema(&self) -> SchemaRef {
792        self.inner.schema()
793    }
794
795    fn output_ordering(&self) -> Option<&[OrderOption]> {
796        self.inner.output_ordering()
797    }
798
799    fn metrics(&self) -> Option<RecordBatchMetrics> {
800        self.inner.metrics()
801    }
802}
803
804#[cfg(test)]
805mod tests {
806    use std::sync::Arc;
807
808    use datatypes::prelude::{ConcreteDataType, VectorRef};
809    use datatypes::schema::{ColumnSchema, Schema};
810    use datatypes::vectors::{BooleanVector, Int32Vector, StringVector};
811
812    use super::*;
813
814    #[test]
815    fn test_recordbatches_try_from_columns() {
816        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
817            "a",
818            ConcreteDataType::int32_datatype(),
819            false,
820        )]));
821        let result = RecordBatches::try_from_columns(
822            schema.clone(),
823            vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
824        );
825        assert!(result.is_err());
826
827        let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
828        let expected = vec![RecordBatch::new(schema.clone(), vec![v.clone()]).unwrap()];
829        let r = RecordBatches::try_from_columns(schema, vec![v]).unwrap();
830        assert_eq!(r.take(), expected);
831    }
832
833    #[test]
834    fn test_recordbatches_try_new() {
835        let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
836        let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
837        let column_c = ColumnSchema::new("c", ConcreteDataType::boolean_datatype(), false);
838
839        let va: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
840        let vb: VectorRef = Arc::new(StringVector::from(vec!["hello", "world"]));
841        let vc: VectorRef = Arc::new(BooleanVector::from(vec![true, false]));
842
843        let schema1 = Arc::new(Schema::new(vec![column_a.clone(), column_b]));
844        let batch1 = RecordBatch::new(schema1.clone(), vec![va.clone(), vb]).unwrap();
845
846        let schema2 = Arc::new(Schema::new(vec![column_a, column_c]));
847        let batch2 = RecordBatch::new(schema2.clone(), vec![va, vc]).unwrap();
848
849        let result = RecordBatches::try_new(schema1.clone(), vec![batch1.clone(), batch2]);
850        assert!(result.is_err());
851        assert_eq!(
852            result.unwrap_err().to_string(),
853            format!(
854                "Failed to create RecordBatches, reason: expect RecordBatch schema equals {schema1:?}, actual: {schema2:?}",
855            )
856        );
857
858        let batches = RecordBatches::try_new(schema1.clone(), vec![batch1.clone()]).unwrap();
859        let expected = "\
860+---+-------+
861| a | b     |
862+---+-------+
863| 1 | hello |
864| 2 | world |
865+---+-------+";
866        assert_eq!(batches.pretty_print().unwrap(), expected);
867
868        assert_eq!(schema1, batches.schema());
869        assert_eq!(vec![batch1], batches.take());
870    }
871
872    #[tokio::test]
873    async fn test_simple_recordbatch_stream() {
874        let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
875        let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
876        let schema = Arc::new(Schema::new(vec![column_a, column_b]));
877
878        let va1: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
879        let vb1: VectorRef = Arc::new(StringVector::from(vec!["a", "b"]));
880        let batch1 = RecordBatch::new(schema.clone(), vec![va1, vb1]).unwrap();
881
882        let va2: VectorRef = Arc::new(Int32Vector::from_slice([3, 4, 5]));
883        let vb2: VectorRef = Arc::new(StringVector::from(vec!["c", "d", "e"]));
884        let batch2 = RecordBatch::new(schema.clone(), vec![va2, vb2]).unwrap();
885
886        let recordbatches =
887            RecordBatches::try_new(schema.clone(), vec![batch1.clone(), batch2.clone()]).unwrap();
888        let stream = recordbatches.as_stream();
889        let collected = util::collect(stream).await.unwrap();
890        assert_eq!(collected.len(), 2);
891        assert_eq!(collected[0], batch1);
892        assert_eq!(collected[1], batch2);
893    }
894
895    #[test]
896    fn test_query_memory_tracker_basic() {
897        let tracker = Arc::new(QueryMemoryTracker::new(1000, 0));
898
899        // Register first stream - should get privileged status
900        let permit1 = tracker.register_permit();
901        assert!(permit1.is_privileged());
902
903        // Privileged stream can use up to limit
904        assert!(permit1.track(500, 0).is_ok());
905        assert_eq!(tracker.current(), 500);
906
907        // Register second stream - also privileged
908        let permit2 = tracker.register_permit();
909        assert!(permit2.is_privileged());
910        // Can add more but cannot exceed hard limit (1000)
911        assert!(permit2.track(400, 0).is_ok());
912        assert_eq!(tracker.current(), 900);
913
914        permit1.release(500);
915        permit2.release(400);
916        assert_eq!(tracker.current(), 0);
917    }
918
919    #[test]
920    fn test_query_memory_tracker_privileged_limit() {
921        // Privileged slots = 2 for easy testing
922        // Limit: 1000, standard-tier fraction: 0.7 (default)
923        // Privileged can push global to 1000, standard-tier can push global to 700
924        let tracker = Arc::new(QueryMemoryTracker::with_privileged_slots(1000, 2));
925
926        // First 2 streams are privileged
927        let permit1 = tracker.register_permit();
928        let permit2 = tracker.register_permit();
929        assert!(permit1.is_privileged());
930        assert!(permit2.is_privileged());
931
932        // Third stream is standard-tier (not privileged)
933        let permit3 = tracker.register_permit();
934        assert!(!permit3.is_privileged());
935
936        // Privileged stream uses some memory
937        assert!(permit1.track(300, 0).is_ok());
938        assert_eq!(tracker.current(), 300);
939
940        // Standard-tier can add up to 400 (total becomes 700, its effective limit)
941        assert!(permit3.track(400, 0).is_ok());
942        assert_eq!(tracker.current(), 700);
943
944        // Standard-tier stream cannot push global beyond 700
945        let err = permit3.track(100, 400).unwrap_err();
946        let err_msg = err.to_string();
947        assert!(err_msg.contains("400B used by this stream"));
948        assert!(err_msg.contains("effective limit: 700B (70%)"));
949        assert!(err_msg.contains("700B used globally (70%)"));
950        assert_eq!(tracker.current(), 700);
951
952        permit1.release(300);
953        permit3.release(400);
954        assert_eq!(tracker.current(), 0);
955    }
956
957    #[test]
958    fn test_query_memory_tracker_promotion() {
959        // Privileged slots = 1 for easy testing
960        let tracker = Arc::new(QueryMemoryTracker::with_privileged_slots(1000, 1));
961
962        // First stream is privileged
963        let permit1 = tracker.register_permit();
964        assert!(permit1.is_privileged());
965
966        // Second stream is standard-tier (can only use 500)
967        let permit2 = tracker.register_permit();
968        assert!(!permit2.is_privileged());
969
970        // Standard-tier can only track 500
971        assert!(permit2.track(400, 0).is_ok());
972        assert_eq!(tracker.current(), 400);
973
974        // Drop first permit to release privileged slot
975        drop(permit1);
976
977        // Second stream can now be promoted and use more memory
978        assert!(permit2.track(500, 400).is_ok());
979        assert!(permit2.is_privileged());
980        assert_eq!(tracker.current(), 900);
981
982        permit2.release(900);
983        assert_eq!(tracker.current(), 0);
984    }
985
986    #[test]
987    fn test_query_memory_tracker_privileged_hard_limit() {
988        // Test that the configured limit is absolute hard limit for all streams
989        // Privileged: can use full limit (1000)
990        // Standard-tier: can use 0.7x limit (700 with defaults)
991        let tracker = Arc::new(QueryMemoryTracker::new(1000, 0));
992
993        let permit1 = tracker.register_permit();
994        assert!(permit1.is_privileged());
995
996        // Privileged can use up to full limit (1000)
997        assert!(permit1.track(900, 0).is_ok());
998        assert_eq!(tracker.current(), 900);
999
1000        // Privileged cannot exceed hard limit (1000)
1001        assert!(permit1.track(200, 900).is_err());
1002        assert_eq!(tracker.current(), 900);
1003
1004        // Can add within hard limit
1005        assert!(permit1.track(100, 900).is_ok());
1006        assert_eq!(tracker.current(), 1000);
1007
1008        // Cannot exceed even by 1 byte
1009        assert!(permit1.track(1, 1000).is_err());
1010        assert_eq!(tracker.current(), 1000);
1011
1012        permit1.release(1000);
1013        assert_eq!(tracker.current(), 0);
1014    }
1015
1016    #[test]
1017    fn test_query_memory_tracker_standard_tier_fraction() {
1018        // Test standard-tier streams use fraction of limit
1019        // Limit: 1000, default fraction: 0.7, so standard-tier can use 700
1020        let tracker = Arc::new(QueryMemoryTracker::with_privileged_slots(1000, 1));
1021
1022        let permit1 = tracker.register_permit();
1023        assert!(permit1.is_privileged());
1024
1025        let permit2 = tracker.register_permit();
1026        assert!(!permit2.is_privileged());
1027
1028        // Standard-tier can use up to 700 (1000 * 0.7 default)
1029        assert!(permit2.track(600, 0).is_ok());
1030        assert_eq!(tracker.current(), 600);
1031
1032        // Cannot exceed standard-tier limit (700)
1033        assert!(permit2.track(200, 600).is_err());
1034        assert_eq!(tracker.current(), 600);
1035
1036        // Can add within standard-tier limit
1037        assert!(permit2.track(100, 600).is_ok());
1038        assert_eq!(tracker.current(), 700);
1039
1040        // Cannot exceed standard-tier limit
1041        assert!(permit2.track(1, 700).is_err());
1042        assert_eq!(tracker.current(), 700);
1043
1044        permit2.release(700);
1045        assert_eq!(tracker.current(), 0);
1046    }
1047}