1use std::pin::Pin;
16use std::sync::Arc;
17
18use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
19use common_error::ext::BoxedError;
20use common_recordbatch::adapter::RecordBatchMetrics;
21use common_recordbatch::error::Result as RecordBatchResult;
22use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream};
23use datatypes::prelude::*;
24use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
25use datatypes::vectors::UInt32Vector;
26use futures::Stream;
27use futures::task::{Context, Poll};
28use snafu::prelude::*;
29use store_api::data_source::DataSource;
30use store_api::storage::{RegionNumber, ScanRequest};
31
32use crate::error::{SchemaConversionSnafu, TableProjectionSnafu};
33use crate::metadata::{
34 FilterPushDownType, TableId, TableInfoBuilder, TableMetaBuilder, TableType, TableVersion,
35};
36use crate::{Table, TableRef};
37
38pub struct MemTable;
39
40impl MemTable {
41 pub fn table(table_name: impl Into<String>, recordbatch: RecordBatch) -> TableRef {
42 Self::new_with_region(table_name, recordbatch, vec![0])
43 }
44
45 pub fn new_with_region(
46 table_name: impl Into<String>,
47 recordbatch: RecordBatch,
48 regions: Vec<RegionNumber>,
49 ) -> TableRef {
50 Self::new_with_catalog(
51 table_name,
52 recordbatch,
53 1,
54 DEFAULT_CATALOG_NAME.to_string(),
55 DEFAULT_SCHEMA_NAME.to_string(),
56 regions,
57 )
58 }
59
60 pub fn new_with_catalog(
61 table_name: impl Into<String>,
62 recordbatch: RecordBatch,
63 table_id: TableId,
64 catalog_name: String,
65 schema_name: String,
66 regions: Vec<RegionNumber>,
67 ) -> TableRef {
68 let schema = recordbatch.schema.clone();
69
70 let meta = TableMetaBuilder::empty()
71 .schema(schema)
72 .primary_key_indices(vec![])
73 .value_indices(vec![])
74 .engine("mito".to_string())
75 .next_column_id(0)
76 .options(Default::default())
77 .created_on(Default::default())
78 .region_numbers(regions)
79 .build()
80 .unwrap();
81
82 let info = Arc::new(
83 TableInfoBuilder::default()
84 .table_id(table_id)
85 .table_version(0 as TableVersion)
86 .name(table_name.into())
87 .schema_name(schema_name)
88 .catalog_name(catalog_name)
89 .desc(None)
90 .table_type(TableType::Base)
91 .meta(meta)
92 .build()
93 .unwrap(),
94 );
95
96 let data_source = Arc::new(MemtableDataSource { recordbatch });
97 let table = Table::new(info, FilterPushDownType::Unsupported, data_source);
98 Arc::new(table)
99 }
100
101 pub fn default_numbers_table() -> TableRef {
104 Self::specified_numbers_table(100)
105 }
106
107 pub fn specified_numbers_table(rows: u32) -> TableRef {
108 let column_schemas = vec![ColumnSchema::new(
109 "uint32s",
110 ConcreteDataType::uint32_datatype(),
111 true,
112 )];
113 let schema = Arc::new(Schema::new(column_schemas));
114 let columns: Vec<VectorRef> = vec![Arc::new(UInt32Vector::from_slice(
115 (0..rows).collect::<Vec<_>>(),
116 ))];
117 let recordbatch = RecordBatch::new(schema, columns).unwrap();
118 MemTable::table("numbers", recordbatch)
119 }
120}
121
122struct MemtableDataSource {
123 recordbatch: RecordBatch,
124}
125
126impl DataSource for MemtableDataSource {
127 fn get_stream(
128 &self,
129 request: ScanRequest,
130 ) -> std::result::Result<SendableRecordBatchStream, BoxedError> {
131 let df_recordbatch = if let Some(indices) = request.projection {
132 self.recordbatch
133 .df_record_batch()
134 .project(&indices)
135 .context(TableProjectionSnafu)
136 .map_err(BoxedError::new)?
137 } else {
138 self.recordbatch.df_record_batch().clone()
139 };
140
141 let rows = df_recordbatch.num_rows();
142 let limit = if let Some(limit) = request.limit {
143 limit.min(rows)
144 } else {
145 rows
146 };
147 let df_recordbatch = df_recordbatch.slice(0, limit);
148
149 let recordbatch = RecordBatch::from_df_record_batch(
150 Arc::new(
151 Schema::try_from(df_recordbatch.schema())
152 .context(SchemaConversionSnafu)
153 .map_err(BoxedError::new)?,
154 ),
155 df_recordbatch,
156 );
157
158 Ok(Box::pin(MemtableStream {
159 schema: recordbatch.schema.clone(),
160 recordbatch: Some(recordbatch),
161 }))
162 }
163}
164
165impl RecordBatchStream for MemtableStream {
166 fn schema(&self) -> SchemaRef {
167 self.schema.clone()
168 }
169
170 fn output_ordering(&self) -> Option<&[OrderOption]> {
171 None
172 }
173
174 fn metrics(&self) -> Option<RecordBatchMetrics> {
175 None
176 }
177}
178
179struct MemtableStream {
180 schema: SchemaRef,
181 recordbatch: Option<RecordBatch>,
182}
183
184impl Stream for MemtableStream {
185 type Item = RecordBatchResult<RecordBatch>;
186
187 fn poll_next(mut self: Pin<&mut Self>, _ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
188 match self.recordbatch.take() {
189 Some(records) => Poll::Ready(Some(Ok(records))),
190 None => Poll::Ready(None),
191 }
192 }
193}
194
195#[cfg(test)]
196mod test {
197 use common_recordbatch::util;
198 use datatypes::prelude::*;
199 use datatypes::schema::ColumnSchema;
200 use datatypes::vectors::{Helper, Int32Vector, StringVector};
201
202 use super::*;
203
204 #[tokio::test]
205 async fn test_scan_with_projection() {
206 let table = build_testing_table();
207
208 let scan_req = ScanRequest {
209 projection: Some(vec![1]),
210 ..Default::default()
211 };
212 let stream = table.scan_to_stream(scan_req).await.unwrap();
213 let recordbatch = util::collect(stream).await.unwrap();
214 assert_eq!(1, recordbatch.len());
215 let columns = recordbatch[0].df_record_batch().columns();
216 assert_eq!(1, columns.len());
217
218 let string_column = Helper::try_into_vector(&columns[0]).unwrap();
219 let string_column = string_column
220 .as_any()
221 .downcast_ref::<StringVector>()
222 .unwrap();
223 let string_column = string_column.iter_data().flatten().collect::<Vec<&str>>();
224 assert_eq!(vec!["hello", "greptime"], string_column);
225 }
226
227 #[tokio::test]
228 async fn test_scan_with_limit() {
229 let table = build_testing_table();
230
231 let scan_req = ScanRequest {
232 limit: Some(2),
233 ..Default::default()
234 };
235 let stream = table.scan_to_stream(scan_req).await.unwrap();
236 let recordbatch = util::collect(stream).await.unwrap();
237 assert_eq!(1, recordbatch.len());
238 let columns = recordbatch[0].df_record_batch().columns();
239 assert_eq!(2, columns.len());
240
241 let i32_column = Helper::try_into_vector(&columns[0]).unwrap();
242 let i32_column = i32_column.as_any().downcast_ref::<Int32Vector>().unwrap();
243 let i32_column = i32_column.iter_data().flatten().collect::<Vec<i32>>();
244 assert_eq!(vec![-100], i32_column);
245
246 let string_column = Helper::try_into_vector(&columns[1]).unwrap();
247 let string_column = string_column
248 .as_any()
249 .downcast_ref::<StringVector>()
250 .unwrap();
251 let string_column = string_column.iter_data().flatten().collect::<Vec<&str>>();
252 assert_eq!(vec!["hello"], string_column);
253 }
254
255 fn build_testing_table() -> TableRef {
256 let i32_column_schema =
257 ColumnSchema::new("i32_numbers", ConcreteDataType::int32_datatype(), true);
258 let string_column_schema =
259 ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true);
260 let column_schemas = vec![i32_column_schema, string_column_schema];
261
262 let schema = Arc::new(Schema::new(column_schemas));
263 let columns: Vec<VectorRef> = vec![
264 Arc::new(Int32Vector::from(vec![
265 Some(-100),
266 None,
267 Some(1),
268 Some(100),
269 ])),
270 Arc::new(StringVector::from(vec![
271 Some("hello"),
272 None,
273 Some("greptime"),
274 None,
275 ])),
276 ];
277 let recordbatch = RecordBatch::new(schema, columns).unwrap();
278 MemTable::table("", recordbatch)
279 }
280}