operator/
bulk_insert.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use ahash::{HashMap, HashMapExt};
18use api::v1::region::{
19    bulk_insert_request, region_request, ArrowIpc, BulkInsertRequest, RegionRequest,
20    RegionRequestHeader,
21};
22use bytes::Bytes;
23use common_base::AffectedRows;
24use common_grpc::flight::{FlightDecoder, FlightEncoder, FlightMessage};
25use common_grpc::FlightData;
26use common_recordbatch::RecordBatch;
27use common_telemetry::tracing_context::TracingContext;
28use datatypes::schema::Schema;
29use prost::Message;
30use snafu::ResultExt;
31use store_api::storage::RegionId;
32use table::metadata::TableId;
33
34use crate::insert::Inserter;
35use crate::{error, metrics};
36
37impl Inserter {
38    /// Handle bulk insert request.
39    pub async fn handle_bulk_insert(
40        &self,
41        table_id: TableId,
42        decoder: &mut FlightDecoder,
43        data: FlightData,
44    ) -> error::Result<AffectedRows> {
45        let decode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
46            .with_label_values(&["decode_request"])
47            .start_timer();
48        let body_size = data.data_body.len();
49        // Build region server requests
50        let message = decoder
51            .try_decode(&data)
52            .context(error::DecodeFlightDataSnafu)?;
53        let FlightMessage::Recordbatch(rb) = message else {
54            return Ok(0);
55        };
56        let record_batch = rb.df_record_batch();
57        decode_timer.observe_duration();
58        metrics::BULK_REQUEST_MESSAGE_SIZE.observe(body_size as f64);
59        metrics::BULK_REQUEST_ROWS
60            .with_label_values(&["raw"])
61            .observe(record_batch.num_rows() as f64);
62
63        // todo(hl): find a way to embed raw FlightData messages in greptimedb proto files so we don't have to encode here.
64
65        // safety: when reach here schema must be present.
66        let schema_message = FlightEncoder::default()
67            .encode(FlightMessage::Schema(decoder.schema().unwrap().clone()));
68        let schema_bytes = Bytes::from(schema_message.encode_to_vec());
69
70        let partition_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
71            .with_label_values(&["partition"])
72            .start_timer();
73        let partition_rule = self
74            .partition_manager
75            .find_table_partition_rule(table_id)
76            .await
77            .context(error::InvalidPartitionSnafu)?;
78
79        // find partitions for each row in the record batch
80        let region_masks = partition_rule
81            .split_record_batch(record_batch)
82            .context(error::SplitInsertSnafu)?;
83        partition_timer.observe_duration();
84
85        // fast path: only one region.
86        if region_masks.len() == 1 {
87            metrics::BULK_REQUEST_ROWS
88                .with_label_values(&["rows_per_region"])
89                .observe(record_batch.num_rows() as f64);
90
91            // SAFETY: region masks length checked
92            let (region_number, _) = region_masks.into_iter().next().unwrap();
93            let region_id = RegionId::new(table_id, region_number);
94            let datanode = self
95                .partition_manager
96                .find_region_leader(region_id)
97                .await
98                .context(error::FindRegionLeaderSnafu)?;
99            let payload = {
100                let _encode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
101                    .with_label_values(&["encode"])
102                    .start_timer();
103                Bytes::from(data.encode_to_vec())
104            };
105            let request = RegionRequest {
106                header: Some(RegionRequestHeader {
107                    tracing_context: TracingContext::from_current_span().to_w3c(),
108                    ..Default::default()
109                }),
110                body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
111                    body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
112                        region_id: region_id.as_u64(),
113                        schema: schema_bytes,
114                        payload,
115                    })),
116                })),
117            };
118
119            let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
120                .with_label_values(&["datanode_handle"])
121                .start_timer();
122            let datanode = self.node_manager.datanode(&datanode).await;
123            return datanode
124                .handle(request)
125                .await
126                .context(error::RequestRegionSnafu)
127                .map(|r| r.affected_rows);
128        }
129
130        let mut mask_per_datanode = HashMap::with_capacity(region_masks.len());
131        for (region_number, mask) in region_masks {
132            let region_id = RegionId::new(table_id, region_number);
133            let datanode = self
134                .partition_manager
135                .find_region_leader(region_id)
136                .await
137                .context(error::FindRegionLeaderSnafu)?;
138            mask_per_datanode
139                .entry(datanode)
140                .or_insert_with(Vec::new)
141                .push((region_id, mask));
142        }
143
144        let wait_all_datanode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
145            .with_label_values(&["wait_all_datanode"])
146            .start_timer();
147
148        let mut handles = Vec::with_capacity(mask_per_datanode.len());
149        let record_batch_schema =
150            Arc::new(Schema::try_from(record_batch.schema()).context(error::ConvertSchemaSnafu)?);
151
152        let mut raw_data_bytes = None;
153        for (peer, masks) in mask_per_datanode {
154            for (region_id, mask) in masks {
155                let rb = record_batch.clone();
156                let schema_bytes = schema_bytes.clone();
157                let record_batch_schema = record_batch_schema.clone();
158                let node_manager = self.node_manager.clone();
159                let peer = peer.clone();
160                let raw_data = if mask.select_all() {
161                    Some(
162                        raw_data_bytes
163                            .get_or_insert_with(|| Bytes::from(data.encode_to_vec()))
164                            .clone(),
165                    )
166                } else {
167                    None
168                };
169                let handle: common_runtime::JoinHandle<error::Result<api::region::RegionResponse>> =
170                    common_runtime::spawn_global(async move {
171                        let payload = if mask.select_all() {
172                            // SAFETY: raw data must be present, we can avoid re-encoding.
173                            raw_data.unwrap()
174                        } else {
175                            let filter_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
176                                .with_label_values(&["filter"])
177                                .start_timer();
178                            let rb = arrow::compute::filter_record_batch(&rb, mask.array())
179                                .context(error::ComputeArrowSnafu)?;
180                            filter_timer.observe_duration();
181                            metrics::BULK_REQUEST_ROWS
182                                .with_label_values(&["rows_per_region"])
183                                .observe(rb.num_rows() as f64);
184
185                            let encode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
186                                .with_label_values(&["encode"])
187                                .start_timer();
188                            let batch =
189                                RecordBatch::try_from_df_record_batch(record_batch_schema, rb)
190                                    .context(error::BuildRecordBatchSnafu)?;
191                            let payload = Bytes::from(
192                                FlightEncoder::default()
193                                    .encode(FlightMessage::Recordbatch(batch))
194                                    .encode_to_vec(),
195                            );
196                            encode_timer.observe_duration();
197                            payload
198                        };
199                        let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
200                            .with_label_values(&["datanode_handle"])
201                            .start_timer();
202                        let request = RegionRequest {
203                            header: Some(RegionRequestHeader {
204                                tracing_context: TracingContext::from_current_span().to_w3c(),
205                                ..Default::default()
206                            }),
207                            body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
208                                body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
209                                    region_id: region_id.as_u64(),
210                                    schema: schema_bytes,
211                                    payload,
212                                })),
213                            })),
214                        };
215
216                        let datanode = node_manager.datanode(&peer).await;
217                        datanode
218                            .handle(request)
219                            .await
220                            .context(error::RequestRegionSnafu)
221                    });
222                handles.push(handle);
223            }
224        }
225
226        let region_responses = futures::future::try_join_all(handles)
227            .await
228            .context(error::JoinTaskSnafu)?;
229        wait_all_datanode_timer.observe_duration();
230        let mut rows_inserted: usize = 0;
231        for res in region_responses {
232            rows_inserted += res?.affected_rows;
233        }
234        Ok(rows_inserted)
235    }
236}