1use std::collections::HashMap;
16
17use api::helper;
18use api::v1::{ColumnSchema, Row, Rows};
19use datatypes::value::Value;
20use store_api::storage::RegionNumber;
21
22use crate::error::Result;
23use crate::PartitionRuleRef;
24
25pub struct RowSplitter {
26 partition_rule: PartitionRuleRef,
27}
28
29impl RowSplitter {
30 pub fn new(partition_rule: PartitionRuleRef) -> Self {
31 Self { partition_rule }
32 }
33
34 pub fn split(&self, rows: Rows) -> Result<HashMap<RegionNumber, Rows>> {
35 if rows.rows.is_empty() {
37 return Ok(HashMap::new());
38 }
39
40 let partition_columns = self.partition_rule.partition_columns();
42 if partition_columns.is_empty() {
43 return Ok(HashMap::from([(0, rows)]));
44 }
45
46 let splitter = SplitReadRowHelper::new(rows, &self.partition_rule);
47 splitter.split_rows()
48 }
49}
50
51struct SplitReadRowHelper<'a> {
52 schema: Vec<ColumnSchema>,
53 rows: Vec<Row>,
54 partition_rule: &'a PartitionRuleRef,
55 partition_cols_indexes: Vec<Option<usize>>,
57}
58
59impl<'a> SplitReadRowHelper<'a> {
60 fn new(rows: Rows, partition_rule: &'a PartitionRuleRef) -> Self {
61 let col_name_to_idx = rows
62 .schema
63 .iter()
64 .enumerate()
65 .map(|(idx, col)| (&col.column_name, idx))
66 .collect::<HashMap<_, _>>();
67 let partition_cols = partition_rule.partition_columns();
68 let partition_cols_indexes = partition_cols
69 .into_iter()
70 .map(|col_name| col_name_to_idx.get(&col_name).cloned())
71 .collect::<Vec<_>>();
72
73 Self {
74 schema: rows.schema,
75 rows: rows.rows,
76 partition_rule,
77 partition_cols_indexes,
78 }
79 }
80
81 fn split_rows(mut self) -> Result<HashMap<RegionNumber, Rows>> {
82 let regions = self.split_to_regions()?;
83 let request_splits = regions
84 .into_iter()
85 .map(|(region_number, row_indexes)| {
86 let rows = row_indexes
87 .into_iter()
88 .map(|row_idx| std::mem::take(&mut self.rows[row_idx]))
89 .collect();
90 let rows = Rows {
91 schema: self.schema.clone(),
92 rows,
93 };
94 (region_number, rows)
95 })
96 .collect::<HashMap<_, _>>();
97
98 Ok(request_splits)
99 }
100
101 fn split_to_regions(&self) -> Result<HashMap<RegionNumber, Vec<usize>>> {
102 let mut regions_row_indexes: HashMap<RegionNumber, Vec<usize>> = HashMap::new();
103 for (row_idx, values) in self.iter_partition_values().enumerate() {
104 let region_number = self.partition_rule.find_region(&values)?;
105 regions_row_indexes
106 .entry(region_number)
107 .or_default()
108 .push(row_idx);
109 }
110
111 Ok(regions_row_indexes)
112 }
113
114 fn iter_partition_values(&'a self) -> impl Iterator<Item = Vec<Value>> + 'a {
115 self.rows.iter().map(|row| {
116 self.partition_cols_indexes
117 .iter()
118 .map(|idx| {
119 idx.as_ref().map_or(Value::Null, |idx| {
120 helper::pb_value_to_value_ref(
121 &row.values[*idx],
122 &self.schema[*idx].datatype_extension,
123 )
124 .into()
125 })
126 })
127 .collect()
128 })
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use std::any::Any;
135 use std::sync::Arc;
136
137 use api::v1::helper::{field_column_schema, tag_column_schema};
138 use api::v1::value::ValueData;
139 use api::v1::ColumnDataType;
140 use serde::{Deserialize, Serialize};
141
142 use super::*;
143 use crate::partition::RegionMask;
144 use crate::PartitionRule;
145
146 fn mock_rows() -> Rows {
147 let schema = vec![
148 tag_column_schema("id", ColumnDataType::String),
149 tag_column_schema("name", ColumnDataType::String),
150 field_column_schema("age", ColumnDataType::Uint32),
151 ];
152 let rows = vec![
153 Row {
154 values: vec![
155 ValueData::StringValue("1".to_string()).into(),
156 ValueData::StringValue("Smith".to_string()).into(),
157 ValueData::U32Value(20).into(),
158 ],
159 },
160 Row {
161 values: vec![
162 ValueData::StringValue("2".to_string()).into(),
163 ValueData::StringValue("Johnson".to_string()).into(),
164 ValueData::U32Value(21).into(),
165 ],
166 },
167 Row {
168 values: vec![
169 ValueData::StringValue("3".to_string()).into(),
170 ValueData::StringValue("Williams".to_string()).into(),
171 ValueData::U32Value(22).into(),
172 ],
173 },
174 ];
175 Rows { schema, rows }
176 }
177
178 #[derive(Debug, Serialize, Deserialize)]
179 struct MockPartitionRule;
180
181 impl PartitionRule for MockPartitionRule {
182 fn as_any(&self) -> &dyn Any {
183 self
184 }
185
186 fn partition_columns(&self) -> Vec<String> {
187 vec!["id".to_string()]
188 }
189
190 fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
191 let val = values.first().unwrap().clone();
192 let val = match val {
193 Value::String(v) => v.as_utf8().to_string(),
194 _ => unreachable!(),
195 };
196
197 Ok(val.parse::<u32>().unwrap() % 2)
198 }
199
200 fn split_record_batch(
201 &self,
202 _record_batch: &datatypes::arrow::array::RecordBatch,
203 ) -> Result<HashMap<RegionNumber, RegionMask>> {
204 unimplemented!()
205 }
206 }
207
208 #[derive(Debug, Serialize, Deserialize)]
209 struct MockMissedColPartitionRule;
210
211 impl PartitionRule for MockMissedColPartitionRule {
212 fn as_any(&self) -> &dyn Any {
213 self
214 }
215
216 fn partition_columns(&self) -> Vec<String> {
217 vec!["missed_col".to_string()]
218 }
219
220 fn find_region(&self, values: &[Value]) -> Result<RegionNumber> {
221 let val = values.first().unwrap().clone();
222 let val = match val {
223 Value::Null => 1,
224 _ => 0,
225 };
226
227 Ok(val)
228 }
229
230 fn split_record_batch(
231 &self,
232 _record_batch: &datatypes::arrow::array::RecordBatch,
233 ) -> Result<HashMap<RegionNumber, RegionMask>> {
234 unimplemented!()
235 }
236 }
237
238 #[derive(Debug, Serialize, Deserialize)]
239 struct EmptyPartitionRule;
240
241 impl PartitionRule for EmptyPartitionRule {
242 fn as_any(&self) -> &dyn Any {
243 self
244 }
245
246 fn partition_columns(&self) -> Vec<String> {
247 vec![]
248 }
249
250 fn find_region(&self, _values: &[Value]) -> Result<RegionNumber> {
251 Ok(0)
252 }
253
254 fn split_record_batch(
255 &self,
256 _record_batch: &datatypes::arrow::array::RecordBatch,
257 ) -> Result<HashMap<RegionNumber, RegionMask>> {
258 unimplemented!()
259 }
260 }
261 #[test]
262 fn test_writer_splitter() {
263 let rows = mock_rows();
264 let rule = Arc::new(MockPartitionRule) as PartitionRuleRef;
265 let splitter = RowSplitter::new(rule);
266
267 let mut splits = splitter.split(rows).unwrap();
268 assert_eq!(splits.len(), 2);
269
270 let rows0 = splits.remove(&0).unwrap().rows;
271 let rows1 = splits.remove(&1).unwrap().rows;
272 assert_eq!(rows0.len(), 1);
273 assert_eq!(rows1.len(), 2);
274 }
275
276 #[test]
277 fn test_missed_col_writer_splitter() {
278 let rows = mock_rows();
279 let rule = Arc::new(MockMissedColPartitionRule) as PartitionRuleRef;
280
281 let splitter = RowSplitter::new(rule);
282 let mut splits = splitter.split(rows).unwrap();
283 assert_eq!(splits.len(), 1);
284
285 let rows = splits.remove(&1).unwrap().rows;
286 assert_eq!(rows.len(), 3);
287 }
288
289 #[test]
290 fn test_empty_partition_rule_writer_splitter() {
291 let rows = mock_rows();
292 let rule = Arc::new(EmptyPartitionRule) as PartitionRuleRef;
293 let splitter = RowSplitter::new(rule);
294
295 let mut splits = splitter.split(rows).unwrap();
296 assert_eq!(splits.len(), 1);
297
298 let rows = splits.remove(&0).unwrap().rows;
299 assert_eq!(rows.len(), 3);
300 }
301}