1use std::sync::Arc;
16
17use api::v1::column::Values;
18use api::v1::greptime_request::Request;
19use api::v1::value::ValueData;
20use api::v1::{
21 Decimal128, InsertRequests, IntervalMonthDayNano, RowInsertRequest, RowInsertRequests,
22};
23use pipeline::ContextReq;
24use snafu::ResultExt;
25use tokio::sync::{OwnedSemaphorePermit, Semaphore};
26
27use crate::error::{AcquireLimiterSnafu, Result};
28
29pub(crate) type LimiterRef = Arc<Limiter>;
30
31pub(crate) struct Limiter {
34 max_in_flight_write_bytes: usize,
35 byte_counter: Arc<Semaphore>,
36}
37
38impl Limiter {
39 pub fn new(max_in_flight_write_bytes: usize) -> Self {
40 Self {
41 byte_counter: Arc::new(Semaphore::new(max_in_flight_write_bytes)),
42 max_in_flight_write_bytes,
43 }
44 }
45
46 pub async fn limit_request(&self, request: &Request) -> Result<OwnedSemaphorePermit> {
47 let size = match request {
48 Request::Inserts(requests) => self.insert_requests_data_size(requests),
49 Request::RowInserts(requests) => {
50 self.rows_insert_requests_data_size(requests.inserts.iter())
51 }
52 _ => 0,
53 };
54 self.limit_in_flight_write_bytes(size).await
55 }
56
57 pub async fn limit_row_inserts(
58 &self,
59 requests: &RowInsertRequests,
60 ) -> Result<OwnedSemaphorePermit> {
61 let size = self.rows_insert_requests_data_size(requests.inserts.iter());
62 self.limit_in_flight_write_bytes(size).await
63 }
64
65 pub async fn limit_ctx_req(&self, opt_req: &ContextReq) -> Result<OwnedSemaphorePermit> {
66 let size = self.rows_insert_requests_data_size(opt_req.ref_all_req());
67 self.limit_in_flight_write_bytes(size).await
68 }
69
70 pub async fn limit_in_flight_write_bytes(&self, bytes: usize) -> Result<OwnedSemaphorePermit> {
72 self.byte_counter
73 .clone()
74 .acquire_many_owned(bytes as u32)
75 .await
76 .context(AcquireLimiterSnafu)
77 }
78
79 #[allow(dead_code)]
81 pub fn in_flight_write_bytes(&self) -> usize {
82 self.max_in_flight_write_bytes - self.byte_counter.available_permits()
83 }
84
85 fn insert_requests_data_size(&self, request: &InsertRequests) -> usize {
86 let mut size: usize = 0;
87 for insert in &request.inserts {
88 for column in &insert.columns {
89 if let Some(values) = &column.values {
90 size += self.size_of_column_values(values);
91 }
92 }
93 }
94 size
95 }
96
97 fn rows_insert_requests_data_size<'a>(
98 &self,
99 inserts: impl Iterator<Item = &'a RowInsertRequest>,
100 ) -> usize {
101 let mut size: usize = 0;
102 for insert in inserts {
103 if let Some(rows) = &insert.rows {
104 for row in &rows.rows {
105 for value in &row.values {
106 if let Some(value) = &value.value_data {
107 size += self.size_of_value_data(value);
108 }
109 }
110 }
111 }
112 }
113 size
114 }
115
116 fn size_of_column_values(&self, values: &Values) -> usize {
117 let mut size: usize = 0;
118 size += values.i8_values.len() * size_of::<i32>();
119 size += values.i16_values.len() * size_of::<i32>();
120 size += values.i32_values.len() * size_of::<i32>();
121 size += values.i64_values.len() * size_of::<i64>();
122 size += values.u8_values.len() * size_of::<u32>();
123 size += values.u16_values.len() * size_of::<u32>();
124 size += values.u32_values.len() * size_of::<u32>();
125 size += values.u64_values.len() * size_of::<u64>();
126 size += values.f32_values.len() * size_of::<f32>();
127 size += values.f64_values.len() * size_of::<f64>();
128 size += values.bool_values.len() * size_of::<bool>();
129 size += values
130 .binary_values
131 .iter()
132 .map(|v| v.len() * size_of::<u8>())
133 .sum::<usize>();
134 size += values.string_values.iter().map(|v| v.len()).sum::<usize>();
135 size += values.date_values.len() * size_of::<i32>();
136 size += values.datetime_values.len() * size_of::<i64>();
137 size += values.timestamp_second_values.len() * size_of::<i64>();
138 size += values.timestamp_millisecond_values.len() * size_of::<i64>();
139 size += values.timestamp_microsecond_values.len() * size_of::<i64>();
140 size += values.timestamp_nanosecond_values.len() * size_of::<i64>();
141 size += values.time_second_values.len() * size_of::<i64>();
142 size += values.time_millisecond_values.len() * size_of::<i64>();
143 size += values.time_microsecond_values.len() * size_of::<i64>();
144 size += values.time_nanosecond_values.len() * size_of::<i64>();
145 size += values.interval_year_month_values.len() * size_of::<i64>();
146 size += values.interval_day_time_values.len() * size_of::<i64>();
147 size += values.interval_month_day_nano_values.len() * size_of::<IntervalMonthDayNano>();
148 size += values.decimal128_values.len() * size_of::<Decimal128>();
149 size
150 }
151
152 fn size_of_value_data(&self, value: &ValueData) -> usize {
153 match value {
154 ValueData::I8Value(_) => size_of::<i32>(),
155 ValueData::I16Value(_) => size_of::<i32>(),
156 ValueData::I32Value(_) => size_of::<i32>(),
157 ValueData::I64Value(_) => size_of::<i64>(),
158 ValueData::U8Value(_) => size_of::<u32>(),
159 ValueData::U16Value(_) => size_of::<u32>(),
160 ValueData::U32Value(_) => size_of::<u32>(),
161 ValueData::U64Value(_) => size_of::<u64>(),
162 ValueData::F32Value(_) => size_of::<f32>(),
163 ValueData::F64Value(_) => size_of::<f64>(),
164 ValueData::BoolValue(_) => size_of::<bool>(),
165 ValueData::BinaryValue(v) => v.len() * size_of::<u8>(),
166 ValueData::StringValue(v) => v.len(),
167 ValueData::DateValue(_) => size_of::<i32>(),
168 ValueData::DatetimeValue(_) => size_of::<i64>(),
169 ValueData::TimestampSecondValue(_) => size_of::<i64>(),
170 ValueData::TimestampMillisecondValue(_) => size_of::<i64>(),
171 ValueData::TimestampMicrosecondValue(_) => size_of::<i64>(),
172 ValueData::TimestampNanosecondValue(_) => size_of::<i64>(),
173 ValueData::TimeSecondValue(_) => size_of::<i64>(),
174 ValueData::TimeMillisecondValue(_) => size_of::<i64>(),
175 ValueData::TimeMicrosecondValue(_) => size_of::<i64>(),
176 ValueData::TimeNanosecondValue(_) => size_of::<i64>(),
177 ValueData::IntervalYearMonthValue(_) => size_of::<i32>(),
178 ValueData::IntervalDayTimeValue(_) => size_of::<i64>(),
179 ValueData::IntervalMonthDayNanoValue(_) => size_of::<IntervalMonthDayNano>(),
180 ValueData::Decimal128Value(_) => size_of::<Decimal128>(),
181 }
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use api::v1::column::Values;
188 use api::v1::greptime_request::Request;
189 use api::v1::{Column, InsertRequest};
190
191 use super::*;
192
193 fn generate_request(size: usize) -> Request {
194 let i8_values = vec![0; size / 4];
195 Request::Inserts(InsertRequests {
196 inserts: vec![InsertRequest {
197 columns: vec![Column {
198 values: Some(Values {
199 i8_values,
200 ..Default::default()
201 }),
202 ..Default::default()
203 }],
204 ..Default::default()
205 }],
206 })
207 }
208
209 #[tokio::test]
210 async fn test_limiter() {
211 let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
212 let tasks_count = 10;
213 let request_data_size = 100;
214 let mut handles = vec![];
215
216 for _ in 0..tasks_count {
218 let limiter = limiter_ref.clone();
219 let handle = tokio::spawn(async move {
220 let result = limiter
221 .limit_request(&generate_request(request_data_size))
222 .await;
223 assert!(result.is_ok());
224 });
225 handles.push(handle);
226 }
227
228 for handle in handles {
230 handle.await.unwrap();
231 }
232 }
233
234 #[tokio::test]
235 async fn test_in_flight_write_bytes() {
236 let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
237 let req1 = generate_request(100);
238 let result1 = limiter_ref
239 .limit_request(&req1)
240 .await
241 .expect("failed to acquire permits");
242 assert_eq!(limiter_ref.in_flight_write_bytes(), 100);
243
244 let req2 = generate_request(200);
245 let result2 = limiter_ref
246 .limit_request(&req2)
247 .await
248 .expect("failed to acquire permits");
249 assert_eq!(limiter_ref.in_flight_write_bytes(), 300);
250
251 drop(result1);
252 assert_eq!(limiter_ref.in_flight_write_bytes(), 200);
253
254 drop(result2);
255 assert_eq!(limiter_ref.in_flight_write_bytes(), 0);
256 }
257}