1use std::collections::HashSet;
16
17use ahash::{HashMap, HashMapExt};
18use api::v1::flow::{DirtyWindowRequest, DirtyWindowRequests};
19use api::v1::region::{
20 BulkInsertRequest, RegionRequest, RegionRequestHeader, bulk_insert_request, region_request,
21};
22use api::v1::{ArrowIpc, PartitionExprVersion};
23use arrow::array::Array;
24use arrow::record_batch::RecordBatch;
25use bytes::Bytes;
26use common_base::AffectedRows;
27use common_grpc::FlightData;
28use common_grpc::flight::{FlightEncoder, FlightMessage};
29use common_telemetry::error;
30use common_telemetry::tracing_context::TracingContext;
31use snafu::{OptionExt, ResultExt, ensure};
32use store_api::storage::RegionId;
33use table::TableRef;
34use table::metadata::TableInfoRef;
35
36use crate::insert::Inserter;
37use crate::{error, metrics};
38
39impl Inserter {
40 pub async fn handle_bulk_insert(
42 &self,
43 table: TableRef,
44 raw_flight_data: FlightData,
45 record_batch: RecordBatch,
46 schema_bytes: Bytes,
47 ) -> error::Result<AffectedRows> {
48 let table_info = table.table_info();
49 let table_id = table_info.table_id();
50 let db_name = table_info.get_db_string();
51
52 if record_batch.num_rows() == 0 {
53 return Ok(0);
54 }
55
56 let body_size = raw_flight_data.data_body.len();
57 self.maybe_update_flow_dirty_window(table_info.clone(), record_batch.clone());
60
61 metrics::BULK_REQUEST_MESSAGE_SIZE.observe(body_size as f64);
62 metrics::BULK_REQUEST_ROWS
63 .with_label_values(&["raw"])
64 .observe(record_batch.num_rows() as f64);
65
66 let partition_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
67 .with_label_values(&["partition"])
68 .start_timer();
69 let (partition_rule, partition_versions) = self
70 .partition_manager
71 .find_table_partition_rule(&table_info)
72 .await
73 .context(error::InvalidPartitionSnafu)?;
74
75 let region_masks = partition_rule
77 .split_record_batch(&record_batch)
78 .context(error::SplitInsertSnafu)?;
79 partition_timer.observe_duration();
80
81 if region_masks.len() == 1 {
83 metrics::BULK_REQUEST_ROWS
84 .with_label_values(&["rows_per_region"])
85 .observe(record_batch.num_rows() as f64);
86
87 let (region_number, _) = region_masks.into_iter().next().unwrap();
89 let region_id = RegionId::new(table_id, region_number);
90 let partition_expr_version = partition_versions
91 .get(®ion_number)
92 .copied()
93 .unwrap_or_default();
94 let datanode = self
95 .partition_manager
96 .find_region_leader(region_id)
97 .await
98 .context(error::FindRegionLeaderSnafu)?;
99
100 let request = RegionRequest {
101 header: Some(RegionRequestHeader {
102 tracing_context: TracingContext::from_current_span().to_w3c(),
103 ..Default::default()
104 }),
105 body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
106 region_id: region_id.as_u64(),
107 partition_expr_version: partition_expr_version
108 .map(|value| PartitionExprVersion { value }),
109 body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
110 schema: schema_bytes.clone(),
111 data_header: raw_flight_data.data_header,
112 payload: raw_flight_data.data_body,
113 })),
114 })),
115 };
116
117 let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
118 .with_label_values(&["datanode_handle"])
119 .start_timer();
120 let datanode = self.node_manager.datanode(&datanode).await;
121 let result = datanode
122 .handle(request)
123 .await
124 .context(error::RequestRegionSnafu)
125 .map(|r| r.affected_rows);
126 if let Ok(rows) = result {
127 crate::metrics::DIST_INGEST_ROW_COUNT
128 .with_label_values(&[db_name.as_str()])
129 .inc_by(rows as u64);
130 }
131 return result;
132 }
133
134 let mut mask_per_datanode = HashMap::with_capacity(region_masks.len());
135 for (region_number, mask) in region_masks {
136 let region_id = RegionId::new(table_id, region_number);
137 let datanode = self
138 .partition_manager
139 .find_region_leader(region_id)
140 .await
141 .context(error::FindRegionLeaderSnafu)?;
142 mask_per_datanode
143 .entry(datanode)
144 .or_insert_with(Vec::new)
145 .push((region_id, mask));
146 }
147
148 let wait_all_datanode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
149 .with_label_values(&["wait_all_datanode"])
150 .start_timer();
151
152 let mut handles = Vec::with_capacity(mask_per_datanode.len());
153
154 for (peer, masks) in mask_per_datanode {
155 for (region_id, mask) in masks {
156 if mask.select_none() {
157 continue;
158 }
159 let partition_expr_version = partition_versions
160 .get(®ion_id.region_number())
161 .copied()
162 .unwrap_or_default();
163 let rb = record_batch.clone();
164 let schema_bytes = schema_bytes.clone();
165 let node_manager = self.node_manager.clone();
166 let peer = peer.clone();
167 let raw_header_and_data = if mask.select_all() {
168 Some((
169 raw_flight_data.data_header.clone(),
170 raw_flight_data.data_body.clone(),
171 ))
172 } else {
173 None
174 };
175 let handle: common_runtime::JoinHandle<error::Result<api::region::RegionResponse>> =
176 common_runtime::spawn_global(async move {
177 let (header, payload) = if mask.select_all() {
178 raw_header_and_data.unwrap()
180 } else {
181 let filter_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
182 .with_label_values(&["filter"])
183 .start_timer();
184 let batch = arrow::compute::filter_record_batch(&rb, mask.array())
185 .context(error::ComputeArrowSnafu)?;
186 filter_timer.observe_duration();
187 metrics::BULK_REQUEST_ROWS
188 .with_label_values(&["rows_per_region"])
189 .observe(batch.num_rows() as f64);
190
191 let encode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
192 .with_label_values(&["encode"])
193 .start_timer();
194 let mut iter = FlightEncoder::default()
195 .encode(FlightMessage::RecordBatch(batch))
196 .into_iter();
197 let Some(flight_data) = iter.next() else {
198 unreachable!()
201 };
202 ensure!(
203 iter.next().is_none(),
204 error::NotSupportedSnafu {
205 feat: "bulk insert RecordBatch with dictionary arrays",
206 }
207 );
208 encode_timer.observe_duration();
209 (flight_data.data_header, flight_data.data_body)
210 };
211 let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
212 .with_label_values(&["datanode_handle"])
213 .start_timer();
214 let request = RegionRequest {
215 header: Some(RegionRequestHeader {
216 tracing_context: TracingContext::from_current_span().to_w3c(),
217 ..Default::default()
218 }),
219 body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
220 region_id: region_id.as_u64(),
221 partition_expr_version: partition_expr_version
222 .map(|value| PartitionExprVersion { value }),
223 body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
224 schema: schema_bytes,
225 data_header: header,
226 payload,
227 })),
228 })),
229 };
230
231 let datanode = node_manager.datanode(&peer).await;
232 datanode
233 .handle(request)
234 .await
235 .context(error::RequestRegionSnafu)
236 });
237 handles.push(handle);
238 }
239 }
240
241 let region_responses = futures::future::try_join_all(handles)
242 .await
243 .context(error::JoinTaskSnafu)?;
244 wait_all_datanode_timer.observe_duration();
245 let mut rows_inserted: usize = 0;
246 for res in region_responses {
247 rows_inserted += res?.affected_rows;
248 }
249 crate::metrics::DIST_INGEST_ROW_COUNT
250 .with_label_values(&[db_name.as_str()])
251 .inc_by(rows_inserted as u64);
252 Ok(rows_inserted)
253 }
254
255 fn maybe_update_flow_dirty_window(&self, table_info: TableInfoRef, record_batch: RecordBatch) {
256 let table_id = table_info.table_id();
257 let table_flownode_set_cache = self.table_flownode_set_cache.clone();
258 let node_manager = self.node_manager.clone();
259 common_runtime::spawn_global(async move {
260 let result = table_flownode_set_cache
261 .get(table_id)
262 .await
263 .context(error::RequestInsertsSnafu);
264 let flownodes = match result {
265 Ok(flownodes) => flownodes.unwrap_or_default(),
266 Err(e) => {
267 error!(e; "Failed to get flownodes for table id: {}", table_id);
268 return;
269 }
270 };
271
272 let peers: HashSet<_> = flownodes.values().cloned().collect();
273 if peers.is_empty() {
274 return;
275 }
276
277 let Ok(timestamps) = extract_timestamps(
278 &record_batch,
279 &table_info
280 .meta
281 .schema
282 .timestamp_column()
283 .as_ref()
284 .unwrap()
285 .name,
286 )
287 .inspect_err(|e| {
288 error!(e; "Failed to extract timestamps from record batch");
289 }) else {
290 return;
291 };
292
293 for peer in peers {
294 let node_manager = node_manager.clone();
295 let timestamps = timestamps.clone();
296 common_runtime::spawn_global(async move {
297 if let Err(e) = node_manager
298 .flownode(&peer)
299 .await
300 .handle_mark_window_dirty(DirtyWindowRequests {
301 requests: vec![DirtyWindowRequest {
302 table_id,
303 timestamps,
304 }],
305 })
306 .await
307 .context(error::RequestInsertsSnafu)
308 {
309 error!(e; "Failed to mark timestamps as dirty, table: {}", table_id);
310 }
311 });
312 }
313 });
314 }
315}
316
317fn extract_timestamps(rb: &RecordBatch, timestamp_index_name: &str) -> error::Result<Vec<i64>> {
319 let ts_col = rb
320 .column_by_name(timestamp_index_name)
321 .context(error::ColumnNotFoundSnafu {
322 msg: timestamp_index_name,
323 })?;
324 if rb.num_rows() == 0 {
325 return Ok(vec![]);
326 }
327 let (primitive, _) =
328 datatypes::timestamp::timestamp_array_to_primitive(ts_col).with_context(|| {
329 error::InvalidTimeIndexTypeSnafu {
330 ty: ts_col.data_type().clone(),
331 }
332 })?;
333 Ok(primitive.iter().flatten().collect())
334}