1use std::sync::Arc;
16
17use api::v1::column::Values;
18use api::v1::greptime_request::Request;
19use api::v1::value::ValueData;
20use api::v1::{
21 Decimal128, InsertRequests, IntervalMonthDayNano, RowInsertRequest, RowInsertRequests,
22};
23use pipeline::ContextReq;
24use snafu::ResultExt;
25use tokio::sync::{OwnedSemaphorePermit, Semaphore};
26
27use crate::error::{AcquireLimiterSnafu, Result};
28
29pub(crate) type LimiterRef = Arc<Limiter>;
30
31pub(crate) struct Limiter {
34 max_in_flight_write_bytes: usize,
35 byte_counter: Arc<Semaphore>,
36}
37
38impl Limiter {
39 pub fn new(max_in_flight_write_bytes: usize) -> Self {
40 Self {
41 byte_counter: Arc::new(Semaphore::new(max_in_flight_write_bytes)),
42 max_in_flight_write_bytes,
43 }
44 }
45
46 pub async fn limit_request(&self, request: &Request) -> Result<OwnedSemaphorePermit> {
47 let size = match request {
48 Request::Inserts(requests) => self.insert_requests_data_size(requests),
49 Request::RowInserts(requests) => {
50 self.rows_insert_requests_data_size(requests.inserts.iter())
51 }
52 _ => 0,
53 };
54 self.limit_in_flight_write_bytes(size).await
55 }
56
57 pub async fn limit_row_inserts(
58 &self,
59 requests: &RowInsertRequests,
60 ) -> Result<OwnedSemaphorePermit> {
61 let size = self.rows_insert_requests_data_size(requests.inserts.iter());
62 self.limit_in_flight_write_bytes(size).await
63 }
64
65 pub async fn limit_ctx_req(&self, opt_req: &ContextReq) -> Result<OwnedSemaphorePermit> {
66 let size = self.rows_insert_requests_data_size(opt_req.ref_all_req());
67 self.limit_in_flight_write_bytes(size).await
68 }
69
70 pub async fn limit_in_flight_write_bytes(&self, bytes: usize) -> Result<OwnedSemaphorePermit> {
72 self.byte_counter
73 .clone()
74 .acquire_many_owned(bytes as u32)
75 .await
76 .context(AcquireLimiterSnafu)
77 }
78
79 #[allow(dead_code)]
81 pub fn in_flight_write_bytes(&self) -> usize {
82 self.max_in_flight_write_bytes - self.byte_counter.available_permits()
83 }
84
85 fn insert_requests_data_size(&self, request: &InsertRequests) -> usize {
86 let mut size: usize = 0;
87 for insert in &request.inserts {
88 for column in &insert.columns {
89 if let Some(values) = &column.values {
90 size += Self::size_of_column_values(values);
91 }
92 }
93 }
94 size
95 }
96
97 fn rows_insert_requests_data_size<'a>(
98 &self,
99 inserts: impl Iterator<Item = &'a RowInsertRequest>,
100 ) -> usize {
101 let mut size: usize = 0;
102 for insert in inserts {
103 if let Some(rows) = &insert.rows {
104 for row in &rows.rows {
105 for value in &row.values {
106 if let Some(value) = &value.value_data {
107 size += Self::size_of_value_data(value);
108 }
109 }
110 }
111 }
112 }
113 size
114 }
115
116 fn size_of_column_values(values: &Values) -> usize {
117 let mut size: usize = 0;
118 size += values.i8_values.len() * size_of::<i32>();
119 size += values.i16_values.len() * size_of::<i32>();
120 size += values.i32_values.len() * size_of::<i32>();
121 size += values.i64_values.len() * size_of::<i64>();
122 size += values.u8_values.len() * size_of::<u32>();
123 size += values.u16_values.len() * size_of::<u32>();
124 size += values.u32_values.len() * size_of::<u32>();
125 size += values.u64_values.len() * size_of::<u64>();
126 size += values.f32_values.len() * size_of::<f32>();
127 size += values.f64_values.len() * size_of::<f64>();
128 size += values.bool_values.len() * size_of::<bool>();
129 size += values
130 .binary_values
131 .iter()
132 .map(|v| v.len() * size_of::<u8>())
133 .sum::<usize>();
134 size += values.string_values.iter().map(|v| v.len()).sum::<usize>();
135 size += values.date_values.len() * size_of::<i32>();
136 size += values.datetime_values.len() * size_of::<i64>();
137 size += values.timestamp_second_values.len() * size_of::<i64>();
138 size += values.timestamp_millisecond_values.len() * size_of::<i64>();
139 size += values.timestamp_microsecond_values.len() * size_of::<i64>();
140 size += values.timestamp_nanosecond_values.len() * size_of::<i64>();
141 size += values.time_second_values.len() * size_of::<i64>();
142 size += values.time_millisecond_values.len() * size_of::<i64>();
143 size += values.time_microsecond_values.len() * size_of::<i64>();
144 size += values.time_nanosecond_values.len() * size_of::<i64>();
145 size += values.interval_year_month_values.len() * size_of::<i64>();
146 size += values.interval_day_time_values.len() * size_of::<i64>();
147 size += values.interval_month_day_nano_values.len() * size_of::<IntervalMonthDayNano>();
148 size += values.decimal128_values.len() * size_of::<Decimal128>();
149 size += values
150 .list_values
151 .iter()
152 .map(|v| {
153 v.items
154 .iter()
155 .map(|item| {
156 item.value_data
157 .as_ref()
158 .map(Self::size_of_value_data)
159 .unwrap_or(0)
160 })
161 .sum::<usize>()
162 })
163 .sum::<usize>();
164 size += values
165 .struct_values
166 .iter()
167 .map(|v| {
168 v.items
169 .iter()
170 .map(|item| {
171 item.value_data
172 .as_ref()
173 .map(Self::size_of_value_data)
174 .unwrap_or(0)
175 })
176 .sum::<usize>()
177 })
178 .sum::<usize>();
179
180 size
181 }
182
183 fn size_of_value_data(value: &ValueData) -> usize {
184 match value {
185 ValueData::I8Value(_) => size_of::<i32>(),
186 ValueData::I16Value(_) => size_of::<i32>(),
187 ValueData::I32Value(_) => size_of::<i32>(),
188 ValueData::I64Value(_) => size_of::<i64>(),
189 ValueData::U8Value(_) => size_of::<u32>(),
190 ValueData::U16Value(_) => size_of::<u32>(),
191 ValueData::U32Value(_) => size_of::<u32>(),
192 ValueData::U64Value(_) => size_of::<u64>(),
193 ValueData::F32Value(_) => size_of::<f32>(),
194 ValueData::F64Value(_) => size_of::<f64>(),
195 ValueData::BoolValue(_) => size_of::<bool>(),
196 ValueData::BinaryValue(v) => v.len() * size_of::<u8>(),
197 ValueData::StringValue(v) => v.len(),
198 ValueData::DateValue(_) => size_of::<i32>(),
199 ValueData::DatetimeValue(_) => size_of::<i64>(),
200 ValueData::TimestampSecondValue(_) => size_of::<i64>(),
201 ValueData::TimestampMillisecondValue(_) => size_of::<i64>(),
202 ValueData::TimestampMicrosecondValue(_) => size_of::<i64>(),
203 ValueData::TimestampNanosecondValue(_) => size_of::<i64>(),
204 ValueData::TimeSecondValue(_) => size_of::<i64>(),
205 ValueData::TimeMillisecondValue(_) => size_of::<i64>(),
206 ValueData::TimeMicrosecondValue(_) => size_of::<i64>(),
207 ValueData::TimeNanosecondValue(_) => size_of::<i64>(),
208 ValueData::IntervalYearMonthValue(_) => size_of::<i32>(),
209 ValueData::IntervalDayTimeValue(_) => size_of::<i64>(),
210 ValueData::IntervalMonthDayNanoValue(_) => size_of::<IntervalMonthDayNano>(),
211 ValueData::Decimal128Value(_) => size_of::<Decimal128>(),
212 ValueData::ListValue(list_values) => list_values
213 .items
214 .iter()
215 .map(|item| {
216 item.value_data
217 .as_ref()
218 .map(Self::size_of_value_data)
219 .unwrap_or(0)
220 })
221 .sum(),
222 ValueData::StructValue(struct_values) => struct_values
223 .items
224 .iter()
225 .map(|item| {
226 item.value_data
227 .as_ref()
228 .map(Self::size_of_value_data)
229 .unwrap_or(0)
230 })
231 .sum(),
232 ValueData::JsonValue(inner) => inner
233 .as_ref()
234 .value_data
235 .as_ref()
236 .map(Self::size_of_value_data)
237 .unwrap_or(0),
238 }
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use api::v1::column::Values;
245 use api::v1::greptime_request::Request;
246 use api::v1::{Column, InsertRequest};
247
248 use super::*;
249
250 fn generate_request(size: usize) -> Request {
251 let i8_values = vec![0; size / 4];
252 Request::Inserts(InsertRequests {
253 inserts: vec![InsertRequest {
254 columns: vec![Column {
255 values: Some(Values {
256 i8_values,
257 ..Default::default()
258 }),
259 ..Default::default()
260 }],
261 ..Default::default()
262 }],
263 })
264 }
265
266 #[tokio::test]
267 async fn test_limiter() {
268 let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
269 let tasks_count = 10;
270 let request_data_size = 100;
271 let mut handles = vec![];
272
273 for _ in 0..tasks_count {
275 let limiter = limiter_ref.clone();
276 let handle = tokio::spawn(async move {
277 let result = limiter
278 .limit_request(&generate_request(request_data_size))
279 .await;
280 assert!(result.is_ok());
281 });
282 handles.push(handle);
283 }
284
285 for handle in handles {
287 handle.await.unwrap();
288 }
289 }
290
291 #[tokio::test]
292 async fn test_in_flight_write_bytes() {
293 let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
294 let req1 = generate_request(100);
295 let result1 = limiter_ref
296 .limit_request(&req1)
297 .await
298 .expect("failed to acquire permits");
299 assert_eq!(limiter_ref.in_flight_write_bytes(), 100);
300
301 let req2 = generate_request(200);
302 let result2 = limiter_ref
303 .limit_request(&req2)
304 .await
305 .expect("failed to acquire permits");
306 assert_eq!(limiter_ref.in_flight_write_bytes(), 300);
307
308 drop(result1);
309 assert_eq!(limiter_ref.in_flight_write_bytes(), 200);
310
311 drop(result2);
312 assert_eq!(limiter_ref.in_flight_write_bytes(), 0);
313 }
314}