1use std::collections::HashSet;
16
17use ahash::{HashMap, HashMapExt};
18use api::v1::ArrowIpc;
19use api::v1::flow::{DirtyWindowRequest, DirtyWindowRequests};
20use api::v1::region::{
21 BulkInsertRequest, RegionRequest, RegionRequestHeader, bulk_insert_request, region_request,
22};
23use arrow::array::Array;
24use arrow::record_batch::RecordBatch;
25use bytes::Bytes;
26use common_base::AffectedRows;
27use common_grpc::FlightData;
28use common_grpc::flight::{FlightEncoder, FlightMessage};
29use common_telemetry::error;
30use common_telemetry::tracing_context::TracingContext;
31use snafu::{OptionExt, ResultExt, ensure};
32use store_api::storage::RegionId;
33use table::TableRef;
34use table::metadata::TableInfoRef;
35
36use crate::insert::Inserter;
37use crate::{error, metrics};
38
39impl Inserter {
40 pub async fn handle_bulk_insert(
42 &self,
43 table: TableRef,
44 raw_flight_data: FlightData,
45 record_batch: RecordBatch,
46 schema_bytes: Bytes,
47 ) -> error::Result<AffectedRows> {
48 let table_info = table.table_info();
49 let table_id = table_info.table_id();
50 let db_name = table_info.get_db_string();
51
52 if record_batch.num_rows() == 0 {
53 return Ok(0);
54 }
55
56 let body_size = raw_flight_data.data_body.len();
57 self.maybe_update_flow_dirty_window(table_info.clone(), record_batch.clone());
60
61 metrics::BULK_REQUEST_MESSAGE_SIZE.observe(body_size as f64);
62 metrics::BULK_REQUEST_ROWS
63 .with_label_values(&["raw"])
64 .observe(record_batch.num_rows() as f64);
65
66 let partition_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
67 .with_label_values(&["partition"])
68 .start_timer();
69 let partition_rule = self
70 .partition_manager
71 .find_table_partition_rule(&table_info)
72 .await
73 .context(error::InvalidPartitionSnafu)?;
74
75 let region_masks = partition_rule
77 .split_record_batch(&record_batch)
78 .context(error::SplitInsertSnafu)?;
79 partition_timer.observe_duration();
80
81 if region_masks.len() == 1 {
83 metrics::BULK_REQUEST_ROWS
84 .with_label_values(&["rows_per_region"])
85 .observe(record_batch.num_rows() as f64);
86
87 let (region_number, _) = region_masks.into_iter().next().unwrap();
89 let region_id = RegionId::new(table_id, region_number);
90 let datanode = self
91 .partition_manager
92 .find_region_leader(region_id)
93 .await
94 .context(error::FindRegionLeaderSnafu)?;
95
96 let request = RegionRequest {
97 header: Some(RegionRequestHeader {
98 tracing_context: TracingContext::from_current_span().to_w3c(),
99 ..Default::default()
100 }),
101 body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
102 region_id: region_id.as_u64(),
103 body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
104 schema: schema_bytes.clone(),
105 data_header: raw_flight_data.data_header,
106 payload: raw_flight_data.data_body,
107 })),
108 })),
109 };
110
111 let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
112 .with_label_values(&["datanode_handle"])
113 .start_timer();
114 let datanode = self.node_manager.datanode(&datanode).await;
115 let result = datanode
116 .handle(request)
117 .await
118 .context(error::RequestRegionSnafu)
119 .map(|r| r.affected_rows);
120 if let Ok(rows) = result {
121 crate::metrics::DIST_INGEST_ROW_COUNT
122 .with_label_values(&[db_name.as_str()])
123 .inc_by(rows as u64);
124 }
125 return result;
126 }
127
128 let mut mask_per_datanode = HashMap::with_capacity(region_masks.len());
129 for (region_number, mask) in region_masks {
130 let region_id = RegionId::new(table_id, region_number);
131 let datanode = self
132 .partition_manager
133 .find_region_leader(region_id)
134 .await
135 .context(error::FindRegionLeaderSnafu)?;
136 mask_per_datanode
137 .entry(datanode)
138 .or_insert_with(Vec::new)
139 .push((region_id, mask));
140 }
141
142 let wait_all_datanode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
143 .with_label_values(&["wait_all_datanode"])
144 .start_timer();
145
146 let mut handles = Vec::with_capacity(mask_per_datanode.len());
147
148 for (peer, masks) in mask_per_datanode {
149 for (region_id, mask) in masks {
150 if mask.select_none() {
151 continue;
152 }
153 let rb = record_batch.clone();
154 let schema_bytes = schema_bytes.clone();
155 let node_manager = self.node_manager.clone();
156 let peer = peer.clone();
157 let raw_header_and_data = if mask.select_all() {
158 Some((
159 raw_flight_data.data_header.clone(),
160 raw_flight_data.data_body.clone(),
161 ))
162 } else {
163 None
164 };
165 let handle: common_runtime::JoinHandle<error::Result<api::region::RegionResponse>> =
166 common_runtime::spawn_global(async move {
167 let (header, payload) = if mask.select_all() {
168 raw_header_and_data.unwrap()
170 } else {
171 let filter_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
172 .with_label_values(&["filter"])
173 .start_timer();
174 let batch = arrow::compute::filter_record_batch(&rb, mask.array())
175 .context(error::ComputeArrowSnafu)?;
176 filter_timer.observe_duration();
177 metrics::BULK_REQUEST_ROWS
178 .with_label_values(&["rows_per_region"])
179 .observe(batch.num_rows() as f64);
180
181 let encode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
182 .with_label_values(&["encode"])
183 .start_timer();
184 let mut iter = FlightEncoder::default()
185 .encode(FlightMessage::RecordBatch(batch))
186 .into_iter();
187 let Some(flight_data) = iter.next() else {
188 unreachable!()
191 };
192 ensure!(
193 iter.next().is_none(),
194 error::NotSupportedSnafu {
195 feat: "bulk insert RecordBatch with dictionary arrays",
196 }
197 );
198 encode_timer.observe_duration();
199 (flight_data.data_header, flight_data.data_body)
200 };
201 let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
202 .with_label_values(&["datanode_handle"])
203 .start_timer();
204 let request = RegionRequest {
205 header: Some(RegionRequestHeader {
206 tracing_context: TracingContext::from_current_span().to_w3c(),
207 ..Default::default()
208 }),
209 body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
210 region_id: region_id.as_u64(),
211 body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
212 schema: schema_bytes,
213 data_header: header,
214 payload,
215 })),
216 })),
217 };
218
219 let datanode = node_manager.datanode(&peer).await;
220 datanode
221 .handle(request)
222 .await
223 .context(error::RequestRegionSnafu)
224 });
225 handles.push(handle);
226 }
227 }
228
229 let region_responses = futures::future::try_join_all(handles)
230 .await
231 .context(error::JoinTaskSnafu)?;
232 wait_all_datanode_timer.observe_duration();
233 let mut rows_inserted: usize = 0;
234 for res in region_responses {
235 rows_inserted += res?.affected_rows;
236 }
237 crate::metrics::DIST_INGEST_ROW_COUNT
238 .with_label_values(&[db_name.as_str()])
239 .inc_by(rows_inserted as u64);
240 Ok(rows_inserted)
241 }
242
243 fn maybe_update_flow_dirty_window(&self, table_info: TableInfoRef, record_batch: RecordBatch) {
244 let table_id = table_info.table_id();
245 let table_flownode_set_cache = self.table_flownode_set_cache.clone();
246 let node_manager = self.node_manager.clone();
247 common_runtime::spawn_global(async move {
248 let result = table_flownode_set_cache
249 .get(table_id)
250 .await
251 .context(error::RequestInsertsSnafu);
252 let flownodes = match result {
253 Ok(flownodes) => flownodes.unwrap_or_default(),
254 Err(e) => {
255 error!(e; "Failed to get flownodes for table id: {}", table_id);
256 return;
257 }
258 };
259
260 let peers: HashSet<_> = flownodes.values().cloned().collect();
261 if peers.is_empty() {
262 return;
263 }
264
265 let Ok(timestamps) = extract_timestamps(
266 &record_batch,
267 &table_info
268 .meta
269 .schema
270 .timestamp_column()
271 .as_ref()
272 .unwrap()
273 .name,
274 )
275 .inspect_err(|e| {
276 error!(e; "Failed to extract timestamps from record batch");
277 }) else {
278 return;
279 };
280
281 for peer in peers {
282 let node_manager = node_manager.clone();
283 let timestamps = timestamps.clone();
284 common_runtime::spawn_global(async move {
285 if let Err(e) = node_manager
286 .flownode(&peer)
287 .await
288 .handle_mark_window_dirty(DirtyWindowRequests {
289 requests: vec![DirtyWindowRequest {
290 table_id,
291 timestamps,
292 }],
293 })
294 .await
295 .context(error::RequestInsertsSnafu)
296 {
297 error!(e; "Failed to mark timestamps as dirty, table: {}", table_id);
298 }
299 });
300 }
301 });
302 }
303}
304
305fn extract_timestamps(rb: &RecordBatch, timestamp_index_name: &str) -> error::Result<Vec<i64>> {
307 let ts_col = rb
308 .column_by_name(timestamp_index_name)
309 .context(error::ColumnNotFoundSnafu {
310 msg: timestamp_index_name,
311 })?;
312 if rb.num_rows() == 0 {
313 return Ok(vec![]);
314 }
315 let (primitive, _) =
316 datatypes::timestamp::timestamp_array_to_primitive(ts_col).with_context(|| {
317 error::InvalidTimeIndexTypeSnafu {
318 ty: ts_col.data_type().clone(),
319 }
320 })?;
321 Ok(primitive.iter().flatten().collect())
322}