Skip to main content

table/test_util/
memtable.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::sync::Arc;
17
18use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
19use common_error::ext::BoxedError;
20use common_recordbatch::adapter::RecordBatchMetrics;
21use common_recordbatch::error::Result as RecordBatchResult;
22use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream};
23use datatypes::prelude::*;
24use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
25use datatypes::vectors::UInt32Vector;
26use futures::Stream;
27use futures::task::{Context, Poll};
28use snafu::prelude::*;
29use store_api::data_source::DataSource;
30use store_api::storage::ScanRequest;
31
32use crate::error::{SchemaConversionSnafu, TableProjectionSnafu};
33use crate::metadata::{
34    FilterPushDownType, TableId, TableInfoBuilder, TableMetaBuilder, TableType, TableVersion,
35};
36use crate::{Table, TableRef};
37
38pub struct MemTable;
39
40impl MemTable {
41    pub fn table(table_name: impl Into<String>, recordbatch: RecordBatch) -> TableRef {
42        Self::new_with_region(table_name, recordbatch)
43    }
44
45    pub fn new_with_region(table_name: impl Into<String>, recordbatch: RecordBatch) -> TableRef {
46        Self::new_with_catalog(
47            table_name,
48            recordbatch,
49            1,
50            DEFAULT_CATALOG_NAME.to_string(),
51            DEFAULT_SCHEMA_NAME.to_string(),
52        )
53    }
54
55    pub fn new_with_catalog(
56        table_name: impl Into<String>,
57        recordbatch: RecordBatch,
58        table_id: TableId,
59        catalog_name: String,
60        schema_name: String,
61    ) -> TableRef {
62        let schema = recordbatch.schema.clone();
63
64        let meta = TableMetaBuilder::empty()
65            .schema(schema)
66            .primary_key_indices(vec![])
67            .value_indices(vec![])
68            .engine("mito".to_string())
69            .next_column_id(0)
70            .options(Default::default())
71            .created_on(Default::default())
72            .build()
73            .unwrap();
74
75        let info = Arc::new(
76            TableInfoBuilder::default()
77                .table_id(table_id)
78                .table_version(0 as TableVersion)
79                .name(table_name.into())
80                .schema_name(schema_name)
81                .catalog_name(catalog_name)
82                .desc(None)
83                .table_type(TableType::Base)
84                .meta(meta)
85                .build()
86                .unwrap(),
87        );
88
89        let data_source = Arc::new(MemtableDataSource { recordbatch });
90        let table = Table::new(info, FilterPushDownType::Unsupported, data_source);
91        Arc::new(table)
92    }
93
94    /// Creates a 1 column 100 rows table, with table name "numbers", column name "uint32s" and
95    /// column type "uint32". Column data increased from 0 to 100.
96    pub fn default_numbers_table() -> TableRef {
97        Self::specified_numbers_table(100)
98    }
99
100    pub fn specified_numbers_table(rows: u32) -> TableRef {
101        let column_schemas = vec![ColumnSchema::new(
102            "uint32s",
103            ConcreteDataType::uint32_datatype(),
104            true,
105        )];
106        let schema = Arc::new(Schema::new(column_schemas));
107        let columns: Vec<VectorRef> = vec![Arc::new(UInt32Vector::from_slice(
108            (0..rows).collect::<Vec<_>>(),
109        ))];
110        let recordbatch = RecordBatch::new(schema, columns).unwrap();
111        MemTable::table("numbers", recordbatch)
112    }
113}
114
115struct MemtableDataSource {
116    recordbatch: RecordBatch,
117}
118
119impl DataSource for MemtableDataSource {
120    fn get_stream(
121        &self,
122        request: ScanRequest,
123    ) -> std::result::Result<SendableRecordBatchStream, BoxedError> {
124        let df_recordbatch = if let Some(indices) = request.projection_indices() {
125            self.recordbatch
126                .df_record_batch()
127                .project(indices)
128                .context(TableProjectionSnafu)
129                .map_err(BoxedError::new)?
130        } else {
131            self.recordbatch.df_record_batch().clone()
132        };
133
134        let rows = df_recordbatch.num_rows();
135        let limit = if let Some(limit) = request.limit {
136            limit.min(rows)
137        } else {
138            rows
139        };
140        let df_recordbatch = df_recordbatch.slice(0, limit);
141
142        let recordbatch = RecordBatch::from_df_record_batch(
143            Arc::new(
144                Schema::try_from(df_recordbatch.schema())
145                    .context(SchemaConversionSnafu)
146                    .map_err(BoxedError::new)?,
147            ),
148            df_recordbatch,
149        );
150
151        Ok(Box::pin(MemtableStream {
152            schema: recordbatch.schema.clone(),
153            recordbatch: Some(recordbatch),
154        }))
155    }
156}
157
158impl RecordBatchStream for MemtableStream {
159    fn schema(&self) -> SchemaRef {
160        self.schema.clone()
161    }
162
163    fn output_ordering(&self) -> Option<&[OrderOption]> {
164        None
165    }
166
167    fn metrics(&self) -> Option<RecordBatchMetrics> {
168        None
169    }
170}
171
172struct MemtableStream {
173    schema: SchemaRef,
174    recordbatch: Option<RecordBatch>,
175}
176
177impl Stream for MemtableStream {
178    type Item = RecordBatchResult<RecordBatch>;
179
180    fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
181        match self.recordbatch.take() {
182            Some(records) => Poll::Ready(Some(Ok(records))),
183            None => Poll::Ready(None),
184        }
185    }
186}
187
188#[cfg(test)]
189mod test {
190    use common_recordbatch::util;
191    use datatypes::prelude::*;
192    use datatypes::schema::ColumnSchema;
193    use datatypes::vectors::{Helper, Int32Vector, StringVector};
194
195    use super::*;
196
197    #[tokio::test]
198    async fn test_scan_with_projection() {
199        let table = build_testing_table();
200
201        let projection_input = Some(vec![1].into());
202        let scan_req = ScanRequest {
203            projection_input,
204            ..Default::default()
205        };
206        let stream = table.scan_to_stream(scan_req).await.unwrap();
207        let recordbatch = util::collect(stream).await.unwrap();
208        assert_eq!(1, recordbatch.len());
209        let columns = recordbatch[0].df_record_batch().columns();
210        assert_eq!(1, columns.len());
211
212        let string_column = Helper::try_into_vector(&columns[0]).unwrap();
213        let string_column = string_column
214            .as_any()
215            .downcast_ref::<StringVector>()
216            .unwrap();
217        let string_column = string_column.iter_data().flatten().collect::<Vec<&str>>();
218        assert_eq!(vec!["hello", "greptime"], string_column);
219    }
220
221    #[tokio::test]
222    async fn test_scan_with_limit() {
223        let table = build_testing_table();
224
225        let scan_req = ScanRequest {
226            limit: Some(2),
227            ..Default::default()
228        };
229        let stream = table.scan_to_stream(scan_req).await.unwrap();
230        let recordbatch = util::collect(stream).await.unwrap();
231        assert_eq!(1, recordbatch.len());
232        let columns = recordbatch[0].df_record_batch().columns();
233        assert_eq!(2, columns.len());
234
235        let i32_column = Helper::try_into_vector(&columns[0]).unwrap();
236        let i32_column = i32_column.as_any().downcast_ref::<Int32Vector>().unwrap();
237        let i32_column = i32_column.iter_data().flatten().collect::<Vec<i32>>();
238        assert_eq!(vec![-100], i32_column);
239
240        let string_column = Helper::try_into_vector(&columns[1]).unwrap();
241        let string_column = string_column
242            .as_any()
243            .downcast_ref::<StringVector>()
244            .unwrap();
245        let string_column = string_column.iter_data().flatten().collect::<Vec<&str>>();
246        assert_eq!(vec!["hello"], string_column);
247    }
248
249    fn build_testing_table() -> TableRef {
250        let i32_column_schema =
251            ColumnSchema::new("i32_numbers", ConcreteDataType::int32_datatype(), true);
252        let string_column_schema =
253            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true);
254        let column_schemas = vec![i32_column_schema, string_column_schema];
255
256        let schema = Arc::new(Schema::new(column_schemas));
257        let columns: Vec<VectorRef> = vec![
258            Arc::new(Int32Vector::from(vec![
259                Some(-100),
260                None,
261                Some(1),
262                Some(100),
263            ])),
264            Arc::new(StringVector::from(vec![
265                Some("hello"),
266                None,
267                Some("greptime"),
268                None,
269            ])),
270        ];
271        let recordbatch = RecordBatch::new(schema, columns).unwrap();
272        MemTable::table("", recordbatch)
273    }
274}