operator/
bulk_insert.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use ahash::{HashMap, HashMapExt};
18use api::v1::ArrowIpc;
19use api::v1::flow::{DirtyWindowRequest, DirtyWindowRequests};
20use api::v1::region::{
21    BulkInsertRequest, RegionRequest, RegionRequestHeader, bulk_insert_request, region_request,
22};
23use arrow::array::Array;
24use arrow::record_batch::RecordBatch;
25use bytes::Bytes;
26use common_base::AffectedRows;
27use common_grpc::FlightData;
28use common_grpc::flight::{FlightEncoder, FlightMessage};
29use common_telemetry::error;
30use common_telemetry::tracing_context::TracingContext;
31use snafu::{OptionExt, ResultExt, ensure};
32use store_api::storage::RegionId;
33use table::TableRef;
34use table::metadata::TableInfoRef;
35
36use crate::insert::Inserter;
37use crate::{error, metrics};
38
39impl Inserter {
40    /// Handle bulk insert request.
41    pub async fn handle_bulk_insert(
42        &self,
43        table: TableRef,
44        raw_flight_data: FlightData,
45        record_batch: RecordBatch,
46        schema_bytes: Bytes,
47    ) -> error::Result<AffectedRows> {
48        let table_info = table.table_info();
49        let table_id = table_info.table_id();
50        let db_name = table_info.get_db_string();
51
52        if record_batch.num_rows() == 0 {
53            return Ok(0);
54        }
55
56        let body_size = raw_flight_data.data_body.len();
57        // TODO(yingwen): Fill record batch impure default values. Note that we should override `raw_flight_data` if we have to fill defaults.
58        // notify flownode to update dirty timestamps if flow is configured.
59        self.maybe_update_flow_dirty_window(table_info.clone(), record_batch.clone());
60
61        metrics::BULK_REQUEST_MESSAGE_SIZE.observe(body_size as f64);
62        metrics::BULK_REQUEST_ROWS
63            .with_label_values(&["raw"])
64            .observe(record_batch.num_rows() as f64);
65
66        let partition_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
67            .with_label_values(&["partition"])
68            .start_timer();
69        let partition_rule = self
70            .partition_manager
71            .find_table_partition_rule(&table_info)
72            .await
73            .context(error::InvalidPartitionSnafu)?;
74
75        // find partitions for each row in the record batch
76        let region_masks = partition_rule
77            .split_record_batch(&record_batch)
78            .context(error::SplitInsertSnafu)?;
79        partition_timer.observe_duration();
80
81        // fast path: only one region.
82        if region_masks.len() == 1 {
83            metrics::BULK_REQUEST_ROWS
84                .with_label_values(&["rows_per_region"])
85                .observe(record_batch.num_rows() as f64);
86
87            // SAFETY: region masks length checked
88            let (region_number, _) = region_masks.into_iter().next().unwrap();
89            let region_id = RegionId::new(table_id, region_number);
90            let datanode = self
91                .partition_manager
92                .find_region_leader(region_id)
93                .await
94                .context(error::FindRegionLeaderSnafu)?;
95
96            let request = RegionRequest {
97                header: Some(RegionRequestHeader {
98                    tracing_context: TracingContext::from_current_span().to_w3c(),
99                    ..Default::default()
100                }),
101                body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
102                    region_id: region_id.as_u64(),
103                    body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
104                        schema: schema_bytes.clone(),
105                        data_header: raw_flight_data.data_header,
106                        payload: raw_flight_data.data_body,
107                    })),
108                })),
109            };
110
111            let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
112                .with_label_values(&["datanode_handle"])
113                .start_timer();
114            let datanode = self.node_manager.datanode(&datanode).await;
115            let result = datanode
116                .handle(request)
117                .await
118                .context(error::RequestRegionSnafu)
119                .map(|r| r.affected_rows);
120            if let Ok(rows) = result {
121                crate::metrics::DIST_INGEST_ROW_COUNT
122                    .with_label_values(&[db_name.as_str()])
123                    .inc_by(rows as u64);
124            }
125            return result;
126        }
127
128        let mut mask_per_datanode = HashMap::with_capacity(region_masks.len());
129        for (region_number, mask) in region_masks {
130            let region_id = RegionId::new(table_id, region_number);
131            let datanode = self
132                .partition_manager
133                .find_region_leader(region_id)
134                .await
135                .context(error::FindRegionLeaderSnafu)?;
136            mask_per_datanode
137                .entry(datanode)
138                .or_insert_with(Vec::new)
139                .push((region_id, mask));
140        }
141
142        let wait_all_datanode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
143            .with_label_values(&["wait_all_datanode"])
144            .start_timer();
145
146        let mut handles = Vec::with_capacity(mask_per_datanode.len());
147
148        for (peer, masks) in mask_per_datanode {
149            for (region_id, mask) in masks {
150                if mask.select_none() {
151                    continue;
152                }
153                let rb = record_batch.clone();
154                let schema_bytes = schema_bytes.clone();
155                let node_manager = self.node_manager.clone();
156                let peer = peer.clone();
157                let raw_header_and_data = if mask.select_all() {
158                    Some((
159                        raw_flight_data.data_header.clone(),
160                        raw_flight_data.data_body.clone(),
161                    ))
162                } else {
163                    None
164                };
165                let handle: common_runtime::JoinHandle<error::Result<api::region::RegionResponse>> =
166                    common_runtime::spawn_global(async move {
167                        let (header, payload) = if mask.select_all() {
168                            // SAFETY: raw data must be present, we can avoid re-encoding.
169                            raw_header_and_data.unwrap()
170                        } else {
171                            let filter_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
172                                .with_label_values(&["filter"])
173                                .start_timer();
174                            let batch = arrow::compute::filter_record_batch(&rb, mask.array())
175                                .context(error::ComputeArrowSnafu)?;
176                            filter_timer.observe_duration();
177                            metrics::BULK_REQUEST_ROWS
178                                .with_label_values(&["rows_per_region"])
179                                .observe(batch.num_rows() as f64);
180
181                            let encode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
182                                .with_label_values(&["encode"])
183                                .start_timer();
184                            let mut iter = FlightEncoder::default()
185                                .encode(FlightMessage::RecordBatch(batch))
186                                .into_iter();
187                            let Some(flight_data) = iter.next() else {
188                                // Safety: `iter` on a type of `Vec1`, which is guaranteed to have
189                                // at least one element.
190                                unreachable!()
191                            };
192                            ensure!(
193                                iter.next().is_none(),
194                                error::NotSupportedSnafu {
195                                    feat: "bulk insert RecordBatch with dictionary arrays",
196                                }
197                            );
198                            encode_timer.observe_duration();
199                            (flight_data.data_header, flight_data.data_body)
200                        };
201                        let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
202                            .with_label_values(&["datanode_handle"])
203                            .start_timer();
204                        let request = RegionRequest {
205                            header: Some(RegionRequestHeader {
206                                tracing_context: TracingContext::from_current_span().to_w3c(),
207                                ..Default::default()
208                            }),
209                            body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
210                                region_id: region_id.as_u64(),
211                                body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
212                                    schema: schema_bytes,
213                                    data_header: header,
214                                    payload,
215                                })),
216                            })),
217                        };
218
219                        let datanode = node_manager.datanode(&peer).await;
220                        datanode
221                            .handle(request)
222                            .await
223                            .context(error::RequestRegionSnafu)
224                    });
225                handles.push(handle);
226            }
227        }
228
229        let region_responses = futures::future::try_join_all(handles)
230            .await
231            .context(error::JoinTaskSnafu)?;
232        wait_all_datanode_timer.observe_duration();
233        let mut rows_inserted: usize = 0;
234        for res in region_responses {
235            rows_inserted += res?.affected_rows;
236        }
237        crate::metrics::DIST_INGEST_ROW_COUNT
238            .with_label_values(&[db_name.as_str()])
239            .inc_by(rows_inserted as u64);
240        Ok(rows_inserted)
241    }
242
243    fn maybe_update_flow_dirty_window(&self, table_info: TableInfoRef, record_batch: RecordBatch) {
244        let table_id = table_info.table_id();
245        let table_flownode_set_cache = self.table_flownode_set_cache.clone();
246        let node_manager = self.node_manager.clone();
247        common_runtime::spawn_global(async move {
248            let result = table_flownode_set_cache
249                .get(table_id)
250                .await
251                .context(error::RequestInsertsSnafu);
252            let flownodes = match result {
253                Ok(flownodes) => flownodes.unwrap_or_default(),
254                Err(e) => {
255                    error!(e; "Failed to get flownodes for table id: {}", table_id);
256                    return;
257                }
258            };
259
260            let peers: HashSet<_> = flownodes.values().cloned().collect();
261            if peers.is_empty() {
262                return;
263            }
264
265            let Ok(timestamps) = extract_timestamps(
266                &record_batch,
267                &table_info
268                    .meta
269                    .schema
270                    .timestamp_column()
271                    .as_ref()
272                    .unwrap()
273                    .name,
274            )
275            .inspect_err(|e| {
276                error!(e; "Failed to extract timestamps from record batch");
277            }) else {
278                return;
279            };
280
281            for peer in peers {
282                let node_manager = node_manager.clone();
283                let timestamps = timestamps.clone();
284                common_runtime::spawn_global(async move {
285                    if let Err(e) = node_manager
286                        .flownode(&peer)
287                        .await
288                        .handle_mark_window_dirty(DirtyWindowRequests {
289                            requests: vec![DirtyWindowRequest {
290                                table_id,
291                                timestamps,
292                            }],
293                        })
294                        .await
295                        .context(error::RequestInsertsSnafu)
296                    {
297                        error!(e; "Failed to mark timestamps as dirty, table: {}", table_id);
298                    }
299                });
300            }
301        });
302    }
303}
304
305/// Calculate the timestamp range of record batch. Return `None` if record batch is empty.
306fn extract_timestamps(rb: &RecordBatch, timestamp_index_name: &str) -> error::Result<Vec<i64>> {
307    let ts_col = rb
308        .column_by_name(timestamp_index_name)
309        .context(error::ColumnNotFoundSnafu {
310            msg: timestamp_index_name,
311        })?;
312    if rb.num_rows() == 0 {
313        return Ok(vec![]);
314    }
315    let (primitive, _) =
316        datatypes::timestamp::timestamp_array_to_primitive(ts_col).with_context(|| {
317            error::InvalidTimeIndexTypeSnafu {
318                ty: ts_col.data_type().clone(),
319            }
320        })?;
321    Ok(primitive.iter().flatten().collect())
322}