1use std::collections::HashSet;
16
17use ahash::{HashMap, HashMapExt};
18use api::v1::flow::DirtyWindowRequest;
19use api::v1::region::{
20 bulk_insert_request, region_request, BulkInsertRequest, RegionRequest, RegionRequestHeader,
21};
22use api::v1::ArrowIpc;
23use arrow::array::Array;
24use arrow::record_batch::RecordBatch;
25use common_base::AffectedRows;
26use common_grpc::flight::{FlightDecoder, FlightEncoder, FlightMessage};
27use common_grpc::FlightData;
28use common_telemetry::error;
29use common_telemetry::tracing_context::TracingContext;
30use snafu::{ensure, OptionExt, ResultExt};
31use store_api::storage::RegionId;
32use table::metadata::TableInfoRef;
33use table::TableRef;
34
35use crate::insert::Inserter;
36use crate::{error, metrics};
37
38impl Inserter {
39 pub async fn handle_bulk_insert(
41 &self,
42 table: TableRef,
43 decoder: &mut FlightDecoder,
44 data: FlightData,
45 ) -> error::Result<AffectedRows> {
46 let table_info = table.table_info();
47 let table_id = table_info.table_id();
48 let db_name = table_info.get_db_string();
49 let decode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
50 .with_label_values(&["decode_request"])
51 .start_timer();
52 let body_size = data.data_body.len();
53 let message = decoder
55 .try_decode(&data)
56 .context(error::DecodeFlightDataSnafu)?
57 .context(error::NotSupportedSnafu {
58 feat: "bulk insert RecordBatch with dictionary arrays",
59 })?;
60 let FlightMessage::RecordBatch(record_batch) = message else {
61 return Ok(0);
62 };
63 decode_timer.observe_duration();
64
65 if record_batch.num_rows() == 0 {
66 return Ok(0);
67 }
68
69 self.maybe_update_flow_dirty_window(table_info.clone(), record_batch.clone());
71
72 metrics::BULK_REQUEST_MESSAGE_SIZE.observe(body_size as f64);
73 metrics::BULK_REQUEST_ROWS
74 .with_label_values(&["raw"])
75 .observe(record_batch.num_rows() as f64);
76
77 let schema_bytes = decoder.schema_bytes().unwrap();
79 let partition_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
80 .with_label_values(&["partition"])
81 .start_timer();
82 let partition_rule = self
83 .partition_manager
84 .find_table_partition_rule(&table_info)
85 .await
86 .context(error::InvalidPartitionSnafu)?;
87
88 let region_masks = partition_rule
90 .split_record_batch(&record_batch)
91 .context(error::SplitInsertSnafu)?;
92 partition_timer.observe_duration();
93
94 if region_masks.len() == 1 {
96 metrics::BULK_REQUEST_ROWS
97 .with_label_values(&["rows_per_region"])
98 .observe(record_batch.num_rows() as f64);
99
100 let (region_number, _) = region_masks.into_iter().next().unwrap();
102 let region_id = RegionId::new(table_id, region_number);
103 let datanode = self
104 .partition_manager
105 .find_region_leader(region_id)
106 .await
107 .context(error::FindRegionLeaderSnafu)?;
108 let request = RegionRequest {
109 header: Some(RegionRequestHeader {
110 tracing_context: TracingContext::from_current_span().to_w3c(),
111 ..Default::default()
112 }),
113 body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
114 region_id: region_id.as_u64(),
115 body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
116 schema: schema_bytes,
117 data_header: data.data_header,
118 payload: data.data_body,
119 })),
120 })),
121 };
122
123 let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
124 .with_label_values(&["datanode_handle"])
125 .start_timer();
126 let datanode = self.node_manager.datanode(&datanode).await;
127 let result = datanode
128 .handle(request)
129 .await
130 .context(error::RequestRegionSnafu)
131 .map(|r| r.affected_rows);
132 if let Ok(rows) = result {
133 crate::metrics::DIST_INGEST_ROW_COUNT
134 .with_label_values(&[db_name.as_str()])
135 .inc_by(rows as u64);
136 }
137 return result;
138 }
139
140 let mut mask_per_datanode = HashMap::with_capacity(region_masks.len());
141 for (region_number, mask) in region_masks {
142 let region_id = RegionId::new(table_id, region_number);
143 let datanode = self
144 .partition_manager
145 .find_region_leader(region_id)
146 .await
147 .context(error::FindRegionLeaderSnafu)?;
148 mask_per_datanode
149 .entry(datanode)
150 .or_insert_with(Vec::new)
151 .push((region_id, mask));
152 }
153
154 let wait_all_datanode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
155 .with_label_values(&["wait_all_datanode"])
156 .start_timer();
157
158 let mut handles = Vec::with_capacity(mask_per_datanode.len());
159
160 let mut raw_data_bytes = None;
162 for (peer, masks) in mask_per_datanode {
163 for (region_id, mask) in masks {
164 if mask.select_none() {
165 continue;
166 }
167 let rb = record_batch.clone();
168 let schema_bytes = schema_bytes.clone();
169 let node_manager = self.node_manager.clone();
170 let peer = peer.clone();
171 let raw_header_and_data = if mask.select_all() {
172 Some(
173 raw_data_bytes
174 .get_or_insert_with(|| {
175 (data.data_header.clone(), data.data_body.clone())
176 })
177 .clone(),
178 )
179 } else {
180 None
181 };
182 let handle: common_runtime::JoinHandle<error::Result<api::region::RegionResponse>> =
183 common_runtime::spawn_global(async move {
184 let (header, payload) = if mask.select_all() {
185 raw_header_and_data.unwrap()
187 } else {
188 let filter_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
189 .with_label_values(&["filter"])
190 .start_timer();
191 let batch = arrow::compute::filter_record_batch(&rb, mask.array())
192 .context(error::ComputeArrowSnafu)?;
193 filter_timer.observe_duration();
194 metrics::BULK_REQUEST_ROWS
195 .with_label_values(&["rows_per_region"])
196 .observe(batch.num_rows() as f64);
197
198 let encode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
199 .with_label_values(&["encode"])
200 .start_timer();
201 let mut iter = FlightEncoder::default()
202 .encode(FlightMessage::RecordBatch(batch))
203 .into_iter();
204 let Some(flight_data) = iter.next() else {
205 unreachable!()
208 };
209 ensure!(
210 iter.next().is_none(),
211 error::NotSupportedSnafu {
212 feat: "bulk insert RecordBatch with dictionary arrays",
213 }
214 );
215 encode_timer.observe_duration();
216 (flight_data.data_header, flight_data.data_body)
217 };
218 let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
219 .with_label_values(&["datanode_handle"])
220 .start_timer();
221 let request = RegionRequest {
222 header: Some(RegionRequestHeader {
223 tracing_context: TracingContext::from_current_span().to_w3c(),
224 ..Default::default()
225 }),
226 body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
227 region_id: region_id.as_u64(),
228 body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
229 schema: schema_bytes,
230 data_header: header,
231 payload,
232 })),
233 })),
234 };
235
236 let datanode = node_manager.datanode(&peer).await;
237 datanode
238 .handle(request)
239 .await
240 .context(error::RequestRegionSnafu)
241 });
242 handles.push(handle);
243 }
244 }
245
246 let region_responses = futures::future::try_join_all(handles)
247 .await
248 .context(error::JoinTaskSnafu)?;
249 wait_all_datanode_timer.observe_duration();
250 let mut rows_inserted: usize = 0;
251 for res in region_responses {
252 rows_inserted += res?.affected_rows;
253 }
254 crate::metrics::DIST_INGEST_ROW_COUNT
255 .with_label_values(&[db_name.as_str()])
256 .inc_by(rows_inserted as u64);
257 Ok(rows_inserted)
258 }
259
260 fn maybe_update_flow_dirty_window(&self, table_info: TableInfoRef, record_batch: RecordBatch) {
261 let table_id = table_info.table_id();
262 let table_flownode_set_cache = self.table_flownode_set_cache.clone();
263 let node_manager = self.node_manager.clone();
264 common_runtime::spawn_global(async move {
265 let result = table_flownode_set_cache
266 .get(table_id)
267 .await
268 .context(error::RequestInsertsSnafu);
269 let flownodes = match result {
270 Ok(flownodes) => flownodes.unwrap_or_default(),
271 Err(e) => {
272 error!(e; "Failed to get flownodes for table id: {}", table_id);
273 return;
274 }
275 };
276
277 let peers: HashSet<_> = flownodes.values().cloned().collect();
278 if peers.is_empty() {
279 return;
280 }
281
282 let Ok(timestamps) = extract_timestamps(
283 &record_batch,
284 &table_info
285 .meta
286 .schema
287 .timestamp_column()
288 .as_ref()
289 .unwrap()
290 .name,
291 )
292 .inspect_err(|e| {
293 error!(e; "Failed to extract timestamps from record batch");
294 }) else {
295 return;
296 };
297
298 for peer in peers {
299 let node_manager = node_manager.clone();
300 let timestamps = timestamps.clone();
301 common_runtime::spawn_global(async move {
302 if let Err(e) = node_manager
303 .flownode(&peer)
304 .await
305 .handle_mark_window_dirty(DirtyWindowRequest {
306 table_id,
307 timestamps,
308 })
309 .await
310 .context(error::RequestInsertsSnafu)
311 {
312 error!(e; "Failed to mark timestamps as dirty, table: {}", table_id);
313 }
314 });
315 }
316 });
317 }
318}
319
320fn extract_timestamps(rb: &RecordBatch, timestamp_index_name: &str) -> error::Result<Vec<i64>> {
322 let ts_col = rb
323 .column_by_name(timestamp_index_name)
324 .context(error::ColumnNotFoundSnafu {
325 msg: timestamp_index_name,
326 })?;
327 if rb.num_rows() == 0 {
328 return Ok(vec![]);
329 }
330 let (primitive, _) =
331 datatypes::timestamp::timestamp_array_to_primitive(ts_col).with_context(|| {
332 error::InvalidTimeIndexTypeSnafu {
333 ty: ts_col.data_type().clone(),
334 }
335 })?;
336 Ok(primitive.iter().flatten().collect())
337}