frontend/
limiter.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use api::v1::column::Values;
18use api::v1::greptime_request::Request;
19use api::v1::value::ValueData;
20use api::v1::{
21    Decimal128, InsertRequests, IntervalMonthDayNano, JsonValue, RowInsertRequest,
22    RowInsertRequests, json_value,
23};
24use pipeline::ContextReq;
25use snafu::ResultExt;
26use tokio::sync::{OwnedSemaphorePermit, Semaphore};
27
28use crate::error::{AcquireLimiterSnafu, Result};
29
30pub(crate) type LimiterRef = Arc<Limiter>;
31
32/// A frontend request limiter that controls the total size of in-flight write
33/// requests.
34pub(crate) struct Limiter {
35    max_in_flight_write_bytes: usize,
36    byte_counter: Arc<Semaphore>,
37}
38
39impl Limiter {
40    pub fn new(max_in_flight_write_bytes: usize) -> Self {
41        Self {
42            byte_counter: Arc::new(Semaphore::new(max_in_flight_write_bytes)),
43            max_in_flight_write_bytes,
44        }
45    }
46
47    pub async fn limit_request(&self, request: &Request) -> Result<OwnedSemaphorePermit> {
48        let size = match request {
49            Request::Inserts(requests) => self.insert_requests_data_size(requests),
50            Request::RowInserts(requests) => {
51                self.rows_insert_requests_data_size(requests.inserts.iter())
52            }
53            _ => 0,
54        };
55        self.limit_in_flight_write_bytes(size).await
56    }
57
58    pub async fn limit_row_inserts(
59        &self,
60        requests: &RowInsertRequests,
61    ) -> Result<OwnedSemaphorePermit> {
62        let size = self.rows_insert_requests_data_size(requests.inserts.iter());
63        self.limit_in_flight_write_bytes(size).await
64    }
65
66    pub async fn limit_ctx_req(&self, opt_req: &ContextReq) -> Result<OwnedSemaphorePermit> {
67        let size = self.rows_insert_requests_data_size(opt_req.ref_all_req());
68        self.limit_in_flight_write_bytes(size).await
69    }
70
71    /// Await until more inflight bytes are available
72    pub async fn limit_in_flight_write_bytes(&self, bytes: usize) -> Result<OwnedSemaphorePermit> {
73        self.byte_counter
74            .clone()
75            .acquire_many_owned(bytes as u32)
76            .await
77            .context(AcquireLimiterSnafu)
78    }
79
80    /// Returns the current in-flight write bytes.
81    #[allow(dead_code)]
82    pub fn in_flight_write_bytes(&self) -> usize {
83        self.max_in_flight_write_bytes - self.byte_counter.available_permits()
84    }
85
86    fn insert_requests_data_size(&self, request: &InsertRequests) -> usize {
87        let mut size: usize = 0;
88        for insert in &request.inserts {
89            for column in &insert.columns {
90                if let Some(values) = &column.values {
91                    size += Self::size_of_column_values(values);
92                }
93            }
94        }
95        size
96    }
97
98    fn rows_insert_requests_data_size<'a>(
99        &self,
100        inserts: impl Iterator<Item = &'a RowInsertRequest>,
101    ) -> usize {
102        let mut size: usize = 0;
103        for insert in inserts {
104            if let Some(rows) = &insert.rows {
105                for row in &rows.rows {
106                    for value in &row.values {
107                        if let Some(value) = &value.value_data {
108                            size += Self::size_of_value_data(value);
109                        }
110                    }
111                }
112            }
113        }
114        size
115    }
116
117    fn size_of_column_values(values: &Values) -> usize {
118        let mut size: usize = 0;
119        size += values.i8_values.len() * size_of::<i32>();
120        size += values.i16_values.len() * size_of::<i32>();
121        size += values.i32_values.len() * size_of::<i32>();
122        size += values.i64_values.len() * size_of::<i64>();
123        size += values.u8_values.len() * size_of::<u32>();
124        size += values.u16_values.len() * size_of::<u32>();
125        size += values.u32_values.len() * size_of::<u32>();
126        size += values.u64_values.len() * size_of::<u64>();
127        size += values.f32_values.len() * size_of::<f32>();
128        size += values.f64_values.len() * size_of::<f64>();
129        size += values.bool_values.len() * size_of::<bool>();
130        size += values
131            .binary_values
132            .iter()
133            .map(|v| v.len() * size_of::<u8>())
134            .sum::<usize>();
135        size += values.string_values.iter().map(|v| v.len()).sum::<usize>();
136        size += values.date_values.len() * size_of::<i32>();
137        size += values.datetime_values.len() * size_of::<i64>();
138        size += values.timestamp_second_values.len() * size_of::<i64>();
139        size += values.timestamp_millisecond_values.len() * size_of::<i64>();
140        size += values.timestamp_microsecond_values.len() * size_of::<i64>();
141        size += values.timestamp_nanosecond_values.len() * size_of::<i64>();
142        size += values.time_second_values.len() * size_of::<i64>();
143        size += values.time_millisecond_values.len() * size_of::<i64>();
144        size += values.time_microsecond_values.len() * size_of::<i64>();
145        size += values.time_nanosecond_values.len() * size_of::<i64>();
146        size += values.interval_year_month_values.len() * size_of::<i64>();
147        size += values.interval_day_time_values.len() * size_of::<i64>();
148        size += values.interval_month_day_nano_values.len() * size_of::<IntervalMonthDayNano>();
149        size += values.decimal128_values.len() * size_of::<Decimal128>();
150        size += values
151            .list_values
152            .iter()
153            .map(|v| {
154                v.items
155                    .iter()
156                    .map(|item| {
157                        item.value_data
158                            .as_ref()
159                            .map(Self::size_of_value_data)
160                            .unwrap_or(0)
161                    })
162                    .sum::<usize>()
163            })
164            .sum::<usize>();
165        size += values
166            .struct_values
167            .iter()
168            .map(|v| {
169                v.items
170                    .iter()
171                    .map(|item| {
172                        item.value_data
173                            .as_ref()
174                            .map(Self::size_of_value_data)
175                            .unwrap_or(0)
176                    })
177                    .sum::<usize>()
178            })
179            .sum::<usize>();
180
181        size
182    }
183
184    fn size_of_value_data(value: &ValueData) -> usize {
185        match value {
186            ValueData::I8Value(_) => size_of::<i32>(),
187            ValueData::I16Value(_) => size_of::<i32>(),
188            ValueData::I32Value(_) => size_of::<i32>(),
189            ValueData::I64Value(_) => size_of::<i64>(),
190            ValueData::U8Value(_) => size_of::<u32>(),
191            ValueData::U16Value(_) => size_of::<u32>(),
192            ValueData::U32Value(_) => size_of::<u32>(),
193            ValueData::U64Value(_) => size_of::<u64>(),
194            ValueData::F32Value(_) => size_of::<f32>(),
195            ValueData::F64Value(_) => size_of::<f64>(),
196            ValueData::BoolValue(_) => size_of::<bool>(),
197            ValueData::BinaryValue(v) => v.len() * size_of::<u8>(),
198            ValueData::StringValue(v) => v.len(),
199            ValueData::DateValue(_) => size_of::<i32>(),
200            ValueData::DatetimeValue(_) => size_of::<i64>(),
201            ValueData::TimestampSecondValue(_) => size_of::<i64>(),
202            ValueData::TimestampMillisecondValue(_) => size_of::<i64>(),
203            ValueData::TimestampMicrosecondValue(_) => size_of::<i64>(),
204            ValueData::TimestampNanosecondValue(_) => size_of::<i64>(),
205            ValueData::TimeSecondValue(_) => size_of::<i64>(),
206            ValueData::TimeMillisecondValue(_) => size_of::<i64>(),
207            ValueData::TimeMicrosecondValue(_) => size_of::<i64>(),
208            ValueData::TimeNanosecondValue(_) => size_of::<i64>(),
209            ValueData::IntervalYearMonthValue(_) => size_of::<i32>(),
210            ValueData::IntervalDayTimeValue(_) => size_of::<i64>(),
211            ValueData::IntervalMonthDayNanoValue(_) => size_of::<IntervalMonthDayNano>(),
212            ValueData::Decimal128Value(_) => size_of::<Decimal128>(),
213            ValueData::ListValue(list_values) => list_values
214                .items
215                .iter()
216                .map(|item| {
217                    item.value_data
218                        .as_ref()
219                        .map(Self::size_of_value_data)
220                        .unwrap_or(0)
221                })
222                .sum(),
223            ValueData::StructValue(struct_values) => struct_values
224                .items
225                .iter()
226                .map(|item| {
227                    item.value_data
228                        .as_ref()
229                        .map(Self::size_of_value_data)
230                        .unwrap_or(0)
231                })
232                .sum(),
233            ValueData::JsonValue(v) => {
234                fn calc(v: &JsonValue) -> usize {
235                    let Some(value) = v.value.as_ref() else {
236                        return 0;
237                    };
238                    match value {
239                        json_value::Value::Boolean(_) => size_of::<bool>(),
240                        json_value::Value::Int(_) => size_of::<i64>(),
241                        json_value::Value::Uint(_) => size_of::<u64>(),
242                        json_value::Value::Float(_) => size_of::<f64>(),
243                        json_value::Value::Str(s) => s.len(),
244                        json_value::Value::Array(array) => array.items.iter().map(calc).sum(),
245                        json_value::Value::Object(object) => object
246                            .entries
247                            .iter()
248                            .flat_map(|entry| {
249                                entry.value.as_ref().map(|v| entry.key.len() + calc(v))
250                            })
251                            .sum(),
252                    }
253                }
254                calc(v)
255            }
256        }
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use api::v1::column::Values;
263    use api::v1::greptime_request::Request;
264    use api::v1::{Column, InsertRequest};
265
266    use super::*;
267
268    fn generate_request(size: usize) -> Request {
269        let i8_values = vec![0; size / 4];
270        Request::Inserts(InsertRequests {
271            inserts: vec![InsertRequest {
272                columns: vec![Column {
273                    values: Some(Values {
274                        i8_values,
275                        ..Default::default()
276                    }),
277                    ..Default::default()
278                }],
279                ..Default::default()
280            }],
281        })
282    }
283
284    #[tokio::test]
285    async fn test_limiter() {
286        let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
287        let tasks_count = 10;
288        let request_data_size = 100;
289        let mut handles = vec![];
290
291        // Generate multiple requests to test the limiter.
292        for _ in 0..tasks_count {
293            let limiter = limiter_ref.clone();
294            let handle = tokio::spawn(async move {
295                let result = limiter
296                    .limit_request(&generate_request(request_data_size))
297                    .await;
298                assert!(result.is_ok());
299            });
300            handles.push(handle);
301        }
302
303        // Wait for all threads to complete.
304        for handle in handles {
305            handle.await.unwrap();
306        }
307    }
308
309    #[tokio::test]
310    async fn test_in_flight_write_bytes() {
311        let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
312        let req1 = generate_request(100);
313        let result1 = limiter_ref
314            .limit_request(&req1)
315            .await
316            .expect("failed to acquire permits");
317        assert_eq!(limiter_ref.in_flight_write_bytes(), 100);
318
319        let req2 = generate_request(200);
320        let result2 = limiter_ref
321            .limit_request(&req2)
322            .await
323            .expect("failed to acquire permits");
324        assert_eq!(limiter_ref.in_flight_write_bytes(), 300);
325
326        drop(result1);
327        assert_eq!(limiter_ref.in_flight_write_bytes(), 200);
328
329        drop(result2);
330        assert_eq!(limiter_ref.in_flight_write_bytes(), 0);
331    }
332}