1use std::collections::HashMap;
16
17use api::helper;
18use api::v1::{ColumnSchema, Row, Rows};
19use datatypes::value::Value;
20use store_api::storage::RegionNumber;
21
22use crate::PartitionRuleRef;
23use crate::error::Result;
24
25pub struct RowSplitter {
26 partition_rule: PartitionRuleRef,
27}
28
29impl RowSplitter {
30 pub fn new(partition_rule: PartitionRuleRef) -> Self {
31 Self { partition_rule }
32 }
33
34 pub fn split(&self, rows: Rows) -> Result<HashMap<RegionNumber, Rows>> {
35 if rows.rows.is_empty() {
37 return Ok(HashMap::new());
38 }
39
40 let partition_columns = self.partition_rule.partition_columns();
42 if partition_columns.is_empty() {
43 return Ok(HashMap::from([(0, rows)]));
44 }
45
46 let splitter = SplitReadRowHelper::new(rows, &self.partition_rule);
47 splitter.split_rows()
48 }
49}
50
51struct SplitReadRowHelper<'a> {
52 schema: Vec<ColumnSchema>,
53 rows: Vec<Row>,
54 partition_rule: &'a PartitionRuleRef,
55 partition_cols_indexes: Vec<Option<usize>>,
57}
58
59impl<'a> SplitReadRowHelper<'a> {
60 fn new(rows: Rows, partition_rule: &'a PartitionRuleRef) -> Self {
61 let col_name_to_idx = rows
62 .schema
63 .iter()
64 .enumerate()
65 .map(|(idx, col)| (&col.column_name, idx))
66 .collect::<HashMap<_, _>>();
67 let partition_cols = partition_rule.partition_columns();
68 let partition_cols_indexes = partition_cols
69 .iter()
70 .map(|col_name| col_name_to_idx.get(&col_name).cloned())
71 .collect::<Vec<_>>();
72
73 Self {
74 schema: rows.schema,
75 rows: rows.rows,
76 partition_rule,
77 partition_cols_indexes,
78 }
79 }
80
81 fn split_rows(mut self) -> Result<HashMap<RegionNumber, Rows>> {
82 let regions = self.split_to_regions()?;
83 let request_splits = regions
84 .into_iter()
85 .map(|(region_number, row_indexes)| {
86 let rows = row_indexes
87 .into_iter()
88 .map(|row_idx| std::mem::take(&mut self.rows[row_idx]))
89 .collect();
90 let rows = Rows {
91 schema: self.schema.clone(),
92 rows,
93 };
94 (region_number, rows)
95 })
96 .collect::<HashMap<_, _>>();
97
98 Ok(request_splits)
99 }
100
101 fn split_to_regions(&self) -> Result<HashMap<RegionNumber, Vec<usize>>> {
102 let mut regions_row_indexes: HashMap<RegionNumber, Vec<usize>> = HashMap::new();
103 for (row_idx, values) in self.iter_partition_values().enumerate() {
104 let region_number = self.partition_rule.find_region(&values)?;
105 regions_row_indexes
106 .entry(region_number)
107 .or_default()
108 .push(row_idx);
109 }
110
111 Ok(regions_row_indexes)
112 }
113
114 fn iter_partition_values(&'a self) -> impl Iterator<Item = Vec<Value>> + 'a {
115 self.rows.iter().map(|row| {
116 self.partition_cols_indexes
117 .iter()
118 .map(|idx| {
119 idx.as_ref().map_or(Value::Null, |idx| {
120 helper::pb_value_to_value_ref(
121 &row.values[*idx],
122 self.schema[*idx].datatype_extension.as_ref(),
123 )
124 .into()
125 })
126 })
127 .collect()
128 })
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use std::any::Any;
135 use std::sync::Arc;
136
137 use api::v1::ColumnDataType;
138 use api::v1::helper::{field_column_schema, tag_column_schema};
139 use api::v1::value::ValueData;
140 use serde::{Deserialize, Serialize};
141
142 use super::*;
143 use crate::PartitionRule;
144 use crate::partition::RegionMask;
145
146 fn mock_rows() -> Rows {
147 let schema = vec![
148 tag_column_schema("id", ColumnDataType::String),
149 tag_column_schema("name", ColumnDataType::String),
150 field_column_schema("age", ColumnDataType::Uint32),
151 ];
152 let rows = vec![
153 Row {
154 values: vec![
155 ValueData::StringValue("1".to_string()).into(),
156 ValueData::StringValue("Smith".to_string()).into(),
157 ValueData::U32Value(20).into(),
158 ],
159 },
160 Row {
161 values: vec![
162 ValueData::StringValue("2".to_string()).into(),
163 ValueData::StringValue("Johnson".to_string()).into(),
164 ValueData::U32Value(21).into(),
165 ],
166 },
167 Row {
168 values: vec![
169 ValueData::StringValue("3".to_string()).into(),
170 ValueData::StringValue("Williams".to_string()).into(),
171 ValueData::U32Value(22).into(),
172 ],
173 },
174 ];
175 Rows { schema, rows }
176 }
177
178 #[derive(Debug, Serialize, Deserialize)]
179 struct MockPartitionRule {
180 partition_columns: Vec<String>,
181 }
182
183 impl Default for MockPartitionRule {
184 fn default() -> Self {
185 Self {
186 partition_columns: vec!["id".to_string()],
187 }
188 }
189 }
190
191 impl PartitionRule for MockPartitionRule {
192 fn as_any(&self) -> &dyn Any {
193 self
194 }
195
196 fn partition_columns(&self) -> &[String] {
197 &self.partition_columns
198 }
199
200 fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
201 let val = values.first().unwrap().clone();
202 let val = match val {
203 Value::String(v) => v.as_utf8().to_string(),
204 _ => unreachable!(),
205 };
206
207 Ok(val.parse::<u32>().unwrap() % 2)
208 }
209
210 fn split_record_batch(
211 &self,
212 _record_batch: &datatypes::arrow::array::RecordBatch,
213 ) -> Result<HashMap<RegionNumber, RegionMask>> {
214 unimplemented!()
215 }
216 }
217
218 #[derive(Debug, Serialize, Deserialize)]
219 struct MockMissedColPartitionRule {
220 partition_columns: Vec<String>,
221 }
222
223 impl Default for MockMissedColPartitionRule {
224 fn default() -> Self {
225 Self {
226 partition_columns: vec!["missed_col".to_string()],
227 }
228 }
229 }
230
231 impl PartitionRule for MockMissedColPartitionRule {
232 fn as_any(&self) -> &dyn Any {
233 self
234 }
235
236 fn partition_columns(&self) -> &[String] {
237 &self.partition_columns
238 }
239
240 fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
241 let val = values.first().unwrap().clone();
242 let val = match val {
243 Value::Null => 1,
244 _ => 0,
245 };
246
247 Ok(val)
248 }
249
250 fn split_record_batch(
251 &self,
252 _record_batch: &datatypes::arrow::array::RecordBatch,
253 ) -> Result<HashMap<RegionNumber, RegionMask>> {
254 unimplemented!()
255 }
256 }
257
258 #[derive(Debug, Serialize, Deserialize)]
259 struct EmptyPartitionRule;
260
261 impl PartitionRule for EmptyPartitionRule {
262 fn as_any(&self) -> &dyn Any {
263 self
264 }
265
266 fn partition_columns(&self) -> &[String] {
267 &[]
268 }
269
270 fn find_region(&self, _values: &[Value]) -> Result<RegionNumber> {
271 Ok(0)
272 }
273
274 fn split_record_batch(
275 &self,
276 _record_batch: &datatypes::arrow::array::RecordBatch,
277 ) -> Result<HashMap<RegionNumber, RegionMask>> {
278 unimplemented!()
279 }
280 }
281 #[test]
282 fn test_writer_splitter() {
283 let rows = mock_rows();
284 let rule = Arc::new(MockPartitionRule::default()) as PartitionRuleRef;
285 let splitter = RowSplitter::new(rule);
286
287 let mut splits = splitter.split(rows).unwrap();
288 assert_eq!(splits.len(), 2);
289
290 let rows0 = splits.remove(&0).unwrap().rows;
291 let rows1 = splits.remove(&1).unwrap().rows;
292 assert_eq!(rows0.len(), 1);
293 assert_eq!(rows1.len(), 2);
294 }
295
296 #[test]
297 fn test_missed_col_writer_splitter() {
298 let rows = mock_rows();
299 let rule = Arc::new(MockMissedColPartitionRule::default()) as PartitionRuleRef;
300
301 let splitter = RowSplitter::new(rule);
302 let mut splits = splitter.split(rows).unwrap();
303 assert_eq!(splits.len(), 1);
304
305 let rows = splits.remove(&1).unwrap().rows;
306 assert_eq!(rows.len(), 3);
307 }
308
309 #[test]
310 fn test_empty_partition_rule_writer_splitter() {
311 let rows = mock_rows();
312 let rule = Arc::new(EmptyPartitionRule) as PartitionRuleRef;
313 let splitter = RowSplitter::new(rule);
314
315 let mut splits = splitter.split(rows).unwrap();
316 assert_eq!(splits.len(), 1);
317
318 let rows = splits.remove(&0).unwrap().rows;
319 assert_eq!(rows.len(), 3);
320 }
321}