frontend/
limiter.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use api::v1::column::Values;
18use api::v1::greptime_request::Request;
19use api::v1::value::ValueData;
20use api::v1::{
21    Decimal128, InsertRequests, IntervalMonthDayNano, RowInsertRequest, RowInsertRequests,
22};
23use pipeline::ContextReq;
24use snafu::ResultExt;
25use tokio::sync::{OwnedSemaphorePermit, Semaphore};
26
27use crate::error::{AcquireLimiterSnafu, Result};
28
29pub(crate) type LimiterRef = Arc<Limiter>;
30
31/// A frontend request limiter that controls the total size of in-flight write
32/// requests.
33pub(crate) struct Limiter {
34    max_in_flight_write_bytes: usize,
35    byte_counter: Arc<Semaphore>,
36}
37
38impl Limiter {
39    pub fn new(max_in_flight_write_bytes: usize) -> Self {
40        Self {
41            byte_counter: Arc::new(Semaphore::new(max_in_flight_write_bytes)),
42            max_in_flight_write_bytes,
43        }
44    }
45
46    pub async fn limit_request(&self, request: &Request) -> Result<OwnedSemaphorePermit> {
47        let size = match request {
48            Request::Inserts(requests) => self.insert_requests_data_size(requests),
49            Request::RowInserts(requests) => {
50                self.rows_insert_requests_data_size(requests.inserts.iter())
51            }
52            _ => 0,
53        };
54        self.limit_in_flight_write_bytes(size).await
55    }
56
57    pub async fn limit_row_inserts(
58        &self,
59        requests: &RowInsertRequests,
60    ) -> Result<OwnedSemaphorePermit> {
61        let size = self.rows_insert_requests_data_size(requests.inserts.iter());
62        self.limit_in_flight_write_bytes(size).await
63    }
64
65    pub async fn limit_ctx_req(&self, opt_req: &ContextReq) -> Result<OwnedSemaphorePermit> {
66        let size = self.rows_insert_requests_data_size(opt_req.ref_all_req());
67        self.limit_in_flight_write_bytes(size).await
68    }
69
70    /// Await until more inflight bytes are available
71    pub async fn limit_in_flight_write_bytes(&self, bytes: usize) -> Result<OwnedSemaphorePermit> {
72        self.byte_counter
73            .clone()
74            .acquire_many_owned(bytes as u32)
75            .await
76            .context(AcquireLimiterSnafu)
77    }
78
79    /// Returns the current in-flight write bytes.
80    #[allow(dead_code)]
81    pub fn in_flight_write_bytes(&self) -> usize {
82        self.max_in_flight_write_bytes - self.byte_counter.available_permits()
83    }
84
85    fn insert_requests_data_size(&self, request: &InsertRequests) -> usize {
86        let mut size: usize = 0;
87        for insert in &request.inserts {
88            for column in &insert.columns {
89                if let Some(values) = &column.values {
90                    size += Self::size_of_column_values(values);
91                }
92            }
93        }
94        size
95    }
96
97    fn rows_insert_requests_data_size<'a>(
98        &self,
99        inserts: impl Iterator<Item = &'a RowInsertRequest>,
100    ) -> usize {
101        let mut size: usize = 0;
102        for insert in inserts {
103            if let Some(rows) = &insert.rows {
104                for row in &rows.rows {
105                    for value in &row.values {
106                        if let Some(value) = &value.value_data {
107                            size += Self::size_of_value_data(value);
108                        }
109                    }
110                }
111            }
112        }
113        size
114    }
115
116    fn size_of_column_values(values: &Values) -> usize {
117        let mut size: usize = 0;
118        size += values.i8_values.len() * size_of::<i32>();
119        size += values.i16_values.len() * size_of::<i32>();
120        size += values.i32_values.len() * size_of::<i32>();
121        size += values.i64_values.len() * size_of::<i64>();
122        size += values.u8_values.len() * size_of::<u32>();
123        size += values.u16_values.len() * size_of::<u32>();
124        size += values.u32_values.len() * size_of::<u32>();
125        size += values.u64_values.len() * size_of::<u64>();
126        size += values.f32_values.len() * size_of::<f32>();
127        size += values.f64_values.len() * size_of::<f64>();
128        size += values.bool_values.len() * size_of::<bool>();
129        size += values
130            .binary_values
131            .iter()
132            .map(|v| v.len() * size_of::<u8>())
133            .sum::<usize>();
134        size += values.string_values.iter().map(|v| v.len()).sum::<usize>();
135        size += values.date_values.len() * size_of::<i32>();
136        size += values.datetime_values.len() * size_of::<i64>();
137        size += values.timestamp_second_values.len() * size_of::<i64>();
138        size += values.timestamp_millisecond_values.len() * size_of::<i64>();
139        size += values.timestamp_microsecond_values.len() * size_of::<i64>();
140        size += values.timestamp_nanosecond_values.len() * size_of::<i64>();
141        size += values.time_second_values.len() * size_of::<i64>();
142        size += values.time_millisecond_values.len() * size_of::<i64>();
143        size += values.time_microsecond_values.len() * size_of::<i64>();
144        size += values.time_nanosecond_values.len() * size_of::<i64>();
145        size += values.interval_year_month_values.len() * size_of::<i64>();
146        size += values.interval_day_time_values.len() * size_of::<i64>();
147        size += values.interval_month_day_nano_values.len() * size_of::<IntervalMonthDayNano>();
148        size += values.decimal128_values.len() * size_of::<Decimal128>();
149        size += values
150            .list_values
151            .iter()
152            .map(|v| {
153                v.items
154                    .iter()
155                    .map(|item| {
156                        item.value_data
157                            .as_ref()
158                            .map(Self::size_of_value_data)
159                            .unwrap_or(0)
160                    })
161                    .sum::<usize>()
162            })
163            .sum::<usize>();
164        size += values
165            .struct_values
166            .iter()
167            .map(|v| {
168                v.items
169                    .iter()
170                    .map(|item| {
171                        item.value_data
172                            .as_ref()
173                            .map(Self::size_of_value_data)
174                            .unwrap_or(0)
175                    })
176                    .sum::<usize>()
177            })
178            .sum::<usize>();
179
180        size
181    }
182
183    fn size_of_value_data(value: &ValueData) -> usize {
184        match value {
185            ValueData::I8Value(_) => size_of::<i32>(),
186            ValueData::I16Value(_) => size_of::<i32>(),
187            ValueData::I32Value(_) => size_of::<i32>(),
188            ValueData::I64Value(_) => size_of::<i64>(),
189            ValueData::U8Value(_) => size_of::<u32>(),
190            ValueData::U16Value(_) => size_of::<u32>(),
191            ValueData::U32Value(_) => size_of::<u32>(),
192            ValueData::U64Value(_) => size_of::<u64>(),
193            ValueData::F32Value(_) => size_of::<f32>(),
194            ValueData::F64Value(_) => size_of::<f64>(),
195            ValueData::BoolValue(_) => size_of::<bool>(),
196            ValueData::BinaryValue(v) => v.len() * size_of::<u8>(),
197            ValueData::StringValue(v) => v.len(),
198            ValueData::DateValue(_) => size_of::<i32>(),
199            ValueData::DatetimeValue(_) => size_of::<i64>(),
200            ValueData::TimestampSecondValue(_) => size_of::<i64>(),
201            ValueData::TimestampMillisecondValue(_) => size_of::<i64>(),
202            ValueData::TimestampMicrosecondValue(_) => size_of::<i64>(),
203            ValueData::TimestampNanosecondValue(_) => size_of::<i64>(),
204            ValueData::TimeSecondValue(_) => size_of::<i64>(),
205            ValueData::TimeMillisecondValue(_) => size_of::<i64>(),
206            ValueData::TimeMicrosecondValue(_) => size_of::<i64>(),
207            ValueData::TimeNanosecondValue(_) => size_of::<i64>(),
208            ValueData::IntervalYearMonthValue(_) => size_of::<i32>(),
209            ValueData::IntervalDayTimeValue(_) => size_of::<i64>(),
210            ValueData::IntervalMonthDayNanoValue(_) => size_of::<IntervalMonthDayNano>(),
211            ValueData::Decimal128Value(_) => size_of::<Decimal128>(),
212            ValueData::ListValue(list_values) => list_values
213                .items
214                .iter()
215                .map(|item| {
216                    item.value_data
217                        .as_ref()
218                        .map(Self::size_of_value_data)
219                        .unwrap_or(0)
220                })
221                .sum(),
222            ValueData::StructValue(struct_values) => struct_values
223                .items
224                .iter()
225                .map(|item| {
226                    item.value_data
227                        .as_ref()
228                        .map(Self::size_of_value_data)
229                        .unwrap_or(0)
230                })
231                .sum(),
232            ValueData::JsonValue(inner) => inner
233                .as_ref()
234                .value_data
235                .as_ref()
236                .map(Self::size_of_value_data)
237                .unwrap_or(0),
238        }
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use api::v1::column::Values;
245    use api::v1::greptime_request::Request;
246    use api::v1::{Column, InsertRequest};
247
248    use super::*;
249
250    fn generate_request(size: usize) -> Request {
251        let i8_values = vec![0; size / 4];
252        Request::Inserts(InsertRequests {
253            inserts: vec![InsertRequest {
254                columns: vec![Column {
255                    values: Some(Values {
256                        i8_values,
257                        ..Default::default()
258                    }),
259                    ..Default::default()
260                }],
261                ..Default::default()
262            }],
263        })
264    }
265
266    #[tokio::test]
267    async fn test_limiter() {
268        let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
269        let tasks_count = 10;
270        let request_data_size = 100;
271        let mut handles = vec![];
272
273        // Generate multiple requests to test the limiter.
274        for _ in 0..tasks_count {
275            let limiter = limiter_ref.clone();
276            let handle = tokio::spawn(async move {
277                let result = limiter
278                    .limit_request(&generate_request(request_data_size))
279                    .await;
280                assert!(result.is_ok());
281            });
282            handles.push(handle);
283        }
284
285        // Wait for all threads to complete.
286        for handle in handles {
287            handle.await.unwrap();
288        }
289    }
290
291    #[tokio::test]
292    async fn test_in_flight_write_bytes() {
293        let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
294        let req1 = generate_request(100);
295        let result1 = limiter_ref
296            .limit_request(&req1)
297            .await
298            .expect("failed to acquire permits");
299        assert_eq!(limiter_ref.in_flight_write_bytes(), 100);
300
301        let req2 = generate_request(200);
302        let result2 = limiter_ref
303            .limit_request(&req2)
304            .await
305            .expect("failed to acquire permits");
306        assert_eq!(limiter_ref.in_flight_write_bytes(), 300);
307
308        drop(result1);
309        assert_eq!(limiter_ref.in_flight_write_bytes(), 200);
310
311        drop(result2);
312        assert_eq!(limiter_ref.in_flight_write_bytes(), 0);
313    }
314}