frontend/
limiter.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::Arc;
17
18use api::v1::column::Values;
19use api::v1::greptime_request::Request;
20use api::v1::value::ValueData;
21use api::v1::{Decimal128, InsertRequests, IntervalMonthDayNano, RowInsertRequests};
22use common_telemetry::{debug, warn};
23
24pub(crate) type LimiterRef = Arc<Limiter>;
25
26/// A frontend request limiter that controls the total size of in-flight write requests.
27pub(crate) struct Limiter {
28    // The maximum number of bytes that can be in flight.
29    max_in_flight_write_bytes: u64,
30
31    // The current in-flight write bytes.
32    in_flight_write_bytes: Arc<AtomicU64>,
33}
34
35/// A counter for the in-flight write bytes.
36pub(crate) struct InFlightWriteBytesCounter {
37    // The current in-flight write bytes.
38    in_flight_write_bytes: Arc<AtomicU64>,
39
40    // The write bytes that are being processed.
41    processing_write_bytes: u64,
42}
43
44impl InFlightWriteBytesCounter {
45    /// Creates a new InFlightWriteBytesCounter. It will decrease the in-flight write bytes when dropped.
46    pub fn new(in_flight_write_bytes: Arc<AtomicU64>, processing_write_bytes: u64) -> Self {
47        debug!(
48            "processing write bytes: {}, current in-flight write bytes: {}",
49            processing_write_bytes,
50            in_flight_write_bytes.load(Ordering::Relaxed)
51        );
52        Self {
53            in_flight_write_bytes,
54            processing_write_bytes,
55        }
56    }
57}
58
59impl Drop for InFlightWriteBytesCounter {
60    // When the request is finished, the in-flight write bytes should be decreased.
61    fn drop(&mut self) {
62        self.in_flight_write_bytes
63            .fetch_sub(self.processing_write_bytes, Ordering::Relaxed);
64    }
65}
66
67impl Limiter {
68    pub fn new(max_in_flight_write_bytes: u64) -> Self {
69        Self {
70            max_in_flight_write_bytes,
71            in_flight_write_bytes: Arc::new(AtomicU64::new(0)),
72        }
73    }
74
75    pub fn limit_request(&self, request: &Request) -> Option<InFlightWriteBytesCounter> {
76        let size = match request {
77            Request::Inserts(requests) => self.insert_requests_data_size(requests),
78            Request::RowInserts(requests) => self.rows_insert_requests_data_size(requests),
79            _ => 0,
80        };
81        self.limit_in_flight_write_bytes(size as u64)
82    }
83
84    pub fn limit_row_inserts(
85        &self,
86        requests: &RowInsertRequests,
87    ) -> Option<InFlightWriteBytesCounter> {
88        let size = self.rows_insert_requests_data_size(requests);
89        self.limit_in_flight_write_bytes(size as u64)
90    }
91
92    /// Returns None if the in-flight write bytes exceed the maximum limit.
93    /// Otherwise, returns Some(InFlightWriteBytesCounter) and the in-flight write bytes will be increased.
94    pub fn limit_in_flight_write_bytes(&self, bytes: u64) -> Option<InFlightWriteBytesCounter> {
95        let result = self.in_flight_write_bytes.fetch_update(
96            Ordering::Relaxed,
97            Ordering::Relaxed,
98            |current| {
99                if current + bytes > self.max_in_flight_write_bytes {
100                    warn!(
101                        "in-flight write bytes exceed the maximum limit {}, request with {} bytes will be limited",
102                        self.max_in_flight_write_bytes,
103                        bytes
104                    );
105                    return None;
106                }
107                Some(current + bytes)
108            },
109        );
110
111        match result {
112            // Update the in-flight write bytes successfully.
113            Ok(_) => Some(InFlightWriteBytesCounter::new(
114                self.in_flight_write_bytes.clone(),
115                bytes,
116            )),
117            // It means the in-flight write bytes exceed the maximum limit.
118            Err(_) => None,
119        }
120    }
121
122    /// Returns the current in-flight write bytes.
123    #[allow(dead_code)]
124    pub fn in_flight_write_bytes(&self) -> u64 {
125        self.in_flight_write_bytes.load(Ordering::Relaxed)
126    }
127
128    fn insert_requests_data_size(&self, request: &InsertRequests) -> usize {
129        let mut size: usize = 0;
130        for insert in &request.inserts {
131            for column in &insert.columns {
132                if let Some(values) = &column.values {
133                    size += self.size_of_column_values(values);
134                }
135            }
136        }
137        size
138    }
139
140    fn rows_insert_requests_data_size(&self, request: &RowInsertRequests) -> usize {
141        let mut size: usize = 0;
142        for insert in &request.inserts {
143            if let Some(rows) = &insert.rows {
144                for row in &rows.rows {
145                    for value in &row.values {
146                        if let Some(value) = &value.value_data {
147                            size += self.size_of_value_data(value);
148                        }
149                    }
150                }
151            }
152        }
153        size
154    }
155
156    fn size_of_column_values(&self, values: &Values) -> usize {
157        let mut size: usize = 0;
158        size += values.i8_values.len() * size_of::<i32>();
159        size += values.i16_values.len() * size_of::<i32>();
160        size += values.i32_values.len() * size_of::<i32>();
161        size += values.i64_values.len() * size_of::<i64>();
162        size += values.u8_values.len() * size_of::<u32>();
163        size += values.u16_values.len() * size_of::<u32>();
164        size += values.u32_values.len() * size_of::<u32>();
165        size += values.u64_values.len() * size_of::<u64>();
166        size += values.f32_values.len() * size_of::<f32>();
167        size += values.f64_values.len() * size_of::<f64>();
168        size += values.bool_values.len() * size_of::<bool>();
169        size += values
170            .binary_values
171            .iter()
172            .map(|v| v.len() * size_of::<u8>())
173            .sum::<usize>();
174        size += values.string_values.iter().map(|v| v.len()).sum::<usize>();
175        size += values.date_values.len() * size_of::<i32>();
176        size += values.datetime_values.len() * size_of::<i64>();
177        size += values.timestamp_second_values.len() * size_of::<i64>();
178        size += values.timestamp_millisecond_values.len() * size_of::<i64>();
179        size += values.timestamp_microsecond_values.len() * size_of::<i64>();
180        size += values.timestamp_nanosecond_values.len() * size_of::<i64>();
181        size += values.time_second_values.len() * size_of::<i64>();
182        size += values.time_millisecond_values.len() * size_of::<i64>();
183        size += values.time_microsecond_values.len() * size_of::<i64>();
184        size += values.time_nanosecond_values.len() * size_of::<i64>();
185        size += values.interval_year_month_values.len() * size_of::<i64>();
186        size += values.interval_day_time_values.len() * size_of::<i64>();
187        size += values.interval_month_day_nano_values.len() * size_of::<IntervalMonthDayNano>();
188        size += values.decimal128_values.len() * size_of::<Decimal128>();
189        size
190    }
191
192    fn size_of_value_data(&self, value: &ValueData) -> usize {
193        match value {
194            ValueData::I8Value(_) => size_of::<i32>(),
195            ValueData::I16Value(_) => size_of::<i32>(),
196            ValueData::I32Value(_) => size_of::<i32>(),
197            ValueData::I64Value(_) => size_of::<i64>(),
198            ValueData::U8Value(_) => size_of::<u32>(),
199            ValueData::U16Value(_) => size_of::<u32>(),
200            ValueData::U32Value(_) => size_of::<u32>(),
201            ValueData::U64Value(_) => size_of::<u64>(),
202            ValueData::F32Value(_) => size_of::<f32>(),
203            ValueData::F64Value(_) => size_of::<f64>(),
204            ValueData::BoolValue(_) => size_of::<bool>(),
205            ValueData::BinaryValue(v) => v.len() * size_of::<u8>(),
206            ValueData::StringValue(v) => v.len(),
207            ValueData::DateValue(_) => size_of::<i32>(),
208            ValueData::DatetimeValue(_) => size_of::<i64>(),
209            ValueData::TimestampSecondValue(_) => size_of::<i64>(),
210            ValueData::TimestampMillisecondValue(_) => size_of::<i64>(),
211            ValueData::TimestampMicrosecondValue(_) => size_of::<i64>(),
212            ValueData::TimestampNanosecondValue(_) => size_of::<i64>(),
213            ValueData::TimeSecondValue(_) => size_of::<i64>(),
214            ValueData::TimeMillisecondValue(_) => size_of::<i64>(),
215            ValueData::TimeMicrosecondValue(_) => size_of::<i64>(),
216            ValueData::TimeNanosecondValue(_) => size_of::<i64>(),
217            ValueData::IntervalYearMonthValue(_) => size_of::<i32>(),
218            ValueData::IntervalDayTimeValue(_) => size_of::<i64>(),
219            ValueData::IntervalMonthDayNanoValue(_) => size_of::<IntervalMonthDayNano>(),
220            ValueData::Decimal128Value(_) => size_of::<Decimal128>(),
221        }
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use api::v1::column::Values;
228    use api::v1::greptime_request::Request;
229    use api::v1::{Column, InsertRequest};
230
231    use super::*;
232
233    fn generate_request(size: usize) -> Request {
234        let i8_values = vec![0; size / 4];
235        Request::Inserts(InsertRequests {
236            inserts: vec![InsertRequest {
237                columns: vec![Column {
238                    values: Some(Values {
239                        i8_values,
240                        ..Default::default()
241                    }),
242                    ..Default::default()
243                }],
244                ..Default::default()
245            }],
246        })
247    }
248
249    #[tokio::test]
250    async fn test_limiter() {
251        let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
252        let tasks_count = 10;
253        let request_data_size = 100;
254        let mut handles = vec![];
255
256        // Generate multiple requests to test the limiter.
257        for _ in 0..tasks_count {
258            let limiter = limiter_ref.clone();
259            let handle = tokio::spawn(async move {
260                let result = limiter.limit_request(&generate_request(request_data_size));
261                assert!(result.is_some());
262            });
263            handles.push(handle);
264        }
265
266        // Wait for all threads to complete.
267        for handle in handles {
268            handle.await.unwrap();
269        }
270    }
271
272    #[test]
273    fn test_in_flight_write_bytes() {
274        let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
275        let req1 = generate_request(100);
276        let result1 = limiter_ref.limit_request(&req1);
277        assert!(result1.is_some());
278        assert_eq!(limiter_ref.in_flight_write_bytes(), 100);
279
280        let req2 = generate_request(200);
281        let result2 = limiter_ref.limit_request(&req2);
282        assert!(result2.is_some());
283        assert_eq!(limiter_ref.in_flight_write_bytes(), 300);
284
285        drop(result1.unwrap());
286        assert_eq!(limiter_ref.in_flight_write_bytes(), 200);
287
288        drop(result2.unwrap());
289        assert_eq!(limiter_ref.in_flight_write_bytes(), 0);
290    }
291}