1use std::collections::HashMap;
16
17use api::helper;
18use api::v1::{ColumnSchema, Row, Rows};
19use datatypes::value::Value;
20use store_api::storage::RegionNumber;
21
22use crate::error::Result;
23use crate::PartitionRuleRef;
24
25pub struct RowSplitter {
26 partition_rule: PartitionRuleRef,
27}
28
29impl RowSplitter {
30 pub fn new(partition_rule: PartitionRuleRef) -> Self {
31 Self { partition_rule }
32 }
33
34 pub fn split(&self, rows: Rows) -> Result<HashMap<RegionNumber, Rows>> {
35 if rows.rows.is_empty() {
37 return Ok(HashMap::new());
38 }
39
40 let partition_columns = self.partition_rule.partition_columns();
42 if partition_columns.is_empty() {
43 return Ok(HashMap::from([(0, rows)]));
44 }
45
46 let splitter = SplitReadRowHelper::new(rows, &self.partition_rule);
47 splitter.split_rows()
48 }
49}
50
51struct SplitReadRowHelper<'a> {
52 schema: Vec<ColumnSchema>,
53 rows: Vec<Row>,
54 partition_rule: &'a PartitionRuleRef,
55 partition_cols_indexes: Vec<Option<usize>>,
57}
58
59impl<'a> SplitReadRowHelper<'a> {
60 fn new(rows: Rows, partition_rule: &'a PartitionRuleRef) -> Self {
61 let col_name_to_idx = rows
62 .schema
63 .iter()
64 .enumerate()
65 .map(|(idx, col)| (&col.column_name, idx))
66 .collect::<HashMap<_, _>>();
67 let partition_cols = partition_rule.partition_columns();
68 let partition_cols_indexes = partition_cols
69 .into_iter()
70 .map(|col_name| col_name_to_idx.get(&col_name).cloned())
71 .collect::<Vec<_>>();
72
73 Self {
74 schema: rows.schema,
75 rows: rows.rows,
76 partition_rule,
77 partition_cols_indexes,
78 }
79 }
80
81 fn split_rows(mut self) -> Result<HashMap<RegionNumber, Rows>> {
82 let regions = self.split_to_regions()?;
83 let request_splits = regions
84 .into_iter()
85 .map(|(region_number, row_indexes)| {
86 let rows = row_indexes
87 .into_iter()
88 .map(|row_idx| std::mem::take(&mut self.rows[row_idx]))
89 .collect();
90 let rows = Rows {
91 schema: self.schema.clone(),
92 rows,
93 };
94 (region_number, rows)
95 })
96 .collect::<HashMap<_, _>>();
97
98 Ok(request_splits)
99 }
100
101 fn split_to_regions(&self) -> Result<HashMap<RegionNumber, Vec<usize>>> {
102 let mut regions_row_indexes: HashMap<RegionNumber, Vec<usize>> = HashMap::new();
103 for (row_idx, values) in self.iter_partition_values().enumerate() {
104 let region_number = self.partition_rule.find_region(&values)?;
105 regions_row_indexes
106 .entry(region_number)
107 .or_default()
108 .push(row_idx);
109 }
110
111 Ok(regions_row_indexes)
112 }
113
114 fn iter_partition_values(&'a self) -> impl Iterator<Item = Vec<Value>> + 'a {
115 self.rows.iter().map(|row| {
116 self.partition_cols_indexes
117 .iter()
118 .map(|idx| {
119 idx.as_ref().map_or(Value::Null, |idx| {
120 helper::pb_value_to_value_ref(
121 &row.values[*idx],
122 &self.schema[*idx].datatype_extension,
123 )
124 .into()
125 })
126 })
127 .collect()
128 })
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use std::any::Any;
135 use std::sync::Arc;
136
137 use api::v1::value::ValueData;
138 use api::v1::{ColumnDataType, SemanticType};
139 use datatypes::arrow::array::BooleanArray;
140 use serde::{Deserialize, Serialize};
141
142 use super::*;
143 use crate::PartitionRule;
144
145 fn mock_rows() -> Rows {
146 let schema = vec![
147 ColumnSchema {
148 column_name: "id".to_string(),
149 datatype: ColumnDataType::String as i32,
150 semantic_type: SemanticType::Tag as i32,
151 ..Default::default()
152 },
153 ColumnSchema {
154 column_name: "name".to_string(),
155 datatype: ColumnDataType::String as i32,
156 semantic_type: SemanticType::Tag as i32,
157 ..Default::default()
158 },
159 ColumnSchema {
160 column_name: "age".to_string(),
161 datatype: ColumnDataType::Uint32 as i32,
162 semantic_type: SemanticType::Field as i32,
163 ..Default::default()
164 },
165 ];
166 let rows = vec![
167 Row {
168 values: vec![
169 ValueData::StringValue("1".to_string()).into(),
170 ValueData::StringValue("Smith".to_string()).into(),
171 ValueData::U32Value(20).into(),
172 ],
173 },
174 Row {
175 values: vec![
176 ValueData::StringValue("2".to_string()).into(),
177 ValueData::StringValue("Johnson".to_string()).into(),
178 ValueData::U32Value(21).into(),
179 ],
180 },
181 Row {
182 values: vec![
183 ValueData::StringValue("3".to_string()).into(),
184 ValueData::StringValue("Williams".to_string()).into(),
185 ValueData::U32Value(22).into(),
186 ],
187 },
188 ];
189 Rows { schema, rows }
190 }
191
192 #[derive(Debug, Serialize, Deserialize)]
193 struct MockPartitionRule;
194
195 impl PartitionRule for MockPartitionRule {
196 fn as_any(&self) -> &dyn Any {
197 self
198 }
199
200 fn partition_columns(&self) -> Vec<String> {
201 vec!["id".to_string()]
202 }
203
204 fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
205 let val = values.first().unwrap().clone();
206 let val = match val {
207 Value::String(v) => v.as_utf8().to_string(),
208 _ => unreachable!(),
209 };
210
211 Ok(val.parse::<u32>().unwrap() % 2)
212 }
213
214 fn split_record_batch(
215 &self,
216 _record_batch: &datatypes::arrow::array::RecordBatch,
217 ) -> Result<HashMap<RegionNumber, BooleanArray>> {
218 unimplemented!()
219 }
220 }
221
222 #[derive(Debug, Serialize, Deserialize)]
223 struct MockMissedColPartitionRule;
224
225 impl PartitionRule for MockMissedColPartitionRule {
226 fn as_any(&self) -> &dyn Any {
227 self
228 }
229
230 fn partition_columns(&self) -> Vec<String> {
231 vec!["missed_col".to_string()]
232 }
233
234 fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
235 let val = values.first().unwrap().clone();
236 let val = match val {
237 Value::Null => 1,
238 _ => 0,
239 };
240
241 Ok(val)
242 }
243
244 fn split_record_batch(
245 &self,
246 _record_batch: &datatypes::arrow::array::RecordBatch,
247 ) -> Result<HashMap<RegionNumber, BooleanArray>> {
248 unimplemented!()
249 }
250 }
251
252 #[derive(Debug, Serialize, Deserialize)]
253 struct EmptyPartitionRule;
254
255 impl PartitionRule for EmptyPartitionRule {
256 fn as_any(&self) -> &dyn Any {
257 self
258 }
259
260 fn partition_columns(&self) -> Vec<String> {
261 vec![]
262 }
263
264 fn find_region(&self, _values: &[Value]) -> Result<RegionNumber> {
265 Ok(0)
266 }
267
268 fn split_record_batch(
269 &self,
270 _record_batch: &datatypes::arrow::array::RecordBatch,
271 ) -> Result<HashMap<RegionNumber, BooleanArray>> {
272 unimplemented!()
273 }
274 }
275 #[test]
276 fn test_writer_splitter() {
277 let rows = mock_rows();
278 let rule = Arc::new(MockPartitionRule) as PartitionRuleRef;
279 let splitter = RowSplitter::new(rule);
280
281 let mut splits = splitter.split(rows).unwrap();
282 assert_eq!(splits.len(), 2);
283
284 let rows0 = splits.remove(&0).unwrap().rows;
285 let rows1 = splits.remove(&1).unwrap().rows;
286 assert_eq!(rows0.len(), 1);
287 assert_eq!(rows1.len(), 2);
288 }
289
290 #[test]
291 fn test_missed_col_writer_splitter() {
292 let rows = mock_rows();
293 let rule = Arc::new(MockMissedColPartitionRule) as PartitionRuleRef;
294
295 let splitter = RowSplitter::new(rule);
296 let mut splits = splitter.split(rows).unwrap();
297 assert_eq!(splits.len(), 1);
298
299 let rows = splits.remove(&1).unwrap().rows;
300 assert_eq!(rows.len(), 3);
301 }
302
303 #[test]
304 fn test_empty_partition_rule_writer_splitter() {
305 let rows = mock_rows();
306 let rule = Arc::new(EmptyPartitionRule) as PartitionRuleRef;
307 let splitter = RowSplitter::new(rule);
308
309 let mut splits = splitter.split(rows).unwrap();
310 assert_eq!(splits.len(), 1);
311
312 let rows = splits.remove(&0).unwrap().rows;
313 assert_eq!(rows.len(), 3);
314 }
315}