frontend/
limiter.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::Arc;
17
18use api::v1::column::Values;
19use api::v1::greptime_request::Request;
20use api::v1::value::ValueData;
21use api::v1::{
22    Decimal128, InsertRequests, IntervalMonthDayNano, RowInsertRequest, RowInsertRequests,
23};
24use common_telemetry::{debug, warn};
25use pipeline::ContextReq;
26
27pub(crate) type LimiterRef = Arc<Limiter>;
28
29/// A frontend request limiter that controls the total size of in-flight write requests.
30pub(crate) struct Limiter {
31    // The maximum number of bytes that can be in flight.
32    max_in_flight_write_bytes: u64,
33
34    // The current in-flight write bytes.
35    in_flight_write_bytes: Arc<AtomicU64>,
36}
37
38/// A counter for the in-flight write bytes.
39pub(crate) struct InFlightWriteBytesCounter {
40    // The current in-flight write bytes.
41    in_flight_write_bytes: Arc<AtomicU64>,
42
43    // The write bytes that are being processed.
44    processing_write_bytes: u64,
45}
46
47impl InFlightWriteBytesCounter {
48    /// Creates a new InFlightWriteBytesCounter. It will decrease the in-flight write bytes when dropped.
49    pub fn new(in_flight_write_bytes: Arc<AtomicU64>, processing_write_bytes: u64) -> Self {
50        debug!(
51            "processing write bytes: {}, current in-flight write bytes: {}",
52            processing_write_bytes,
53            in_flight_write_bytes.load(Ordering::Relaxed)
54        );
55        Self {
56            in_flight_write_bytes,
57            processing_write_bytes,
58        }
59    }
60}
61
62impl Drop for InFlightWriteBytesCounter {
63    // When the request is finished, the in-flight write bytes should be decreased.
64    fn drop(&mut self) {
65        self.in_flight_write_bytes
66            .fetch_sub(self.processing_write_bytes, Ordering::Relaxed);
67    }
68}
69
70impl Limiter {
71    pub fn new(max_in_flight_write_bytes: u64) -> Self {
72        Self {
73            max_in_flight_write_bytes,
74            in_flight_write_bytes: Arc::new(AtomicU64::new(0)),
75        }
76    }
77
78    pub fn limit_request(&self, request: &Request) -> Option<InFlightWriteBytesCounter> {
79        let size = match request {
80            Request::Inserts(requests) => self.insert_requests_data_size(requests),
81            Request::RowInserts(requests) => {
82                self.rows_insert_requests_data_size(requests.inserts.iter())
83            }
84            _ => 0,
85        };
86        self.limit_in_flight_write_bytes(size as u64)
87    }
88
89    pub fn limit_row_inserts(
90        &self,
91        requests: &RowInsertRequests,
92    ) -> Option<InFlightWriteBytesCounter> {
93        let size = self.rows_insert_requests_data_size(requests.inserts.iter());
94        self.limit_in_flight_write_bytes(size as u64)
95    }
96
97    pub fn limit_ctx_req(&self, opt_req: &ContextReq) -> Option<InFlightWriteBytesCounter> {
98        let size = self.rows_insert_requests_data_size(opt_req.ref_all_req());
99        self.limit_in_flight_write_bytes(size as u64)
100    }
101
102    /// Returns None if the in-flight write bytes exceed the maximum limit.
103    /// Otherwise, returns Some(InFlightWriteBytesCounter) and the in-flight write bytes will be increased.
104    pub fn limit_in_flight_write_bytes(&self, bytes: u64) -> Option<InFlightWriteBytesCounter> {
105        let result = self.in_flight_write_bytes.fetch_update(
106            Ordering::Relaxed,
107            Ordering::Relaxed,
108            |current| {
109                if current + bytes > self.max_in_flight_write_bytes {
110                    warn!(
111                        "in-flight write bytes exceed the maximum limit {}, request with {} bytes will be limited",
112                        self.max_in_flight_write_bytes,
113                        bytes
114                    );
115                    return None;
116                }
117                Some(current + bytes)
118            },
119        );
120
121        match result {
122            // Update the in-flight write bytes successfully.
123            Ok(_) => Some(InFlightWriteBytesCounter::new(
124                self.in_flight_write_bytes.clone(),
125                bytes,
126            )),
127            // It means the in-flight write bytes exceed the maximum limit.
128            Err(_) => None,
129        }
130    }
131
132    /// Returns the current in-flight write bytes.
133    #[allow(dead_code)]
134    pub fn in_flight_write_bytes(&self) -> u64 {
135        self.in_flight_write_bytes.load(Ordering::Relaxed)
136    }
137
138    fn insert_requests_data_size(&self, request: &InsertRequests) -> usize {
139        let mut size: usize = 0;
140        for insert in &request.inserts {
141            for column in &insert.columns {
142                if let Some(values) = &column.values {
143                    size += self.size_of_column_values(values);
144                }
145            }
146        }
147        size
148    }
149
150    fn rows_insert_requests_data_size<'a>(
151        &self,
152        inserts: impl Iterator<Item = &'a RowInsertRequest>,
153    ) -> usize {
154        let mut size: usize = 0;
155        for insert in inserts {
156            if let Some(rows) = &insert.rows {
157                for row in &rows.rows {
158                    for value in &row.values {
159                        if let Some(value) = &value.value_data {
160                            size += self.size_of_value_data(value);
161                        }
162                    }
163                }
164            }
165        }
166        size
167    }
168
169    fn size_of_column_values(&self, values: &Values) -> usize {
170        let mut size: usize = 0;
171        size += values.i8_values.len() * size_of::<i32>();
172        size += values.i16_values.len() * size_of::<i32>();
173        size += values.i32_values.len() * size_of::<i32>();
174        size += values.i64_values.len() * size_of::<i64>();
175        size += values.u8_values.len() * size_of::<u32>();
176        size += values.u16_values.len() * size_of::<u32>();
177        size += values.u32_values.len() * size_of::<u32>();
178        size += values.u64_values.len() * size_of::<u64>();
179        size += values.f32_values.len() * size_of::<f32>();
180        size += values.f64_values.len() * size_of::<f64>();
181        size += values.bool_values.len() * size_of::<bool>();
182        size += values
183            .binary_values
184            .iter()
185            .map(|v| v.len() * size_of::<u8>())
186            .sum::<usize>();
187        size += values.string_values.iter().map(|v| v.len()).sum::<usize>();
188        size += values.date_values.len() * size_of::<i32>();
189        size += values.datetime_values.len() * size_of::<i64>();
190        size += values.timestamp_second_values.len() * size_of::<i64>();
191        size += values.timestamp_millisecond_values.len() * size_of::<i64>();
192        size += values.timestamp_microsecond_values.len() * size_of::<i64>();
193        size += values.timestamp_nanosecond_values.len() * size_of::<i64>();
194        size += values.time_second_values.len() * size_of::<i64>();
195        size += values.time_millisecond_values.len() * size_of::<i64>();
196        size += values.time_microsecond_values.len() * size_of::<i64>();
197        size += values.time_nanosecond_values.len() * size_of::<i64>();
198        size += values.interval_year_month_values.len() * size_of::<i64>();
199        size += values.interval_day_time_values.len() * size_of::<i64>();
200        size += values.interval_month_day_nano_values.len() * size_of::<IntervalMonthDayNano>();
201        size += values.decimal128_values.len() * size_of::<Decimal128>();
202        size
203    }
204
205    fn size_of_value_data(&self, value: &ValueData) -> usize {
206        match value {
207            ValueData::I8Value(_) => size_of::<i32>(),
208            ValueData::I16Value(_) => size_of::<i32>(),
209            ValueData::I32Value(_) => size_of::<i32>(),
210            ValueData::I64Value(_) => size_of::<i64>(),
211            ValueData::U8Value(_) => size_of::<u32>(),
212            ValueData::U16Value(_) => size_of::<u32>(),
213            ValueData::U32Value(_) => size_of::<u32>(),
214            ValueData::U64Value(_) => size_of::<u64>(),
215            ValueData::F32Value(_) => size_of::<f32>(),
216            ValueData::F64Value(_) => size_of::<f64>(),
217            ValueData::BoolValue(_) => size_of::<bool>(),
218            ValueData::BinaryValue(v) => v.len() * size_of::<u8>(),
219            ValueData::StringValue(v) => v.len(),
220            ValueData::DateValue(_) => size_of::<i32>(),
221            ValueData::DatetimeValue(_) => size_of::<i64>(),
222            ValueData::TimestampSecondValue(_) => size_of::<i64>(),
223            ValueData::TimestampMillisecondValue(_) => size_of::<i64>(),
224            ValueData::TimestampMicrosecondValue(_) => size_of::<i64>(),
225            ValueData::TimestampNanosecondValue(_) => size_of::<i64>(),
226            ValueData::TimeSecondValue(_) => size_of::<i64>(),
227            ValueData::TimeMillisecondValue(_) => size_of::<i64>(),
228            ValueData::TimeMicrosecondValue(_) => size_of::<i64>(),
229            ValueData::TimeNanosecondValue(_) => size_of::<i64>(),
230            ValueData::IntervalYearMonthValue(_) => size_of::<i32>(),
231            ValueData::IntervalDayTimeValue(_) => size_of::<i64>(),
232            ValueData::IntervalMonthDayNanoValue(_) => size_of::<IntervalMonthDayNano>(),
233            ValueData::Decimal128Value(_) => size_of::<Decimal128>(),
234        }
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use api::v1::column::Values;
241    use api::v1::greptime_request::Request;
242    use api::v1::{Column, InsertRequest};
243
244    use super::*;
245
246    fn generate_request(size: usize) -> Request {
247        let i8_values = vec![0; size / 4];
248        Request::Inserts(InsertRequests {
249            inserts: vec![InsertRequest {
250                columns: vec![Column {
251                    values: Some(Values {
252                        i8_values,
253                        ..Default::default()
254                    }),
255                    ..Default::default()
256                }],
257                ..Default::default()
258            }],
259        })
260    }
261
262    #[tokio::test]
263    async fn test_limiter() {
264        let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
265        let tasks_count = 10;
266        let request_data_size = 100;
267        let mut handles = vec![];
268
269        // Generate multiple requests to test the limiter.
270        for _ in 0..tasks_count {
271            let limiter = limiter_ref.clone();
272            let handle = tokio::spawn(async move {
273                let result = limiter.limit_request(&generate_request(request_data_size));
274                assert!(result.is_some());
275            });
276            handles.push(handle);
277        }
278
279        // Wait for all threads to complete.
280        for handle in handles {
281            handle.await.unwrap();
282        }
283    }
284
285    #[test]
286    fn test_in_flight_write_bytes() {
287        let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
288        let req1 = generate_request(100);
289        let result1 = limiter_ref.limit_request(&req1);
290        assert!(result1.is_some());
291        assert_eq!(limiter_ref.in_flight_write_bytes(), 100);
292
293        let req2 = generate_request(200);
294        let result2 = limiter_ref.limit_request(&req2);
295        assert!(result2.is_some());
296        assert_eq!(limiter_ref.in_flight_write_bytes(), 300);
297
298        drop(result1.unwrap());
299        assert_eq!(limiter_ref.in_flight_write_bytes(), 200);
300
301        drop(result2.unwrap());
302        assert_eq!(limiter_ref.in_flight_write_bytes(), 0);
303    }
304}