operator/
bulk_insert.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use ahash::{HashMap, HashMapExt};
18use api::v1::flow::DirtyWindowRequest;
19use api::v1::region::{
20    bulk_insert_request, region_request, BulkInsertRequest, RegionRequest, RegionRequestHeader,
21};
22use api::v1::ArrowIpc;
23use arrow::array::Array;
24use arrow::record_batch::RecordBatch;
25use common_base::AffectedRows;
26use common_grpc::flight::{FlightDecoder, FlightEncoder, FlightMessage};
27use common_grpc::FlightData;
28use common_telemetry::error;
29use common_telemetry::tracing_context::TracingContext;
30use snafu::{ensure, OptionExt, ResultExt};
31use store_api::storage::RegionId;
32use table::metadata::TableInfoRef;
33use table::TableRef;
34
35use crate::insert::Inserter;
36use crate::{error, metrics};
37
38impl Inserter {
39    /// Handle bulk insert request.
40    pub async fn handle_bulk_insert(
41        &self,
42        table: TableRef,
43        decoder: &mut FlightDecoder,
44        data: FlightData,
45    ) -> error::Result<AffectedRows> {
46        let table_info = table.table_info();
47        let table_id = table_info.table_id();
48        let db_name = table_info.get_db_string();
49        let decode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
50            .with_label_values(&["decode_request"])
51            .start_timer();
52        let body_size = data.data_body.len();
53        // Build region server requests
54        let message = decoder
55            .try_decode(&data)
56            .context(error::DecodeFlightDataSnafu)?
57            .context(error::NotSupportedSnafu {
58                feat: "bulk insert RecordBatch with dictionary arrays",
59            })?;
60        let FlightMessage::RecordBatch(record_batch) = message else {
61            return Ok(0);
62        };
63        decode_timer.observe_duration();
64
65        if record_batch.num_rows() == 0 {
66            return Ok(0);
67        }
68
69        // notify flownode to update dirty timestamps if flow is configured.
70        self.maybe_update_flow_dirty_window(table_info.clone(), record_batch.clone());
71
72        metrics::BULK_REQUEST_MESSAGE_SIZE.observe(body_size as f64);
73        metrics::BULK_REQUEST_ROWS
74            .with_label_values(&["raw"])
75            .observe(record_batch.num_rows() as f64);
76
77        // safety: when reach here schema must be present.
78        let schema_bytes = decoder.schema_bytes().unwrap();
79        let partition_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
80            .with_label_values(&["partition"])
81            .start_timer();
82        let partition_rule = self
83            .partition_manager
84            .find_table_partition_rule(&table_info)
85            .await
86            .context(error::InvalidPartitionSnafu)?;
87
88        // find partitions for each row in the record batch
89        let region_masks = partition_rule
90            .split_record_batch(&record_batch)
91            .context(error::SplitInsertSnafu)?;
92        partition_timer.observe_duration();
93
94        // fast path: only one region.
95        if region_masks.len() == 1 {
96            metrics::BULK_REQUEST_ROWS
97                .with_label_values(&["rows_per_region"])
98                .observe(record_batch.num_rows() as f64);
99
100            // SAFETY: region masks length checked
101            let (region_number, _) = region_masks.into_iter().next().unwrap();
102            let region_id = RegionId::new(table_id, region_number);
103            let datanode = self
104                .partition_manager
105                .find_region_leader(region_id)
106                .await
107                .context(error::FindRegionLeaderSnafu)?;
108            let request = RegionRequest {
109                header: Some(RegionRequestHeader {
110                    tracing_context: TracingContext::from_current_span().to_w3c(),
111                    ..Default::default()
112                }),
113                body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
114                    region_id: region_id.as_u64(),
115                    body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
116                        schema: schema_bytes,
117                        data_header: data.data_header,
118                        payload: data.data_body,
119                    })),
120                })),
121            };
122
123            let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
124                .with_label_values(&["datanode_handle"])
125                .start_timer();
126            let datanode = self.node_manager.datanode(&datanode).await;
127            let result = datanode
128                .handle(request)
129                .await
130                .context(error::RequestRegionSnafu)
131                .map(|r| r.affected_rows);
132            if let Ok(rows) = result {
133                crate::metrics::DIST_INGEST_ROW_COUNT
134                    .with_label_values(&[db_name.as_str()])
135                    .inc_by(rows as u64);
136            }
137            return result;
138        }
139
140        let mut mask_per_datanode = HashMap::with_capacity(region_masks.len());
141        for (region_number, mask) in region_masks {
142            let region_id = RegionId::new(table_id, region_number);
143            let datanode = self
144                .partition_manager
145                .find_region_leader(region_id)
146                .await
147                .context(error::FindRegionLeaderSnafu)?;
148            mask_per_datanode
149                .entry(datanode)
150                .or_insert_with(Vec::new)
151                .push((region_id, mask));
152        }
153
154        let wait_all_datanode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
155            .with_label_values(&["wait_all_datanode"])
156            .start_timer();
157
158        let mut handles = Vec::with_capacity(mask_per_datanode.len());
159
160        // raw daya header and payload bytes.
161        let mut raw_data_bytes = None;
162        for (peer, masks) in mask_per_datanode {
163            for (region_id, mask) in masks {
164                if mask.select_none() {
165                    continue;
166                }
167                let rb = record_batch.clone();
168                let schema_bytes = schema_bytes.clone();
169                let node_manager = self.node_manager.clone();
170                let peer = peer.clone();
171                let raw_header_and_data = if mask.select_all() {
172                    Some(
173                        raw_data_bytes
174                            .get_or_insert_with(|| {
175                                (data.data_header.clone(), data.data_body.clone())
176                            })
177                            .clone(),
178                    )
179                } else {
180                    None
181                };
182                let handle: common_runtime::JoinHandle<error::Result<api::region::RegionResponse>> =
183                    common_runtime::spawn_global(async move {
184                        let (header, payload) = if mask.select_all() {
185                            // SAFETY: raw data must be present, we can avoid re-encoding.
186                            raw_header_and_data.unwrap()
187                        } else {
188                            let filter_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
189                                .with_label_values(&["filter"])
190                                .start_timer();
191                            let batch = arrow::compute::filter_record_batch(&rb, mask.array())
192                                .context(error::ComputeArrowSnafu)?;
193                            filter_timer.observe_duration();
194                            metrics::BULK_REQUEST_ROWS
195                                .with_label_values(&["rows_per_region"])
196                                .observe(batch.num_rows() as f64);
197
198                            let encode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
199                                .with_label_values(&["encode"])
200                                .start_timer();
201                            let mut iter = FlightEncoder::default()
202                                .encode(FlightMessage::RecordBatch(batch))
203                                .into_iter();
204                            let Some(flight_data) = iter.next() else {
205                                // Safety: `iter` on a type of `Vec1`, which is guaranteed to have
206                                // at least one element.
207                                unreachable!()
208                            };
209                            ensure!(
210                                iter.next().is_none(),
211                                error::NotSupportedSnafu {
212                                    feat: "bulk insert RecordBatch with dictionary arrays",
213                                }
214                            );
215                            encode_timer.observe_duration();
216                            (flight_data.data_header, flight_data.data_body)
217                        };
218                        let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
219                            .with_label_values(&["datanode_handle"])
220                            .start_timer();
221                        let request = RegionRequest {
222                            header: Some(RegionRequestHeader {
223                                tracing_context: TracingContext::from_current_span().to_w3c(),
224                                ..Default::default()
225                            }),
226                            body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
227                                region_id: region_id.as_u64(),
228                                body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
229                                    schema: schema_bytes,
230                                    data_header: header,
231                                    payload,
232                                })),
233                            })),
234                        };
235
236                        let datanode = node_manager.datanode(&peer).await;
237                        datanode
238                            .handle(request)
239                            .await
240                            .context(error::RequestRegionSnafu)
241                    });
242                handles.push(handle);
243            }
244        }
245
246        let region_responses = futures::future::try_join_all(handles)
247            .await
248            .context(error::JoinTaskSnafu)?;
249        wait_all_datanode_timer.observe_duration();
250        let mut rows_inserted: usize = 0;
251        for res in region_responses {
252            rows_inserted += res?.affected_rows;
253        }
254        crate::metrics::DIST_INGEST_ROW_COUNT
255            .with_label_values(&[db_name.as_str()])
256            .inc_by(rows_inserted as u64);
257        Ok(rows_inserted)
258    }
259
260    fn maybe_update_flow_dirty_window(&self, table_info: TableInfoRef, record_batch: RecordBatch) {
261        let table_id = table_info.table_id();
262        let table_flownode_set_cache = self.table_flownode_set_cache.clone();
263        let node_manager = self.node_manager.clone();
264        common_runtime::spawn_global(async move {
265            let result = table_flownode_set_cache
266                .get(table_id)
267                .await
268                .context(error::RequestInsertsSnafu);
269            let flownodes = match result {
270                Ok(flownodes) => flownodes.unwrap_or_default(),
271                Err(e) => {
272                    error!(e; "Failed to get flownodes for table id: {}", table_id);
273                    return;
274                }
275            };
276
277            let peers: HashSet<_> = flownodes.values().cloned().collect();
278            if peers.is_empty() {
279                return;
280            }
281
282            let Ok(timestamps) = extract_timestamps(
283                &record_batch,
284                &table_info
285                    .meta
286                    .schema
287                    .timestamp_column()
288                    .as_ref()
289                    .unwrap()
290                    .name,
291            )
292            .inspect_err(|e| {
293                error!(e; "Failed to extract timestamps from record batch");
294            }) else {
295                return;
296            };
297
298            for peer in peers {
299                let node_manager = node_manager.clone();
300                let timestamps = timestamps.clone();
301                common_runtime::spawn_global(async move {
302                    if let Err(e) = node_manager
303                        .flownode(&peer)
304                        .await
305                        .handle_mark_window_dirty(DirtyWindowRequest {
306                            table_id,
307                            timestamps,
308                        })
309                        .await
310                        .context(error::RequestInsertsSnafu)
311                    {
312                        error!(e; "Failed to mark timestamps as dirty, table: {}", table_id);
313                    }
314                });
315            }
316        });
317    }
318}
319
320/// Calculate the timestamp range of record batch. Return `None` if record batch is empty.
321fn extract_timestamps(rb: &RecordBatch, timestamp_index_name: &str) -> error::Result<Vec<i64>> {
322    let ts_col = rb
323        .column_by_name(timestamp_index_name)
324        .context(error::ColumnNotFoundSnafu {
325            msg: timestamp_index_name,
326        })?;
327    if rb.num_rows() == 0 {
328        return Ok(vec![]);
329    }
330    let (primitive, _) =
331        datatypes::timestamp::timestamp_array_to_primitive(ts_col).with_context(|| {
332            error::InvalidTimeIndexTypeSnafu {
333                ty: ts_col.data_type().clone(),
334            }
335        })?;
336    Ok(primitive.iter().flatten().collect())
337}