1use std::sync::Arc;
16
17use api::v1::column::Values;
18use api::v1::greptime_request::Request;
19use api::v1::value::ValueData;
20use api::v1::{
21 Decimal128, InsertRequests, IntervalMonthDayNano, JsonValue, RowInsertRequest,
22 RowInsertRequests, json_value,
23};
24use pipeline::ContextReq;
25use snafu::ResultExt;
26use tokio::sync::{OwnedSemaphorePermit, Semaphore};
27
28use crate::error::{AcquireLimiterSnafu, Result};
29
30pub(crate) type LimiterRef = Arc<Limiter>;
31
32pub(crate) struct Limiter {
35 max_in_flight_write_bytes: usize,
36 byte_counter: Arc<Semaphore>,
37}
38
39impl Limiter {
40 pub fn new(max_in_flight_write_bytes: usize) -> Self {
41 Self {
42 byte_counter: Arc::new(Semaphore::new(max_in_flight_write_bytes)),
43 max_in_flight_write_bytes,
44 }
45 }
46
47 pub async fn limit_request(&self, request: &Request) -> Result<OwnedSemaphorePermit> {
48 let size = match request {
49 Request::Inserts(requests) => self.insert_requests_data_size(requests),
50 Request::RowInserts(requests) => {
51 self.rows_insert_requests_data_size(requests.inserts.iter())
52 }
53 _ => 0,
54 };
55 self.limit_in_flight_write_bytes(size).await
56 }
57
58 pub async fn limit_row_inserts(
59 &self,
60 requests: &RowInsertRequests,
61 ) -> Result<OwnedSemaphorePermit> {
62 let size = self.rows_insert_requests_data_size(requests.inserts.iter());
63 self.limit_in_flight_write_bytes(size).await
64 }
65
66 pub async fn limit_ctx_req(&self, opt_req: &ContextReq) -> Result<OwnedSemaphorePermit> {
67 let size = self.rows_insert_requests_data_size(opt_req.ref_all_req());
68 self.limit_in_flight_write_bytes(size).await
69 }
70
71 pub async fn limit_in_flight_write_bytes(&self, bytes: usize) -> Result<OwnedSemaphorePermit> {
73 self.byte_counter
74 .clone()
75 .acquire_many_owned(bytes as u32)
76 .await
77 .context(AcquireLimiterSnafu)
78 }
79
80 #[allow(dead_code)]
82 pub fn in_flight_write_bytes(&self) -> usize {
83 self.max_in_flight_write_bytes - self.byte_counter.available_permits()
84 }
85
86 fn insert_requests_data_size(&self, request: &InsertRequests) -> usize {
87 let mut size: usize = 0;
88 for insert in &request.inserts {
89 for column in &insert.columns {
90 if let Some(values) = &column.values {
91 size += Self::size_of_column_values(values);
92 }
93 }
94 }
95 size
96 }
97
98 fn rows_insert_requests_data_size<'a>(
99 &self,
100 inserts: impl Iterator<Item = &'a RowInsertRequest>,
101 ) -> usize {
102 let mut size: usize = 0;
103 for insert in inserts {
104 if let Some(rows) = &insert.rows {
105 for row in &rows.rows {
106 for value in &row.values {
107 if let Some(value) = &value.value_data {
108 size += Self::size_of_value_data(value);
109 }
110 }
111 }
112 }
113 }
114 size
115 }
116
117 fn size_of_column_values(values: &Values) -> usize {
118 let mut size: usize = 0;
119 size += values.i8_values.len() * size_of::<i32>();
120 size += values.i16_values.len() * size_of::<i32>();
121 size += values.i32_values.len() * size_of::<i32>();
122 size += values.i64_values.len() * size_of::<i64>();
123 size += values.u8_values.len() * size_of::<u32>();
124 size += values.u16_values.len() * size_of::<u32>();
125 size += values.u32_values.len() * size_of::<u32>();
126 size += values.u64_values.len() * size_of::<u64>();
127 size += values.f32_values.len() * size_of::<f32>();
128 size += values.f64_values.len() * size_of::<f64>();
129 size += values.bool_values.len() * size_of::<bool>();
130 size += values
131 .binary_values
132 .iter()
133 .map(|v| v.len() * size_of::<u8>())
134 .sum::<usize>();
135 size += values.string_values.iter().map(|v| v.len()).sum::<usize>();
136 size += values.date_values.len() * size_of::<i32>();
137 size += values.datetime_values.len() * size_of::<i64>();
138 size += values.timestamp_second_values.len() * size_of::<i64>();
139 size += values.timestamp_millisecond_values.len() * size_of::<i64>();
140 size += values.timestamp_microsecond_values.len() * size_of::<i64>();
141 size += values.timestamp_nanosecond_values.len() * size_of::<i64>();
142 size += values.time_second_values.len() * size_of::<i64>();
143 size += values.time_millisecond_values.len() * size_of::<i64>();
144 size += values.time_microsecond_values.len() * size_of::<i64>();
145 size += values.time_nanosecond_values.len() * size_of::<i64>();
146 size += values.interval_year_month_values.len() * size_of::<i64>();
147 size += values.interval_day_time_values.len() * size_of::<i64>();
148 size += values.interval_month_day_nano_values.len() * size_of::<IntervalMonthDayNano>();
149 size += values.decimal128_values.len() * size_of::<Decimal128>();
150 size += values
151 .list_values
152 .iter()
153 .map(|v| {
154 v.items
155 .iter()
156 .map(|item| {
157 item.value_data
158 .as_ref()
159 .map(Self::size_of_value_data)
160 .unwrap_or(0)
161 })
162 .sum::<usize>()
163 })
164 .sum::<usize>();
165 size += values
166 .struct_values
167 .iter()
168 .map(|v| {
169 v.items
170 .iter()
171 .map(|item| {
172 item.value_data
173 .as_ref()
174 .map(Self::size_of_value_data)
175 .unwrap_or(0)
176 })
177 .sum::<usize>()
178 })
179 .sum::<usize>();
180
181 size
182 }
183
184 fn size_of_value_data(value: &ValueData) -> usize {
185 match value {
186 ValueData::I8Value(_) => size_of::<i32>(),
187 ValueData::I16Value(_) => size_of::<i32>(),
188 ValueData::I32Value(_) => size_of::<i32>(),
189 ValueData::I64Value(_) => size_of::<i64>(),
190 ValueData::U8Value(_) => size_of::<u32>(),
191 ValueData::U16Value(_) => size_of::<u32>(),
192 ValueData::U32Value(_) => size_of::<u32>(),
193 ValueData::U64Value(_) => size_of::<u64>(),
194 ValueData::F32Value(_) => size_of::<f32>(),
195 ValueData::F64Value(_) => size_of::<f64>(),
196 ValueData::BoolValue(_) => size_of::<bool>(),
197 ValueData::BinaryValue(v) => v.len() * size_of::<u8>(),
198 ValueData::StringValue(v) => v.len(),
199 ValueData::DateValue(_) => size_of::<i32>(),
200 ValueData::DatetimeValue(_) => size_of::<i64>(),
201 ValueData::TimestampSecondValue(_) => size_of::<i64>(),
202 ValueData::TimestampMillisecondValue(_) => size_of::<i64>(),
203 ValueData::TimestampMicrosecondValue(_) => size_of::<i64>(),
204 ValueData::TimestampNanosecondValue(_) => size_of::<i64>(),
205 ValueData::TimeSecondValue(_) => size_of::<i64>(),
206 ValueData::TimeMillisecondValue(_) => size_of::<i64>(),
207 ValueData::TimeMicrosecondValue(_) => size_of::<i64>(),
208 ValueData::TimeNanosecondValue(_) => size_of::<i64>(),
209 ValueData::IntervalYearMonthValue(_) => size_of::<i32>(),
210 ValueData::IntervalDayTimeValue(_) => size_of::<i64>(),
211 ValueData::IntervalMonthDayNanoValue(_) => size_of::<IntervalMonthDayNano>(),
212 ValueData::Decimal128Value(_) => size_of::<Decimal128>(),
213 ValueData::ListValue(list_values) => list_values
214 .items
215 .iter()
216 .map(|item| {
217 item.value_data
218 .as_ref()
219 .map(Self::size_of_value_data)
220 .unwrap_or(0)
221 })
222 .sum(),
223 ValueData::StructValue(struct_values) => struct_values
224 .items
225 .iter()
226 .map(|item| {
227 item.value_data
228 .as_ref()
229 .map(Self::size_of_value_data)
230 .unwrap_or(0)
231 })
232 .sum(),
233 ValueData::JsonValue(v) => {
234 fn calc(v: &JsonValue) -> usize {
235 let Some(value) = v.value.as_ref() else {
236 return 0;
237 };
238 match value {
239 json_value::Value::Boolean(_) => size_of::<bool>(),
240 json_value::Value::Int(_) => size_of::<i64>(),
241 json_value::Value::Uint(_) => size_of::<u64>(),
242 json_value::Value::Float(_) => size_of::<f64>(),
243 json_value::Value::Str(s) => s.len(),
244 json_value::Value::Array(array) => array.items.iter().map(calc).sum(),
245 json_value::Value::Object(object) => object
246 .entries
247 .iter()
248 .flat_map(|entry| {
249 entry.value.as_ref().map(|v| entry.key.len() + calc(v))
250 })
251 .sum(),
252 }
253 }
254 calc(v)
255 }
256 }
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use api::v1::column::Values;
263 use api::v1::greptime_request::Request;
264 use api::v1::{Column, InsertRequest};
265
266 use super::*;
267
268 fn generate_request(size: usize) -> Request {
269 let i8_values = vec![0; size / 4];
270 Request::Inserts(InsertRequests {
271 inserts: vec![InsertRequest {
272 columns: vec![Column {
273 values: Some(Values {
274 i8_values,
275 ..Default::default()
276 }),
277 ..Default::default()
278 }],
279 ..Default::default()
280 }],
281 })
282 }
283
284 #[tokio::test]
285 async fn test_limiter() {
286 let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
287 let tasks_count = 10;
288 let request_data_size = 100;
289 let mut handles = vec![];
290
291 for _ in 0..tasks_count {
293 let limiter = limiter_ref.clone();
294 let handle = tokio::spawn(async move {
295 let result = limiter
296 .limit_request(&generate_request(request_data_size))
297 .await;
298 assert!(result.is_ok());
299 });
300 handles.push(handle);
301 }
302
303 for handle in handles {
305 handle.await.unwrap();
306 }
307 }
308
309 #[tokio::test]
310 async fn test_in_flight_write_bytes() {
311 let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
312 let req1 = generate_request(100);
313 let result1 = limiter_ref
314 .limit_request(&req1)
315 .await
316 .expect("failed to acquire permits");
317 assert_eq!(limiter_ref.in_flight_write_bytes(), 100);
318
319 let req2 = generate_request(200);
320 let result2 = limiter_ref
321 .limit_request(&req2)
322 .await
323 .expect("failed to acquire permits");
324 assert_eq!(limiter_ref.in_flight_write_bytes(), 300);
325
326 drop(result1);
327 assert_eq!(limiter_ref.in_flight_write_bytes(), 200);
328
329 drop(result2);
330 assert_eq!(limiter_ref.in_flight_write_bytes(), 0);
331 }
332}