common_recordbatch/
lib.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![feature(never_type)]
16
17pub mod adapter;
18pub mod cursor;
19pub mod error;
20pub mod filter;
21mod recordbatch;
22pub mod util;
23
24use std::fmt;
25use std::pin::Pin;
26use std::sync::Arc;
27use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
28
29use adapter::RecordBatchMetrics;
30use arc_swap::ArcSwapOption;
31use common_base::readable_size::ReadableSize;
32pub use datafusion::physical_plan::SendableRecordBatchStream as DfSendableRecordBatchStream;
33use datatypes::arrow::compute::SortOptions;
34pub use datatypes::arrow::record_batch::RecordBatch as DfRecordBatch;
35use datatypes::arrow::util::pretty;
36use datatypes::prelude::{ConcreteDataType, VectorRef};
37use datatypes::scalars::{ScalarVector, ScalarVectorBuilder};
38use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
39use datatypes::types::{JsonFormat, jsonb_to_string};
40use datatypes::vectors::{BinaryVector, StringVectorBuilder};
41use error::Result;
42use futures::task::{Context, Poll};
43use futures::{Stream, TryStreamExt};
44pub use recordbatch::RecordBatch;
45use snafu::{OptionExt, ResultExt, ensure};
46
47pub trait RecordBatchStream: Stream<Item = Result<RecordBatch>> {
48    fn name(&self) -> &str {
49        "RecordBatchStream"
50    }
51
52    fn schema(&self) -> SchemaRef;
53
54    fn output_ordering(&self) -> Option<&[OrderOption]>;
55
56    fn metrics(&self) -> Option<RecordBatchMetrics>;
57}
58
59pub type SendableRecordBatchStream = Pin<Box<dyn RecordBatchStream + Send>>;
60
61#[derive(Debug, Clone, PartialEq, Eq)]
62pub struct OrderOption {
63    pub name: String,
64    pub options: SortOptions,
65}
66
67/// A wrapper that maps a [RecordBatchStream] to a new [RecordBatchStream] by applying a function to each [RecordBatch].
68///
69/// The mapper function is applied to each [RecordBatch] in the stream.
70/// The schema of the new [RecordBatchStream] is the same as the schema of the inner [RecordBatchStream] after applying the schema mapper function.
71/// The output ordering of the new [RecordBatchStream] is the same as the output ordering of the inner [RecordBatchStream].
72/// The metrics of the new [RecordBatchStream] is the same as the metrics of the inner [RecordBatchStream] if it is not `None`.
73pub struct SendableRecordBatchMapper {
74    inner: SendableRecordBatchStream,
75    /// The mapper function is applied to each [RecordBatch] in the stream.
76    /// The original schema and the mapped schema are passed to the mapper function.
77    mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
78    /// The schema of the new [RecordBatchStream] is the same as the schema of the inner [RecordBatchStream] after applying the schema mapper function.
79    schema: SchemaRef,
80    /// Whether the mapper function is applied to each [RecordBatch] in the stream.
81    apply_mapper: bool,
82}
83
84/// Maps the json type to string in the batch.
85///
86/// The json type is mapped to string by converting the json value to string.
87/// The batch is updated to have the same number of columns as the original batch,
88/// but with the json type mapped to string.
89pub fn map_json_type_to_string(
90    batch: RecordBatch,
91    original_schema: &SchemaRef,
92    mapped_schema: &SchemaRef,
93) -> Result<RecordBatch> {
94    let mut vectors = Vec::with_capacity(original_schema.column_schemas().len());
95    for (vector, schema) in batch.columns.iter().zip(original_schema.column_schemas()) {
96        if let ConcreteDataType::Json(j) = &schema.data_type {
97            if matches!(&j.format, JsonFormat::Jsonb) {
98                let mut string_vector_builder = StringVectorBuilder::with_capacity(vector.len());
99                let binary_vector = vector
100                    .as_any()
101                    .downcast_ref::<BinaryVector>()
102                    .with_context(|| error::DowncastVectorSnafu {
103                        from_type: schema.data_type.clone(),
104                        to_type: ConcreteDataType::binary_datatype(),
105                    })?;
106                for value in binary_vector.iter_data() {
107                    let Some(value) = value else {
108                        string_vector_builder.push(None);
109                        continue;
110                    };
111                    let string_value =
112                        jsonb_to_string(value).with_context(|_| error::CastVectorSnafu {
113                            from_type: schema.data_type.clone(),
114                            to_type: ConcreteDataType::string_datatype(),
115                        })?;
116                    string_vector_builder.push(Some(string_value.as_str()));
117                }
118
119                let string_vector = string_vector_builder.finish();
120                vectors.push(Arc::new(string_vector) as VectorRef);
121            } else {
122                vectors.push(vector.clone());
123            }
124        } else {
125            vectors.push(vector.clone());
126        }
127    }
128
129    RecordBatch::new(mapped_schema.clone(), vectors)
130}
131
132/// Maps the json type to string in the schema.
133///
134/// The json type is mapped to string by converting the json value to string.
135/// The schema is updated to have the same number of columns as the original schema,
136/// but with the json type mapped to string.
137///
138/// Returns the new schema and whether the schema needs to be mapped to string.
139pub fn map_json_type_to_string_schema(schema: SchemaRef) -> (SchemaRef, bool) {
140    let mut new_columns = Vec::with_capacity(schema.column_schemas().len());
141    let mut apply_mapper = false;
142    for column in schema.column_schemas() {
143        if matches!(column.data_type, ConcreteDataType::Json(_)) {
144            new_columns.push(ColumnSchema::new(
145                column.name.clone(),
146                ConcreteDataType::string_datatype(),
147                column.is_nullable(),
148            ));
149            apply_mapper = true;
150        } else {
151            new_columns.push(column.clone());
152        }
153    }
154    (Arc::new(Schema::new(new_columns)), apply_mapper)
155}
156
157impl SendableRecordBatchMapper {
158    /// Creates a new [SendableRecordBatchMapper] with the given inner [RecordBatchStream], mapper function, and schema mapper function.
159    pub fn new(
160        inner: SendableRecordBatchStream,
161        mapper: fn(RecordBatch, &SchemaRef, &SchemaRef) -> Result<RecordBatch>,
162        schema_mapper: fn(SchemaRef) -> (SchemaRef, bool),
163    ) -> Self {
164        let (mapped_schema, apply_mapper) = schema_mapper(inner.schema());
165        Self {
166            inner,
167            mapper,
168            schema: mapped_schema,
169            apply_mapper,
170        }
171    }
172}
173
174impl RecordBatchStream for SendableRecordBatchMapper {
175    fn name(&self) -> &str {
176        "SendableRecordBatchMapper"
177    }
178
179    fn schema(&self) -> SchemaRef {
180        self.schema.clone()
181    }
182
183    fn output_ordering(&self) -> Option<&[OrderOption]> {
184        self.inner.output_ordering()
185    }
186
187    fn metrics(&self) -> Option<RecordBatchMetrics> {
188        self.inner.metrics()
189    }
190}
191
192impl Stream for SendableRecordBatchMapper {
193    type Item = Result<RecordBatch>;
194
195    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
196        if self.apply_mapper {
197            Pin::new(&mut self.inner).poll_next(cx).map(|opt| {
198                opt.map(|result| {
199                    result
200                        .and_then(|batch| (self.mapper)(batch, &self.inner.schema(), &self.schema))
201                })
202            })
203        } else {
204            Pin::new(&mut self.inner).poll_next(cx)
205        }
206    }
207}
208
209/// EmptyRecordBatchStream can be used to create a RecordBatchStream
210/// that will produce no results
211pub struct EmptyRecordBatchStream {
212    /// Schema wrapped by Arc
213    schema: SchemaRef,
214}
215
216impl EmptyRecordBatchStream {
217    /// Create an empty RecordBatchStream
218    pub fn new(schema: SchemaRef) -> Self {
219        Self { schema }
220    }
221}
222
223impl RecordBatchStream for EmptyRecordBatchStream {
224    fn schema(&self) -> SchemaRef {
225        self.schema.clone()
226    }
227
228    fn output_ordering(&self) -> Option<&[OrderOption]> {
229        None
230    }
231
232    fn metrics(&self) -> Option<RecordBatchMetrics> {
233        None
234    }
235}
236
237impl Stream for EmptyRecordBatchStream {
238    type Item = Result<RecordBatch>;
239
240    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
241        Poll::Ready(None)
242    }
243}
244
245#[derive(Debug, PartialEq)]
246pub struct RecordBatches {
247    schema: SchemaRef,
248    batches: Vec<RecordBatch>,
249}
250
251impl RecordBatches {
252    pub fn try_from_columns<I: IntoIterator<Item = VectorRef>>(
253        schema: SchemaRef,
254        columns: I,
255    ) -> Result<Self> {
256        let batches = vec![RecordBatch::new(schema.clone(), columns)?];
257        Ok(Self { schema, batches })
258    }
259
260    pub async fn try_collect(stream: SendableRecordBatchStream) -> Result<Self> {
261        let schema = stream.schema();
262        let batches = stream.try_collect::<Vec<_>>().await?;
263        Ok(Self { schema, batches })
264    }
265
266    #[inline]
267    pub fn empty() -> Self {
268        Self {
269            schema: Arc::new(Schema::new(vec![])),
270            batches: vec![],
271        }
272    }
273
274    pub fn iter(&self) -> impl Iterator<Item = &RecordBatch> {
275        self.batches.iter()
276    }
277
278    pub fn pretty_print(&self) -> Result<String> {
279        let df_batches = &self
280            .iter()
281            .map(|x| x.df_record_batch().clone())
282            .collect::<Vec<_>>();
283        let result = pretty::pretty_format_batches(df_batches).context(error::FormatSnafu)?;
284
285        Ok(result.to_string())
286    }
287
288    pub fn try_new(schema: SchemaRef, batches: Vec<RecordBatch>) -> Result<Self> {
289        for batch in &batches {
290            ensure!(
291                batch.schema == schema,
292                error::CreateRecordBatchesSnafu {
293                    reason: format!(
294                        "expect RecordBatch schema equals {:?}, actual: {:?}",
295                        schema, batch.schema
296                    )
297                }
298            )
299        }
300        Ok(Self { schema, batches })
301    }
302
303    pub fn schema(&self) -> SchemaRef {
304        self.schema.clone()
305    }
306
307    pub fn take(self) -> Vec<RecordBatch> {
308        self.batches
309    }
310
311    pub fn as_stream(&self) -> SendableRecordBatchStream {
312        Box::pin(SimpleRecordBatchStream {
313            inner: RecordBatches {
314                schema: self.schema(),
315                batches: self.batches.clone(),
316            },
317            index: 0,
318        })
319    }
320}
321
322impl IntoIterator for RecordBatches {
323    type Item = RecordBatch;
324    type IntoIter = std::vec::IntoIter<Self::Item>;
325
326    fn into_iter(self) -> Self::IntoIter {
327        self.batches.into_iter()
328    }
329}
330
331pub struct SimpleRecordBatchStream {
332    inner: RecordBatches,
333    index: usize,
334}
335
336impl RecordBatchStream for SimpleRecordBatchStream {
337    fn schema(&self) -> SchemaRef {
338        self.inner.schema()
339    }
340
341    fn output_ordering(&self) -> Option<&[OrderOption]> {
342        None
343    }
344
345    fn metrics(&self) -> Option<RecordBatchMetrics> {
346        None
347    }
348}
349
350impl Stream for SimpleRecordBatchStream {
351    type Item = Result<RecordBatch>;
352
353    fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
354        Poll::Ready(if self.index < self.inner.batches.len() {
355            let batch = self.inner.batches[self.index].clone();
356            self.index += 1;
357            Some(Ok(batch))
358        } else {
359            None
360        })
361    }
362}
363
364/// Adapt a [Stream] of [RecordBatch] to a [RecordBatchStream].
365pub struct RecordBatchStreamWrapper<S> {
366    pub schema: SchemaRef,
367    pub stream: S,
368    pub output_ordering: Option<Vec<OrderOption>>,
369    pub metrics: Arc<ArcSwapOption<RecordBatchMetrics>>,
370}
371
372impl<S> RecordBatchStreamWrapper<S> {
373    /// Creates a [RecordBatchStreamWrapper] without output ordering requirement.
374    pub fn new(schema: SchemaRef, stream: S) -> RecordBatchStreamWrapper<S> {
375        RecordBatchStreamWrapper {
376            schema,
377            stream,
378            output_ordering: None,
379            metrics: Default::default(),
380        }
381    }
382}
383
384impl<S: Stream<Item = Result<RecordBatch>> + Unpin> RecordBatchStream
385    for RecordBatchStreamWrapper<S>
386{
387    fn name(&self) -> &str {
388        "RecordBatchStreamWrapper"
389    }
390
391    fn schema(&self) -> SchemaRef {
392        self.schema.clone()
393    }
394
395    fn output_ordering(&self) -> Option<&[OrderOption]> {
396        self.output_ordering.as_deref()
397    }
398
399    fn metrics(&self) -> Option<RecordBatchMetrics> {
400        self.metrics.load().as_ref().map(|s| s.as_ref().clone())
401    }
402}
403
404impl<S: Stream<Item = Result<RecordBatch>> + Unpin> Stream for RecordBatchStreamWrapper<S> {
405    type Item = Result<RecordBatch>;
406
407    fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
408        Pin::new(&mut self.stream).poll_next(ctx)
409    }
410}
411
412/// Memory permit for a stream, providing privileged access or rate limiting.
413///
414/// The permit tracks whether this stream has privileged Top-K status.
415/// When dropped, it automatically releases any privileged slot it holds.
416pub struct MemoryPermit {
417    tracker: QueryMemoryTracker,
418    is_privileged: AtomicBool,
419}
420
421impl MemoryPermit {
422    /// Check if this permit currently has privileged status.
423    pub fn is_privileged(&self) -> bool {
424        self.is_privileged.load(Ordering::Acquire)
425    }
426
427    /// Ensure this permit has privileged status by acquiring a slot if available.
428    /// Returns true if privileged (either already privileged or just acquired privilege).
429    fn ensure_privileged(&self) -> bool {
430        if self.is_privileged.load(Ordering::Acquire) {
431            return true;
432        }
433
434        // Try to claim a privileged slot
435        self.tracker
436            .privileged_count
437            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |count| {
438                if count < self.tracker.privileged_slots {
439                    Some(count + 1)
440                } else {
441                    None
442                }
443            })
444            .map(|_| {
445                self.is_privileged.store(true, Ordering::Release);
446                true
447            })
448            .unwrap_or(false)
449    }
450
451    /// Track additional memory usage with this permit.
452    /// Returns error if limit is exceeded.
453    ///
454    /// # Arguments
455    /// * `additional` - Additional memory size to track in bytes
456    /// * `stream_tracked` - Total memory already tracked by this stream
457    ///
458    /// # Behavior
459    /// - Privileged streams: Can push global memory usage up to full limit
460    /// - Standard-tier streams: Can push global memory usage up to limit * standard_tier_memory_fraction (default: 0.7)
461    /// - Standard-tier streams automatically attempt to acquire privilege if slots become available
462    /// - The configured limit is absolute hard limit - no stream can exceed it
463    pub fn track(&self, additional: usize, stream_tracked: usize) -> Result<()> {
464        // Ensure privileged status if possible
465        let is_privileged = self.ensure_privileged();
466
467        self.tracker
468            .track_internal(additional, is_privileged, stream_tracked)
469    }
470
471    /// Release tracked memory.
472    ///
473    /// # Arguments
474    /// * `amount` - Amount of memory to release in bytes
475    pub fn release(&self, amount: usize) {
476        self.tracker.release(amount);
477    }
478}
479
480impl Drop for MemoryPermit {
481    fn drop(&mut self) {
482        // Release privileged slot if we had one
483        if self.is_privileged.load(Ordering::Acquire) {
484            self.tracker
485                .privileged_count
486                .fetch_sub(1, Ordering::Release);
487        }
488    }
489}
490
491/// Memory tracker for RecordBatch streams. Clone to share the same limit across queries.
492///
493/// Implements a two-tier memory allocation strategy:
494/// - **Privileged tier**: First N streams (default: 20) can use up to the full memory limit
495/// - **Standard tier**: Remaining streams are restricted to a fraction of the limit (default: 70%)
496/// - Privilege is granted on a first-come-first-served basis
497/// - The configured limit is an absolute hard cap - no stream can exceed it
498#[derive(Clone)]
499pub struct QueryMemoryTracker {
500    current: Arc<AtomicUsize>,
501    limit: usize,
502    standard_tier_memory_fraction: f64,
503    privileged_count: Arc<AtomicUsize>,
504    privileged_slots: usize,
505    on_update: Option<Arc<dyn Fn(usize) + Send + Sync>>,
506    on_reject: Option<Arc<dyn Fn() + Send + Sync>>,
507}
508
509impl fmt::Debug for QueryMemoryTracker {
510    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
511        f.debug_struct("QueryMemoryTracker")
512            .field("current", &self.current.load(Ordering::Acquire))
513            .field("limit", &self.limit)
514            .field(
515                "standard_tier_memory_fraction",
516                &self.standard_tier_memory_fraction,
517            )
518            .field(
519                "privileged_count",
520                &self.privileged_count.load(Ordering::Acquire),
521            )
522            .field("privileged_slots", &self.privileged_slots)
523            .field("on_update", &self.on_update.is_some())
524            .field("on_reject", &self.on_reject.is_some())
525            .finish()
526    }
527}
528
529impl QueryMemoryTracker {
530    // Default privileged slots when max_concurrent_queries is 0.
531    const DEFAULT_PRIVILEGED_SLOTS: usize = 20;
532    // Ratio for privileged tier: 70% queries get privileged access, standard tier uses 70% memory.
533    const DEFAULT_PRIVILEGED_TIER_RATIO: f64 = 0.7;
534
535    /// Create a new memory tracker with the given limit and max_concurrent_queries.
536    /// Calculates privileged slots as 70% of max_concurrent_queries (or 20 if max_concurrent_queries is 0).
537    ///
538    /// # Arguments
539    /// * `limit` - Maximum memory usage in bytes (hard limit for all streams). 0 means unlimited.
540    /// * `max_concurrent_queries` - Maximum number of concurrent queries (0 = unlimited).
541    pub fn new(limit: usize, max_concurrent_queries: usize) -> Self {
542        let privileged_slots = Self::calculate_privileged_slots(max_concurrent_queries);
543        Self::with_privileged_slots(limit, privileged_slots)
544    }
545
546    /// Create a new memory tracker with custom privileged slots limit.
547    pub fn with_privileged_slots(limit: usize, privileged_slots: usize) -> Self {
548        Self::with_config(limit, privileged_slots, Self::DEFAULT_PRIVILEGED_TIER_RATIO)
549    }
550
551    /// Create a new memory tracker with full configuration.
552    ///
553    /// # Arguments
554    /// * `limit` - Maximum memory usage in bytes (hard limit for all streams). 0 means unlimited.
555    /// * `privileged_slots` - Maximum number of streams that can get privileged status.
556    /// * `standard_tier_memory_fraction` - Memory fraction for standard-tier streams (range: [0.0, 1.0]).
557    ///
558    /// # Panics
559    /// Panics if `standard_tier_memory_fraction` is not in the range [0.0, 1.0].
560    pub fn with_config(
561        limit: usize,
562        privileged_slots: usize,
563        standard_tier_memory_fraction: f64,
564    ) -> Self {
565        assert!(
566            (0.0..=1.0).contains(&standard_tier_memory_fraction),
567            "standard_tier_memory_fraction must be in [0.0, 1.0], got {}",
568            standard_tier_memory_fraction
569        );
570
571        Self {
572            current: Arc::new(AtomicUsize::new(0)),
573            limit,
574            standard_tier_memory_fraction,
575            privileged_count: Arc::new(AtomicUsize::new(0)),
576            privileged_slots,
577            on_update: None,
578            on_reject: None,
579        }
580    }
581
582    /// Register a new permit for memory tracking.
583    /// The first `privileged_slots` permits get privileged status automatically.
584    /// The returned permit can be shared across multiple streams of the same query.
585    pub fn register_permit(&self) -> MemoryPermit {
586        // Try to claim a privileged slot
587        let is_privileged = self
588            .privileged_count
589            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |count| {
590                if count < self.privileged_slots {
591                    Some(count + 1)
592                } else {
593                    None
594                }
595            })
596            .is_ok();
597
598        MemoryPermit {
599            tracker: self.clone(),
600            is_privileged: AtomicBool::new(is_privileged),
601        }
602    }
603
604    /// Set a callback to be called whenever the usage changes successfully.
605    /// The callback receives the new total usage in bytes.
606    ///
607    /// # Note
608    /// The callback is called after both successful `track()` and `release()` operations.
609    /// It is called even when `limit == 0` (unlimited mode) to track actual usage.
610    pub fn with_on_update<F>(mut self, on_update: F) -> Self
611    where
612        F: Fn(usize) + Send + Sync + 'static,
613    {
614        self.on_update = Some(Arc::new(on_update));
615        self
616    }
617
618    /// Set a callback to be called when memory allocation is rejected.
619    ///
620    /// # Note
621    /// This is only called when `track()` fails due to exceeding the limit.
622    /// It is never called when `limit == 0` (unlimited mode).
623    pub fn with_on_reject<F>(mut self, on_reject: F) -> Self
624    where
625        F: Fn() + Send + Sync + 'static,
626    {
627        self.on_reject = Some(Arc::new(on_reject));
628        self
629    }
630
631    /// Get the current memory usage in bytes.
632    pub fn current(&self) -> usize {
633        self.current.load(Ordering::Acquire)
634    }
635
636    fn calculate_privileged_slots(max_concurrent_queries: usize) -> usize {
637        if max_concurrent_queries == 0 {
638            Self::DEFAULT_PRIVILEGED_SLOTS
639        } else {
640            ((max_concurrent_queries as f64 * Self::DEFAULT_PRIVILEGED_TIER_RATIO) as usize).max(1)
641        }
642    }
643
644    /// Internal method to track additional memory usage.
645    ///
646    /// Called by `MemoryPermit::track()`. Use `MemoryPermit::track()` instead of calling this directly.
647    fn track_internal(
648        &self,
649        additional: usize,
650        is_privileged: bool,
651        stream_tracked: usize,
652    ) -> Result<()> {
653        // Calculate effective global limit based on stream privilege
654        // Privileged streams: can push global usage up to full limit
655        // Standard-tier streams: can only push global usage up to fraction of limit
656        let effective_limit = if is_privileged {
657            self.limit
658        } else {
659            (self.limit as f64 * self.standard_tier_memory_fraction) as usize
660        };
661
662        let mut new_total = 0;
663        let result = self
664            .current
665            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
666                new_total = current.saturating_add(additional);
667
668                if self.limit == 0 {
669                    // Unlimited mode
670                    return Some(new_total);
671                }
672
673                // Check if new global total exceeds effective limit
674                // The configured limit is absolute hard limit - no stream can exceed it
675                if new_total <= effective_limit {
676                    Some(new_total)
677                } else {
678                    None
679                }
680            });
681
682        match result {
683            Ok(_) => {
684                if let Some(callback) = &self.on_update {
685                    callback(new_total);
686                }
687                Ok(())
688            }
689            Err(current) => {
690                if let Some(callback) = &self.on_reject {
691                    callback();
692                }
693                let msg = format!(
694                    "{} requested, {} used globally ({}%), {} used by this stream (privileged: {}), effective limit: {} ({}%), hard limit: {}",
695                    ReadableSize(additional as u64),
696                    ReadableSize(current as u64),
697                    if self.limit > 0 {
698                        current * 100 / self.limit
699                    } else {
700                        0
701                    },
702                    ReadableSize(stream_tracked as u64),
703                    is_privileged,
704                    ReadableSize(effective_limit as u64),
705                    if self.limit > 0 {
706                        effective_limit * 100 / self.limit
707                    } else {
708                        0
709                    },
710                    ReadableSize(self.limit as u64)
711                );
712                error::ExceedMemoryLimitSnafu { msg }.fail()
713            }
714        }
715    }
716
717    /// Release tracked memory.
718    ///
719    /// # Arguments
720    /// * `amount` - Amount of memory to release in bytes
721    pub fn release(&self, amount: usize) {
722        if let Ok(old_value) =
723            self.current
724                .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
725                    Some(current.saturating_sub(amount))
726                })
727            && let Some(callback) = &self.on_update
728        {
729            callback(old_value.saturating_sub(amount));
730        }
731    }
732}
733
734/// A wrapper stream that tracks memory usage of RecordBatches.
735pub struct MemoryTrackedStream {
736    inner: SendableRecordBatchStream,
737    permit: Arc<MemoryPermit>,
738    // Total tracked size, released when stream drops.
739    total_tracked: usize,
740}
741
742impl MemoryTrackedStream {
743    pub fn new(inner: SendableRecordBatchStream, permit: Arc<MemoryPermit>) -> Self {
744        Self {
745            inner,
746            permit,
747            total_tracked: 0,
748        }
749    }
750}
751
752impl Stream for MemoryTrackedStream {
753    type Item = Result<RecordBatch>;
754
755    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
756        match Pin::new(&mut self.inner).poll_next(cx) {
757            Poll::Ready(Some(Ok(batch))) => {
758                let additional = batch
759                    .columns()
760                    .iter()
761                    .map(|c| c.memory_size())
762                    .sum::<usize>();
763
764                if let Err(e) = self.permit.track(additional, self.total_tracked) {
765                    return Poll::Ready(Some(Err(e)));
766                }
767
768                self.total_tracked += additional;
769
770                Poll::Ready(Some(Ok(batch)))
771            }
772            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
773            Poll::Ready(None) => Poll::Ready(None),
774            Poll::Pending => Poll::Pending,
775        }
776    }
777
778    fn size_hint(&self) -> (usize, Option<usize>) {
779        self.inner.size_hint()
780    }
781}
782
783impl Drop for MemoryTrackedStream {
784    fn drop(&mut self) {
785        if self.total_tracked > 0 {
786            self.permit.release(self.total_tracked);
787        }
788    }
789}
790
791impl RecordBatchStream for MemoryTrackedStream {
792    fn schema(&self) -> SchemaRef {
793        self.inner.schema()
794    }
795
796    fn output_ordering(&self) -> Option<&[OrderOption]> {
797        self.inner.output_ordering()
798    }
799
800    fn metrics(&self) -> Option<RecordBatchMetrics> {
801        self.inner.metrics()
802    }
803}
804
805#[cfg(test)]
806mod tests {
807    use std::sync::Arc;
808
809    use datatypes::prelude::{ConcreteDataType, VectorRef};
810    use datatypes::schema::{ColumnSchema, Schema};
811    use datatypes::vectors::{BooleanVector, Int32Vector, StringVector};
812
813    use super::*;
814
815    #[test]
816    fn test_recordbatches_try_from_columns() {
817        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
818            "a",
819            ConcreteDataType::int32_datatype(),
820            false,
821        )]));
822        let result = RecordBatches::try_from_columns(
823            schema.clone(),
824            vec![Arc::new(StringVector::from(vec!["hello", "world"])) as _],
825        );
826        assert!(result.is_err());
827
828        let v: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
829        let expected = vec![RecordBatch::new(schema.clone(), vec![v.clone()]).unwrap()];
830        let r = RecordBatches::try_from_columns(schema, vec![v]).unwrap();
831        assert_eq!(r.take(), expected);
832    }
833
834    #[test]
835    fn test_recordbatches_try_new() {
836        let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
837        let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
838        let column_c = ColumnSchema::new("c", ConcreteDataType::boolean_datatype(), false);
839
840        let va: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
841        let vb: VectorRef = Arc::new(StringVector::from(vec!["hello", "world"]));
842        let vc: VectorRef = Arc::new(BooleanVector::from(vec![true, false]));
843
844        let schema1 = Arc::new(Schema::new(vec![column_a.clone(), column_b]));
845        let batch1 = RecordBatch::new(schema1.clone(), vec![va.clone(), vb]).unwrap();
846
847        let schema2 = Arc::new(Schema::new(vec![column_a, column_c]));
848        let batch2 = RecordBatch::new(schema2.clone(), vec![va, vc]).unwrap();
849
850        let result = RecordBatches::try_new(schema1.clone(), vec![batch1.clone(), batch2]);
851        assert!(result.is_err());
852        assert_eq!(
853            result.unwrap_err().to_string(),
854            format!(
855                "Failed to create RecordBatches, reason: expect RecordBatch schema equals {schema1:?}, actual: {schema2:?}",
856            )
857        );
858
859        let batches = RecordBatches::try_new(schema1.clone(), vec![batch1.clone()]).unwrap();
860        let expected = "\
861+---+-------+
862| a | b     |
863+---+-------+
864| 1 | hello |
865| 2 | world |
866+---+-------+";
867        assert_eq!(batches.pretty_print().unwrap(), expected);
868
869        assert_eq!(schema1, batches.schema());
870        assert_eq!(vec![batch1], batches.take());
871    }
872
873    #[tokio::test]
874    async fn test_simple_recordbatch_stream() {
875        let column_a = ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false);
876        let column_b = ColumnSchema::new("b", ConcreteDataType::string_datatype(), false);
877        let schema = Arc::new(Schema::new(vec![column_a, column_b]));
878
879        let va1: VectorRef = Arc::new(Int32Vector::from_slice([1, 2]));
880        let vb1: VectorRef = Arc::new(StringVector::from(vec!["a", "b"]));
881        let batch1 = RecordBatch::new(schema.clone(), vec![va1, vb1]).unwrap();
882
883        let va2: VectorRef = Arc::new(Int32Vector::from_slice([3, 4, 5]));
884        let vb2: VectorRef = Arc::new(StringVector::from(vec!["c", "d", "e"]));
885        let batch2 = RecordBatch::new(schema.clone(), vec![va2, vb2]).unwrap();
886
887        let recordbatches =
888            RecordBatches::try_new(schema.clone(), vec![batch1.clone(), batch2.clone()]).unwrap();
889        let stream = recordbatches.as_stream();
890        let collected = util::collect(stream).await.unwrap();
891        assert_eq!(collected.len(), 2);
892        assert_eq!(collected[0], batch1);
893        assert_eq!(collected[1], batch2);
894    }
895
896    #[test]
897    fn test_query_memory_tracker_basic() {
898        let tracker = Arc::new(QueryMemoryTracker::new(1000, 0));
899
900        // Register first stream - should get privileged status
901        let permit1 = tracker.register_permit();
902        assert!(permit1.is_privileged());
903
904        // Privileged stream can use up to limit
905        assert!(permit1.track(500, 0).is_ok());
906        assert_eq!(tracker.current(), 500);
907
908        // Register second stream - also privileged
909        let permit2 = tracker.register_permit();
910        assert!(permit2.is_privileged());
911        // Can add more but cannot exceed hard limit (1000)
912        assert!(permit2.track(400, 0).is_ok());
913        assert_eq!(tracker.current(), 900);
914
915        permit1.release(500);
916        permit2.release(400);
917        assert_eq!(tracker.current(), 0);
918    }
919
920    #[test]
921    fn test_query_memory_tracker_privileged_limit() {
922        // Privileged slots = 2 for easy testing
923        // Limit: 1000, standard-tier fraction: 0.7 (default)
924        // Privileged can push global to 1000, standard-tier can push global to 700
925        let tracker = Arc::new(QueryMemoryTracker::with_privileged_slots(1000, 2));
926
927        // First 2 streams are privileged
928        let permit1 = tracker.register_permit();
929        let permit2 = tracker.register_permit();
930        assert!(permit1.is_privileged());
931        assert!(permit2.is_privileged());
932
933        // Third stream is standard-tier (not privileged)
934        let permit3 = tracker.register_permit();
935        assert!(!permit3.is_privileged());
936
937        // Privileged stream uses some memory
938        assert!(permit1.track(300, 0).is_ok());
939        assert_eq!(tracker.current(), 300);
940
941        // Standard-tier can add up to 400 (total becomes 700, its effective limit)
942        assert!(permit3.track(400, 0).is_ok());
943        assert_eq!(tracker.current(), 700);
944
945        // Standard-tier stream cannot push global beyond 700
946        let err = permit3.track(100, 400).unwrap_err();
947        let err_msg = err.to_string();
948        assert!(err_msg.contains("400B used by this stream"));
949        assert!(err_msg.contains("effective limit: 700B (70%)"));
950        assert!(err_msg.contains("700B used globally (70%)"));
951        assert_eq!(tracker.current(), 700);
952
953        permit1.release(300);
954        permit3.release(400);
955        assert_eq!(tracker.current(), 0);
956    }
957
958    #[test]
959    fn test_query_memory_tracker_promotion() {
960        // Privileged slots = 1 for easy testing
961        let tracker = Arc::new(QueryMemoryTracker::with_privileged_slots(1000, 1));
962
963        // First stream is privileged
964        let permit1 = tracker.register_permit();
965        assert!(permit1.is_privileged());
966
967        // Second stream is standard-tier (can only use 500)
968        let permit2 = tracker.register_permit();
969        assert!(!permit2.is_privileged());
970
971        // Standard-tier can only track 500
972        assert!(permit2.track(400, 0).is_ok());
973        assert_eq!(tracker.current(), 400);
974
975        // Drop first permit to release privileged slot
976        drop(permit1);
977
978        // Second stream can now be promoted and use more memory
979        assert!(permit2.track(500, 400).is_ok());
980        assert!(permit2.is_privileged());
981        assert_eq!(tracker.current(), 900);
982
983        permit2.release(900);
984        assert_eq!(tracker.current(), 0);
985    }
986
987    #[test]
988    fn test_query_memory_tracker_privileged_hard_limit() {
989        // Test that the configured limit is absolute hard limit for all streams
990        // Privileged: can use full limit (1000)
991        // Standard-tier: can use 0.7x limit (700 with defaults)
992        let tracker = Arc::new(QueryMemoryTracker::new(1000, 0));
993
994        let permit1 = tracker.register_permit();
995        assert!(permit1.is_privileged());
996
997        // Privileged can use up to full limit (1000)
998        assert!(permit1.track(900, 0).is_ok());
999        assert_eq!(tracker.current(), 900);
1000
1001        // Privileged cannot exceed hard limit (1000)
1002        assert!(permit1.track(200, 900).is_err());
1003        assert_eq!(tracker.current(), 900);
1004
1005        // Can add within hard limit
1006        assert!(permit1.track(100, 900).is_ok());
1007        assert_eq!(tracker.current(), 1000);
1008
1009        // Cannot exceed even by 1 byte
1010        assert!(permit1.track(1, 1000).is_err());
1011        assert_eq!(tracker.current(), 1000);
1012
1013        permit1.release(1000);
1014        assert_eq!(tracker.current(), 0);
1015    }
1016
1017    #[test]
1018    fn test_query_memory_tracker_standard_tier_fraction() {
1019        // Test standard-tier streams use fraction of limit
1020        // Limit: 1000, default fraction: 0.7, so standard-tier can use 700
1021        let tracker = Arc::new(QueryMemoryTracker::with_privileged_slots(1000, 1));
1022
1023        let permit1 = tracker.register_permit();
1024        assert!(permit1.is_privileged());
1025
1026        let permit2 = tracker.register_permit();
1027        assert!(!permit2.is_privileged());
1028
1029        // Standard-tier can use up to 700 (1000 * 0.7 default)
1030        assert!(permit2.track(600, 0).is_ok());
1031        assert_eq!(tracker.current(), 600);
1032
1033        // Cannot exceed standard-tier limit (700)
1034        assert!(permit2.track(200, 600).is_err());
1035        assert_eq!(tracker.current(), 600);
1036
1037        // Can add within standard-tier limit
1038        assert!(permit2.track(100, 600).is_ok());
1039        assert_eq!(tracker.current(), 700);
1040
1041        // Cannot exceed standard-tier limit
1042        assert!(permit2.track(1, 700).is_err());
1043        assert_eq!(tracker.current(), 700);
1044
1045        permit2.release(700);
1046        assert_eq!(tracker.current(), 0);
1047    }
1048}