operator/
bulk_insert.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use ahash::{HashMap, HashMapExt};
18use api::v1::flow::{DirtyWindowRequest, DirtyWindowRequests};
19use api::v1::region::{
20    BulkInsertRequest, RegionRequest, RegionRequestHeader, bulk_insert_request, region_request,
21};
22use api::v1::{ArrowIpc, PartitionExprVersion};
23use arrow::array::Array;
24use arrow::record_batch::RecordBatch;
25use bytes::Bytes;
26use common_base::AffectedRows;
27use common_grpc::FlightData;
28use common_grpc::flight::{FlightEncoder, FlightMessage};
29use common_telemetry::error;
30use common_telemetry::tracing_context::TracingContext;
31use snafu::{OptionExt, ResultExt, ensure};
32use store_api::storage::RegionId;
33use table::TableRef;
34use table::metadata::TableInfoRef;
35
36use crate::insert::Inserter;
37use crate::{error, metrics};
38
39impl Inserter {
40    /// Handle bulk insert request.
41    pub async fn handle_bulk_insert(
42        &self,
43        table: TableRef,
44        raw_flight_data: FlightData,
45        record_batch: RecordBatch,
46        schema_bytes: Bytes,
47    ) -> error::Result<AffectedRows> {
48        let table_info = table.table_info();
49        let table_id = table_info.table_id();
50        let db_name = table_info.get_db_string();
51
52        if record_batch.num_rows() == 0 {
53            return Ok(0);
54        }
55
56        let body_size = raw_flight_data.data_body.len();
57        // TODO(yingwen): Fill record batch impure default values. Note that we should override `raw_flight_data` if we have to fill defaults.
58        // notify flownode to update dirty timestamps if flow is configured.
59        self.maybe_update_flow_dirty_window(table_info.clone(), record_batch.clone());
60
61        metrics::BULK_REQUEST_MESSAGE_SIZE.observe(body_size as f64);
62        metrics::BULK_REQUEST_ROWS
63            .with_label_values(&["raw"])
64            .observe(record_batch.num_rows() as f64);
65
66        let partition_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
67            .with_label_values(&["partition"])
68            .start_timer();
69        let (partition_rule, partition_versions) = self
70            .partition_manager
71            .find_table_partition_rule(&table_info)
72            .await
73            .context(error::InvalidPartitionSnafu)?;
74
75        // find partitions for each row in the record batch
76        let region_masks = partition_rule
77            .split_record_batch(&record_batch)
78            .context(error::SplitInsertSnafu)?;
79        partition_timer.observe_duration();
80
81        // fast path: only one region.
82        if region_masks.len() == 1 {
83            metrics::BULK_REQUEST_ROWS
84                .with_label_values(&["rows_per_region"])
85                .observe(record_batch.num_rows() as f64);
86
87            // SAFETY: region masks length checked
88            let (region_number, _) = region_masks.into_iter().next().unwrap();
89            let region_id = RegionId::new(table_id, region_number);
90            let partition_expr_version = partition_versions
91                .get(&region_number)
92                .copied()
93                .unwrap_or_default();
94            let datanode = self
95                .partition_manager
96                .find_region_leader(region_id)
97                .await
98                .context(error::FindRegionLeaderSnafu)?;
99
100            let request = RegionRequest {
101                header: Some(RegionRequestHeader {
102                    tracing_context: TracingContext::from_current_span().to_w3c(),
103                    ..Default::default()
104                }),
105                body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
106                    region_id: region_id.as_u64(),
107                    partition_expr_version: partition_expr_version
108                        .map(|value| PartitionExprVersion { value }),
109                    body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
110                        schema: schema_bytes.clone(),
111                        data_header: raw_flight_data.data_header,
112                        payload: raw_flight_data.data_body,
113                    })),
114                })),
115            };
116
117            let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
118                .with_label_values(&["datanode_handle"])
119                .start_timer();
120            let datanode = self.node_manager.datanode(&datanode).await;
121            let result = datanode
122                .handle(request)
123                .await
124                .context(error::RequestRegionSnafu)
125                .map(|r| r.affected_rows);
126            if let Ok(rows) = result {
127                crate::metrics::DIST_INGEST_ROW_COUNT
128                    .with_label_values(&[db_name.as_str()])
129                    .inc_by(rows as u64);
130            }
131            return result;
132        }
133
134        let mut mask_per_datanode = HashMap::with_capacity(region_masks.len());
135        for (region_number, mask) in region_masks {
136            let region_id = RegionId::new(table_id, region_number);
137            let datanode = self
138                .partition_manager
139                .find_region_leader(region_id)
140                .await
141                .context(error::FindRegionLeaderSnafu)?;
142            mask_per_datanode
143                .entry(datanode)
144                .or_insert_with(Vec::new)
145                .push((region_id, mask));
146        }
147
148        let wait_all_datanode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
149            .with_label_values(&["wait_all_datanode"])
150            .start_timer();
151
152        let mut handles = Vec::with_capacity(mask_per_datanode.len());
153
154        for (peer, masks) in mask_per_datanode {
155            for (region_id, mask) in masks {
156                if mask.select_none() {
157                    continue;
158                }
159                let partition_expr_version = partition_versions
160                    .get(&region_id.region_number())
161                    .copied()
162                    .unwrap_or_default();
163                let rb = record_batch.clone();
164                let schema_bytes = schema_bytes.clone();
165                let node_manager = self.node_manager.clone();
166                let peer = peer.clone();
167                let raw_header_and_data = if mask.select_all() {
168                    Some((
169                        raw_flight_data.data_header.clone(),
170                        raw_flight_data.data_body.clone(),
171                    ))
172                } else {
173                    None
174                };
175                let handle: common_runtime::JoinHandle<error::Result<api::region::RegionResponse>> =
176                    common_runtime::spawn_global(async move {
177                        let (header, payload) = if mask.select_all() {
178                            // SAFETY: raw data must be present, we can avoid re-encoding.
179                            raw_header_and_data.unwrap()
180                        } else {
181                            let filter_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
182                                .with_label_values(&["filter"])
183                                .start_timer();
184                            let batch = arrow::compute::filter_record_batch(&rb, mask.array())
185                                .context(error::ComputeArrowSnafu)?;
186                            filter_timer.observe_duration();
187                            metrics::BULK_REQUEST_ROWS
188                                .with_label_values(&["rows_per_region"])
189                                .observe(batch.num_rows() as f64);
190
191                            let encode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
192                                .with_label_values(&["encode"])
193                                .start_timer();
194                            let mut iter = FlightEncoder::default()
195                                .encode(FlightMessage::RecordBatch(batch))
196                                .into_iter();
197                            let Some(flight_data) = iter.next() else {
198                                // Safety: `iter` on a type of `Vec1`, which is guaranteed to have
199                                // at least one element.
200                                unreachable!()
201                            };
202                            ensure!(
203                                iter.next().is_none(),
204                                error::NotSupportedSnafu {
205                                    feat: "bulk insert RecordBatch with dictionary arrays",
206                                }
207                            );
208                            encode_timer.observe_duration();
209                            (flight_data.data_header, flight_data.data_body)
210                        };
211                        let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
212                            .with_label_values(&["datanode_handle"])
213                            .start_timer();
214                        let request = RegionRequest {
215                            header: Some(RegionRequestHeader {
216                                tracing_context: TracingContext::from_current_span().to_w3c(),
217                                ..Default::default()
218                            }),
219                            body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
220                                region_id: region_id.as_u64(),
221                                partition_expr_version: partition_expr_version
222                                    .map(|value| PartitionExprVersion { value }),
223                                body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
224                                    schema: schema_bytes,
225                                    data_header: header,
226                                    payload,
227                                })),
228                            })),
229                        };
230
231                        let datanode = node_manager.datanode(&peer).await;
232                        datanode
233                            .handle(request)
234                            .await
235                            .context(error::RequestRegionSnafu)
236                    });
237                handles.push(handle);
238            }
239        }
240
241        let region_responses = futures::future::try_join_all(handles)
242            .await
243            .context(error::JoinTaskSnafu)?;
244        wait_all_datanode_timer.observe_duration();
245        let mut rows_inserted: usize = 0;
246        for res in region_responses {
247            rows_inserted += res?.affected_rows;
248        }
249        crate::metrics::DIST_INGEST_ROW_COUNT
250            .with_label_values(&[db_name.as_str()])
251            .inc_by(rows_inserted as u64);
252        Ok(rows_inserted)
253    }
254
255    fn maybe_update_flow_dirty_window(&self, table_info: TableInfoRef, record_batch: RecordBatch) {
256        let table_id = table_info.table_id();
257        let table_flownode_set_cache = self.table_flownode_set_cache.clone();
258        let node_manager = self.node_manager.clone();
259        common_runtime::spawn_global(async move {
260            let result = table_flownode_set_cache
261                .get(table_id)
262                .await
263                .context(error::RequestInsertsSnafu);
264            let flownodes = match result {
265                Ok(flownodes) => flownodes.unwrap_or_default(),
266                Err(e) => {
267                    error!(e; "Failed to get flownodes for table id: {}", table_id);
268                    return;
269                }
270            };
271
272            let peers: HashSet<_> = flownodes.values().cloned().collect();
273            if peers.is_empty() {
274                return;
275            }
276
277            let Ok(timestamps) = extract_timestamps(
278                &record_batch,
279                &table_info
280                    .meta
281                    .schema
282                    .timestamp_column()
283                    .as_ref()
284                    .unwrap()
285                    .name,
286            )
287            .inspect_err(|e| {
288                error!(e; "Failed to extract timestamps from record batch");
289            }) else {
290                return;
291            };
292
293            for peer in peers {
294                let node_manager = node_manager.clone();
295                let timestamps = timestamps.clone();
296                common_runtime::spawn_global(async move {
297                    if let Err(e) = node_manager
298                        .flownode(&peer)
299                        .await
300                        .handle_mark_window_dirty(DirtyWindowRequests {
301                            requests: vec![DirtyWindowRequest {
302                                table_id,
303                                timestamps,
304                            }],
305                        })
306                        .await
307                        .context(error::RequestInsertsSnafu)
308                    {
309                        error!(e; "Failed to mark timestamps as dirty, table: {}", table_id);
310                    }
311                });
312            }
313        });
314    }
315}
316
317/// Calculate the timestamp range of record batch. Return `None` if record batch is empty.
318fn extract_timestamps(rb: &RecordBatch, timestamp_index_name: &str) -> error::Result<Vec<i64>> {
319    let ts_col = rb
320        .column_by_name(timestamp_index_name)
321        .context(error::ColumnNotFoundSnafu {
322            msg: timestamp_index_name,
323        })?;
324    if rb.num_rows() == 0 {
325        return Ok(vec![]);
326    }
327    let (primitive, _) =
328        datatypes::timestamp::timestamp_array_to_primitive(ts_col).with_context(|| {
329            error::InvalidTimeIndexTypeSnafu {
330                ty: ts_col.data_type().clone(),
331            }
332        })?;
333    Ok(primitive.iter().flatten().collect())
334}