1use std::sync::Arc;
16
17use ahash::{HashMap, HashMapExt};
18use api::v1::region::{
19 bulk_insert_request, region_request, ArrowIpc, BulkInsertRequest, RegionRequest,
20 RegionRequestHeader,
21};
22use bytes::Bytes;
23use common_base::AffectedRows;
24use common_grpc::flight::{FlightDecoder, FlightEncoder, FlightMessage};
25use common_grpc::FlightData;
26use common_recordbatch::RecordBatch;
27use common_telemetry::tracing_context::TracingContext;
28use datatypes::schema::Schema;
29use prost::Message;
30use snafu::ResultExt;
31use store_api::storage::RegionId;
32use table::metadata::TableId;
33
34use crate::insert::Inserter;
35use crate::{error, metrics};
36
37impl Inserter {
38 pub async fn handle_bulk_insert(
40 &self,
41 table_id: TableId,
42 decoder: &mut FlightDecoder,
43 data: FlightData,
44 ) -> error::Result<AffectedRows> {
45 let decode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
46 .with_label_values(&["decode_request"])
47 .start_timer();
48 let body_size = data.data_body.len();
49 let message = decoder
51 .try_decode(&data)
52 .context(error::DecodeFlightDataSnafu)?;
53 let FlightMessage::Recordbatch(rb) = message else {
54 return Ok(0);
55 };
56 let record_batch = rb.df_record_batch();
57 decode_timer.observe_duration();
58 metrics::BULK_REQUEST_MESSAGE_SIZE.observe(body_size as f64);
59 metrics::BULK_REQUEST_ROWS
60 .with_label_values(&["raw"])
61 .observe(record_batch.num_rows() as f64);
62
63 let schema_message = FlightEncoder::default()
67 .encode(FlightMessage::Schema(decoder.schema().unwrap().clone()));
68 let schema_bytes = Bytes::from(schema_message.encode_to_vec());
69
70 let partition_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
71 .with_label_values(&["partition"])
72 .start_timer();
73 let partition_rule = self
74 .partition_manager
75 .find_table_partition_rule(table_id)
76 .await
77 .context(error::InvalidPartitionSnafu)?;
78
79 let region_masks = partition_rule
81 .split_record_batch(record_batch)
82 .context(error::SplitInsertSnafu)?;
83 partition_timer.observe_duration();
84
85 if region_masks.len() == 1 {
87 metrics::BULK_REQUEST_ROWS
88 .with_label_values(&["rows_per_region"])
89 .observe(record_batch.num_rows() as f64);
90
91 let (region_number, _) = region_masks.into_iter().next().unwrap();
93 let region_id = RegionId::new(table_id, region_number);
94 let datanode = self
95 .partition_manager
96 .find_region_leader(region_id)
97 .await
98 .context(error::FindRegionLeaderSnafu)?;
99 let payload = {
100 let _encode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
101 .with_label_values(&["encode"])
102 .start_timer();
103 Bytes::from(data.encode_to_vec())
104 };
105 let request = RegionRequest {
106 header: Some(RegionRequestHeader {
107 tracing_context: TracingContext::from_current_span().to_w3c(),
108 ..Default::default()
109 }),
110 body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
111 body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
112 region_id: region_id.as_u64(),
113 schema: schema_bytes,
114 payload,
115 })),
116 })),
117 };
118
119 let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
120 .with_label_values(&["datanode_handle"])
121 .start_timer();
122 let datanode = self.node_manager.datanode(&datanode).await;
123 return datanode
124 .handle(request)
125 .await
126 .context(error::RequestRegionSnafu)
127 .map(|r| r.affected_rows);
128 }
129
130 let mut mask_per_datanode = HashMap::with_capacity(region_masks.len());
131 for (region_number, mask) in region_masks {
132 let region_id = RegionId::new(table_id, region_number);
133 let datanode = self
134 .partition_manager
135 .find_region_leader(region_id)
136 .await
137 .context(error::FindRegionLeaderSnafu)?;
138 mask_per_datanode
139 .entry(datanode)
140 .or_insert_with(Vec::new)
141 .push((region_id, mask));
142 }
143
144 let wait_all_datanode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
145 .with_label_values(&["wait_all_datanode"])
146 .start_timer();
147
148 let mut handles = Vec::with_capacity(mask_per_datanode.len());
149 let record_batch_schema =
150 Arc::new(Schema::try_from(record_batch.schema()).context(error::ConvertSchemaSnafu)?);
151
152 let mut raw_data_bytes = None;
153 for (peer, masks) in mask_per_datanode {
154 for (region_id, mask) in masks {
155 let rb = record_batch.clone();
156 let schema_bytes = schema_bytes.clone();
157 let record_batch_schema = record_batch_schema.clone();
158 let node_manager = self.node_manager.clone();
159 let peer = peer.clone();
160 let raw_data = if mask.select_all() {
161 Some(
162 raw_data_bytes
163 .get_or_insert_with(|| Bytes::from(data.encode_to_vec()))
164 .clone(),
165 )
166 } else {
167 None
168 };
169 let handle: common_runtime::JoinHandle<error::Result<api::region::RegionResponse>> =
170 common_runtime::spawn_global(async move {
171 let payload = if mask.select_all() {
172 raw_data.unwrap()
174 } else {
175 let filter_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
176 .with_label_values(&["filter"])
177 .start_timer();
178 let rb = arrow::compute::filter_record_batch(&rb, mask.array())
179 .context(error::ComputeArrowSnafu)?;
180 filter_timer.observe_duration();
181 metrics::BULK_REQUEST_ROWS
182 .with_label_values(&["rows_per_region"])
183 .observe(rb.num_rows() as f64);
184
185 let encode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
186 .with_label_values(&["encode"])
187 .start_timer();
188 let batch =
189 RecordBatch::try_from_df_record_batch(record_batch_schema, rb)
190 .context(error::BuildRecordBatchSnafu)?;
191 let payload = Bytes::from(
192 FlightEncoder::default()
193 .encode(FlightMessage::Recordbatch(batch))
194 .encode_to_vec(),
195 );
196 encode_timer.observe_duration();
197 payload
198 };
199 let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
200 .with_label_values(&["datanode_handle"])
201 .start_timer();
202 let request = RegionRequest {
203 header: Some(RegionRequestHeader {
204 tracing_context: TracingContext::from_current_span().to_w3c(),
205 ..Default::default()
206 }),
207 body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
208 body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
209 region_id: region_id.as_u64(),
210 schema: schema_bytes,
211 payload,
212 })),
213 })),
214 };
215
216 let datanode = node_manager.datanode(&peer).await;
217 datanode
218 .handle(request)
219 .await
220 .context(error::RequestRegionSnafu)
221 });
222 handles.push(handle);
223 }
224 }
225
226 let region_responses = futures::future::try_join_all(handles)
227 .await
228 .context(error::JoinTaskSnafu)?;
229 wait_all_datanode_timer.observe_duration();
230 let mut rows_inserted: usize = 0;
231 for res in region_responses {
232 rows_inserted += res?.affected_rows;
233 }
234 Ok(rows_inserted)
235 }
236}