1use std::collections::HashSet;
16
17use ahash::{HashMap, HashMapExt};
18use api::v1::flow::DirtyWindowRequest;
19use api::v1::region::{
20 bulk_insert_request, region_request, BulkInsertRequest, RegionRequest, RegionRequestHeader,
21};
22use api::v1::ArrowIpc;
23use arrow::array::{
24 Array, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
25 TimestampSecondArray,
26};
27use arrow::datatypes::{DataType, Int64Type, TimeUnit};
28use arrow::record_batch::RecordBatch;
29use common_base::AffectedRows;
30use common_grpc::flight::{FlightDecoder, FlightEncoder, FlightMessage};
31use common_grpc::FlightData;
32use common_telemetry::error;
33use common_telemetry::tracing_context::TracingContext;
34use snafu::{OptionExt, ResultExt};
35use store_api::storage::RegionId;
36use table::metadata::TableInfoRef;
37use table::TableRef;
38
39use crate::insert::Inserter;
40use crate::{error, metrics};
41
42impl Inserter {
43 pub async fn handle_bulk_insert(
45 &self,
46 table: TableRef,
47 decoder: &mut FlightDecoder,
48 data: FlightData,
49 ) -> error::Result<AffectedRows> {
50 let table_info = table.table_info();
51 let table_id = table_info.table_id();
52 let decode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
53 .with_label_values(&["decode_request"])
54 .start_timer();
55 let body_size = data.data_body.len();
56 let message = decoder
58 .try_decode(&data)
59 .context(error::DecodeFlightDataSnafu)?;
60 let FlightMessage::RecordBatch(record_batch) = message else {
61 return Ok(0);
62 };
63 decode_timer.observe_duration();
64
65 self.maybe_update_flow_dirty_window(table_info, record_batch.clone());
67
68 metrics::BULK_REQUEST_MESSAGE_SIZE.observe(body_size as f64);
69 metrics::BULK_REQUEST_ROWS
70 .with_label_values(&["raw"])
71 .observe(record_batch.num_rows() as f64);
72
73 let schema_bytes = decoder.schema_bytes().unwrap();
75 let partition_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
76 .with_label_values(&["partition"])
77 .start_timer();
78 let partition_rule = self
79 .partition_manager
80 .find_table_partition_rule(table_id)
81 .await
82 .context(error::InvalidPartitionSnafu)?;
83
84 let region_masks = partition_rule
86 .split_record_batch(&record_batch)
87 .context(error::SplitInsertSnafu)?;
88 partition_timer.observe_duration();
89
90 if region_masks.len() == 1 {
92 metrics::BULK_REQUEST_ROWS
93 .with_label_values(&["rows_per_region"])
94 .observe(record_batch.num_rows() as f64);
95
96 let (region_number, _) = region_masks.into_iter().next().unwrap();
98 let region_id = RegionId::new(table_id, region_number);
99 let datanode = self
100 .partition_manager
101 .find_region_leader(region_id)
102 .await
103 .context(error::FindRegionLeaderSnafu)?;
104 let request = RegionRequest {
105 header: Some(RegionRequestHeader {
106 tracing_context: TracingContext::from_current_span().to_w3c(),
107 ..Default::default()
108 }),
109 body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
110 region_id: region_id.as_u64(),
111 body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
112 schema: schema_bytes,
113 data_header: data.data_header,
114 payload: data.data_body,
115 })),
116 })),
117 };
118
119 let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
120 .with_label_values(&["datanode_handle"])
121 .start_timer();
122 let datanode = self.node_manager.datanode(&datanode).await;
123 let result = datanode
124 .handle(request)
125 .await
126 .context(error::RequestRegionSnafu)
127 .map(|r| r.affected_rows);
128 if let Ok(rows) = result {
129 crate::metrics::DIST_INGEST_ROW_COUNT.inc_by(rows as u64);
130 }
131 return result;
132 }
133
134 let mut mask_per_datanode = HashMap::with_capacity(region_masks.len());
135 for (region_number, mask) in region_masks {
136 let region_id = RegionId::new(table_id, region_number);
137 let datanode = self
138 .partition_manager
139 .find_region_leader(region_id)
140 .await
141 .context(error::FindRegionLeaderSnafu)?;
142 mask_per_datanode
143 .entry(datanode)
144 .or_insert_with(Vec::new)
145 .push((region_id, mask));
146 }
147
148 let wait_all_datanode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
149 .with_label_values(&["wait_all_datanode"])
150 .start_timer();
151
152 let mut handles = Vec::with_capacity(mask_per_datanode.len());
153
154 let mut raw_data_bytes = None;
156 for (peer, masks) in mask_per_datanode {
157 for (region_id, mask) in masks {
158 let rb = record_batch.clone();
159 let schema_bytes = schema_bytes.clone();
160 let node_manager = self.node_manager.clone();
161 let peer = peer.clone();
162 let raw_header_and_data = if mask.select_all() {
163 Some(
164 raw_data_bytes
165 .get_or_insert_with(|| {
166 (data.data_header.clone(), data.data_body.clone())
167 })
168 .clone(),
169 )
170 } else {
171 None
172 };
173 let handle: common_runtime::JoinHandle<error::Result<api::region::RegionResponse>> =
174 common_runtime::spawn_global(async move {
175 let (header, payload) = if mask.select_all() {
176 raw_header_and_data.unwrap()
178 } else {
179 let filter_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
180 .with_label_values(&["filter"])
181 .start_timer();
182 let batch = arrow::compute::filter_record_batch(&rb, mask.array())
183 .context(error::ComputeArrowSnafu)?;
184 filter_timer.observe_duration();
185 metrics::BULK_REQUEST_ROWS
186 .with_label_values(&["rows_per_region"])
187 .observe(batch.num_rows() as f64);
188
189 let encode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
190 .with_label_values(&["encode"])
191 .start_timer();
192 let flight_data =
193 FlightEncoder::default().encode(FlightMessage::RecordBatch(batch));
194 encode_timer.observe_duration();
195 (flight_data.data_header, flight_data.data_body)
196 };
197 let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
198 .with_label_values(&["datanode_handle"])
199 .start_timer();
200 let request = RegionRequest {
201 header: Some(RegionRequestHeader {
202 tracing_context: TracingContext::from_current_span().to_w3c(),
203 ..Default::default()
204 }),
205 body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
206 region_id: region_id.as_u64(),
207 body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
208 schema: schema_bytes,
209 data_header: header,
210 payload,
211 })),
212 })),
213 };
214
215 let datanode = node_manager.datanode(&peer).await;
216 datanode
217 .handle(request)
218 .await
219 .context(error::RequestRegionSnafu)
220 });
221 handles.push(handle);
222 }
223 }
224
225 let region_responses = futures::future::try_join_all(handles)
226 .await
227 .context(error::JoinTaskSnafu)?;
228 wait_all_datanode_timer.observe_duration();
229 let mut rows_inserted: usize = 0;
230 for res in region_responses {
231 rows_inserted += res?.affected_rows;
232 }
233 crate::metrics::DIST_INGEST_ROW_COUNT.inc_by(rows_inserted as u64);
234 Ok(rows_inserted)
235 }
236
237 fn maybe_update_flow_dirty_window(&self, table_info: TableInfoRef, record_batch: RecordBatch) {
238 let table_id = table_info.table_id();
239 let table_flownode_set_cache = self.table_flownode_set_cache.clone();
240 let node_manager = self.node_manager.clone();
241 common_runtime::spawn_global(async move {
242 let result = table_flownode_set_cache
243 .get(table_id)
244 .await
245 .context(error::RequestInsertsSnafu);
246 let flownodes = match result {
247 Ok(flownodes) => flownodes.unwrap_or_default(),
248 Err(e) => {
249 error!(e; "Failed to get flownodes for table id: {}", table_id);
250 return;
251 }
252 };
253
254 let peers: HashSet<_> = flownodes.values().cloned().collect();
255 if peers.is_empty() {
256 return;
257 }
258
259 let Ok(timestamps) = extract_timestamps(
260 &record_batch,
261 &table_info
262 .meta
263 .schema
264 .timestamp_column()
265 .as_ref()
266 .unwrap()
267 .name,
268 )
269 .inspect_err(|e| {
270 error!(e; "Failed to extract timestamps from record batch");
271 }) else {
272 return;
273 };
274
275 for peer in peers {
276 let node_manager = node_manager.clone();
277 let timestamps = timestamps.clone();
278 common_runtime::spawn_global(async move {
279 if let Err(e) = node_manager
280 .flownode(&peer)
281 .await
282 .handle_mark_window_dirty(DirtyWindowRequest {
283 table_id,
284 timestamps,
285 })
286 .await
287 .context(error::RequestInsertsSnafu)
288 {
289 error!(e; "Failed to mark timestamps as dirty, table: {}", table_id);
290 }
291 });
292 }
293 });
294 }
295}
296
297fn extract_timestamps(rb: &RecordBatch, timestamp_index_name: &str) -> error::Result<Vec<i64>> {
299 let ts_col = rb
300 .column_by_name(timestamp_index_name)
301 .context(error::ColumnNotFoundSnafu {
302 msg: timestamp_index_name,
303 })?;
304 if rb.num_rows() == 0 {
305 return Ok(vec![]);
306 }
307 let primitive = match ts_col.data_type() {
308 DataType::Timestamp(unit, _) => match unit {
309 TimeUnit::Second => ts_col
310 .as_any()
311 .downcast_ref::<TimestampSecondArray>()
312 .unwrap()
313 .reinterpret_cast::<Int64Type>(),
314 TimeUnit::Millisecond => ts_col
315 .as_any()
316 .downcast_ref::<TimestampMillisecondArray>()
317 .unwrap()
318 .reinterpret_cast::<Int64Type>(),
319 TimeUnit::Microsecond => ts_col
320 .as_any()
321 .downcast_ref::<TimestampMicrosecondArray>()
322 .unwrap()
323 .reinterpret_cast::<Int64Type>(),
324 TimeUnit::Nanosecond => ts_col
325 .as_any()
326 .downcast_ref::<TimestampNanosecondArray>()
327 .unwrap()
328 .reinterpret_cast::<Int64Type>(),
329 },
330 t => {
331 return error::InvalidTimeIndexTypeSnafu { ty: t.clone() }.fail();
332 }
333 };
334 Ok(primitive.iter().flatten().collect())
335}