operator/
bulk_insert.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use ahash::{HashMap, HashMapExt};
18use api::v1::flow::DirtyWindowRequest;
19use api::v1::region::{
20    bulk_insert_request, region_request, BulkInsertRequest, RegionRequest, RegionRequestHeader,
21};
22use api::v1::ArrowIpc;
23use arrow::array::{
24    Array, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
25    TimestampSecondArray,
26};
27use arrow::datatypes::{DataType, Int64Type, TimeUnit};
28use arrow::record_batch::RecordBatch;
29use common_base::AffectedRows;
30use common_grpc::flight::{FlightDecoder, FlightEncoder, FlightMessage};
31use common_grpc::FlightData;
32use common_telemetry::error;
33use common_telemetry::tracing_context::TracingContext;
34use snafu::{OptionExt, ResultExt};
35use store_api::storage::RegionId;
36use table::metadata::TableInfoRef;
37use table::TableRef;
38
39use crate::insert::Inserter;
40use crate::{error, metrics};
41
42impl Inserter {
43    /// Handle bulk insert request.
44    pub async fn handle_bulk_insert(
45        &self,
46        table: TableRef,
47        decoder: &mut FlightDecoder,
48        data: FlightData,
49    ) -> error::Result<AffectedRows> {
50        let table_info = table.table_info();
51        let table_id = table_info.table_id();
52        let decode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
53            .with_label_values(&["decode_request"])
54            .start_timer();
55        let body_size = data.data_body.len();
56        // Build region server requests
57        let message = decoder
58            .try_decode(&data)
59            .context(error::DecodeFlightDataSnafu)?;
60        let FlightMessage::RecordBatch(record_batch) = message else {
61            return Ok(0);
62        };
63        decode_timer.observe_duration();
64
65        // notify flownode to update dirty timestamps if flow is configured.
66        self.maybe_update_flow_dirty_window(table_info, record_batch.clone());
67
68        metrics::BULK_REQUEST_MESSAGE_SIZE.observe(body_size as f64);
69        metrics::BULK_REQUEST_ROWS
70            .with_label_values(&["raw"])
71            .observe(record_batch.num_rows() as f64);
72
73        // safety: when reach here schema must be present.
74        let schema_bytes = decoder.schema_bytes().unwrap();
75        let partition_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
76            .with_label_values(&["partition"])
77            .start_timer();
78        let partition_rule = self
79            .partition_manager
80            .find_table_partition_rule(table_id)
81            .await
82            .context(error::InvalidPartitionSnafu)?;
83
84        // find partitions for each row in the record batch
85        let region_masks = partition_rule
86            .split_record_batch(&record_batch)
87            .context(error::SplitInsertSnafu)?;
88        partition_timer.observe_duration();
89
90        // fast path: only one region.
91        if region_masks.len() == 1 {
92            metrics::BULK_REQUEST_ROWS
93                .with_label_values(&["rows_per_region"])
94                .observe(record_batch.num_rows() as f64);
95
96            // SAFETY: region masks length checked
97            let (region_number, _) = region_masks.into_iter().next().unwrap();
98            let region_id = RegionId::new(table_id, region_number);
99            let datanode = self
100                .partition_manager
101                .find_region_leader(region_id)
102                .await
103                .context(error::FindRegionLeaderSnafu)?;
104            let request = RegionRequest {
105                header: Some(RegionRequestHeader {
106                    tracing_context: TracingContext::from_current_span().to_w3c(),
107                    ..Default::default()
108                }),
109                body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
110                    region_id: region_id.as_u64(),
111                    body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
112                        schema: schema_bytes,
113                        data_header: data.data_header,
114                        payload: data.data_body,
115                    })),
116                })),
117            };
118
119            let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
120                .with_label_values(&["datanode_handle"])
121                .start_timer();
122            let datanode = self.node_manager.datanode(&datanode).await;
123            let result = datanode
124                .handle(request)
125                .await
126                .context(error::RequestRegionSnafu)
127                .map(|r| r.affected_rows);
128            if let Ok(rows) = result {
129                crate::metrics::DIST_INGEST_ROW_COUNT.inc_by(rows as u64);
130            }
131            return result;
132        }
133
134        let mut mask_per_datanode = HashMap::with_capacity(region_masks.len());
135        for (region_number, mask) in region_masks {
136            let region_id = RegionId::new(table_id, region_number);
137            let datanode = self
138                .partition_manager
139                .find_region_leader(region_id)
140                .await
141                .context(error::FindRegionLeaderSnafu)?;
142            mask_per_datanode
143                .entry(datanode)
144                .or_insert_with(Vec::new)
145                .push((region_id, mask));
146        }
147
148        let wait_all_datanode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
149            .with_label_values(&["wait_all_datanode"])
150            .start_timer();
151
152        let mut handles = Vec::with_capacity(mask_per_datanode.len());
153
154        // raw daya header and payload bytes.
155        let mut raw_data_bytes = None;
156        for (peer, masks) in mask_per_datanode {
157            for (region_id, mask) in masks {
158                let rb = record_batch.clone();
159                let schema_bytes = schema_bytes.clone();
160                let node_manager = self.node_manager.clone();
161                let peer = peer.clone();
162                let raw_header_and_data = if mask.select_all() {
163                    Some(
164                        raw_data_bytes
165                            .get_or_insert_with(|| {
166                                (data.data_header.clone(), data.data_body.clone())
167                            })
168                            .clone(),
169                    )
170                } else {
171                    None
172                };
173                let handle: common_runtime::JoinHandle<error::Result<api::region::RegionResponse>> =
174                    common_runtime::spawn_global(async move {
175                        let (header, payload) = if mask.select_all() {
176                            // SAFETY: raw data must be present, we can avoid re-encoding.
177                            raw_header_and_data.unwrap()
178                        } else {
179                            let filter_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
180                                .with_label_values(&["filter"])
181                                .start_timer();
182                            let batch = arrow::compute::filter_record_batch(&rb, mask.array())
183                                .context(error::ComputeArrowSnafu)?;
184                            filter_timer.observe_duration();
185                            metrics::BULK_REQUEST_ROWS
186                                .with_label_values(&["rows_per_region"])
187                                .observe(batch.num_rows() as f64);
188
189                            let encode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
190                                .with_label_values(&["encode"])
191                                .start_timer();
192                            let flight_data =
193                                FlightEncoder::default().encode(FlightMessage::RecordBatch(batch));
194                            encode_timer.observe_duration();
195                            (flight_data.data_header, flight_data.data_body)
196                        };
197                        let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
198                            .with_label_values(&["datanode_handle"])
199                            .start_timer();
200                        let request = RegionRequest {
201                            header: Some(RegionRequestHeader {
202                                tracing_context: TracingContext::from_current_span().to_w3c(),
203                                ..Default::default()
204                            }),
205                            body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
206                                region_id: region_id.as_u64(),
207                                body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
208                                    schema: schema_bytes,
209                                    data_header: header,
210                                    payload,
211                                })),
212                            })),
213                        };
214
215                        let datanode = node_manager.datanode(&peer).await;
216                        datanode
217                            .handle(request)
218                            .await
219                            .context(error::RequestRegionSnafu)
220                    });
221                handles.push(handle);
222            }
223        }
224
225        let region_responses = futures::future::try_join_all(handles)
226            .await
227            .context(error::JoinTaskSnafu)?;
228        wait_all_datanode_timer.observe_duration();
229        let mut rows_inserted: usize = 0;
230        for res in region_responses {
231            rows_inserted += res?.affected_rows;
232        }
233        crate::metrics::DIST_INGEST_ROW_COUNT.inc_by(rows_inserted as u64);
234        Ok(rows_inserted)
235    }
236
237    fn maybe_update_flow_dirty_window(&self, table_info: TableInfoRef, record_batch: RecordBatch) {
238        let table_id = table_info.table_id();
239        let table_flownode_set_cache = self.table_flownode_set_cache.clone();
240        let node_manager = self.node_manager.clone();
241        common_runtime::spawn_global(async move {
242            let result = table_flownode_set_cache
243                .get(table_id)
244                .await
245                .context(error::RequestInsertsSnafu);
246            let flownodes = match result {
247                Ok(flownodes) => flownodes.unwrap_or_default(),
248                Err(e) => {
249                    error!(e; "Failed to get flownodes for table id: {}", table_id);
250                    return;
251                }
252            };
253
254            let peers: HashSet<_> = flownodes.values().cloned().collect();
255            if peers.is_empty() {
256                return;
257            }
258
259            let Ok(timestamps) = extract_timestamps(
260                &record_batch,
261                &table_info
262                    .meta
263                    .schema
264                    .timestamp_column()
265                    .as_ref()
266                    .unwrap()
267                    .name,
268            )
269            .inspect_err(|e| {
270                error!(e; "Failed to extract timestamps from record batch");
271            }) else {
272                return;
273            };
274
275            for peer in peers {
276                let node_manager = node_manager.clone();
277                let timestamps = timestamps.clone();
278                common_runtime::spawn_global(async move {
279                    if let Err(e) = node_manager
280                        .flownode(&peer)
281                        .await
282                        .handle_mark_window_dirty(DirtyWindowRequest {
283                            table_id,
284                            timestamps,
285                        })
286                        .await
287                        .context(error::RequestInsertsSnafu)
288                    {
289                        error!(e; "Failed to mark timestamps as dirty, table: {}", table_id);
290                    }
291                });
292            }
293        });
294    }
295}
296
297/// Calculate the timestamp range of record batch. Return `None` if record batch is empty.
298fn extract_timestamps(rb: &RecordBatch, timestamp_index_name: &str) -> error::Result<Vec<i64>> {
299    let ts_col = rb
300        .column_by_name(timestamp_index_name)
301        .context(error::ColumnNotFoundSnafu {
302            msg: timestamp_index_name,
303        })?;
304    if rb.num_rows() == 0 {
305        return Ok(vec![]);
306    }
307    let primitive = match ts_col.data_type() {
308        DataType::Timestamp(unit, _) => match unit {
309            TimeUnit::Second => ts_col
310                .as_any()
311                .downcast_ref::<TimestampSecondArray>()
312                .unwrap()
313                .reinterpret_cast::<Int64Type>(),
314            TimeUnit::Millisecond => ts_col
315                .as_any()
316                .downcast_ref::<TimestampMillisecondArray>()
317                .unwrap()
318                .reinterpret_cast::<Int64Type>(),
319            TimeUnit::Microsecond => ts_col
320                .as_any()
321                .downcast_ref::<TimestampMicrosecondArray>()
322                .unwrap()
323                .reinterpret_cast::<Int64Type>(),
324            TimeUnit::Nanosecond => ts_col
325                .as_any()
326                .downcast_ref::<TimestampNanosecondArray>()
327                .unwrap()
328                .reinterpret_cast::<Int64Type>(),
329        },
330        t => {
331            return error::InvalidTimeIndexTypeSnafu { ty: t.clone() }.fail();
332        }
333    };
334    Ok(primitive.iter().flatten().collect())
335}