1use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::Arc;
17
18use api::v1::column::Values;
19use api::v1::greptime_request::Request;
20use api::v1::value::ValueData;
21use api::v1::{Decimal128, InsertRequests, IntervalMonthDayNano, RowInsertRequests};
22use common_telemetry::{debug, warn};
23
24pub(crate) type LimiterRef = Arc<Limiter>;
25
26pub(crate) struct Limiter {
28 max_in_flight_write_bytes: u64,
30
31 in_flight_write_bytes: Arc<AtomicU64>,
33}
34
35pub(crate) struct InFlightWriteBytesCounter {
37 in_flight_write_bytes: Arc<AtomicU64>,
39
40 processing_write_bytes: u64,
42}
43
44impl InFlightWriteBytesCounter {
45 pub fn new(in_flight_write_bytes: Arc<AtomicU64>, processing_write_bytes: u64) -> Self {
47 debug!(
48 "processing write bytes: {}, current in-flight write bytes: {}",
49 processing_write_bytes,
50 in_flight_write_bytes.load(Ordering::Relaxed)
51 );
52 Self {
53 in_flight_write_bytes,
54 processing_write_bytes,
55 }
56 }
57}
58
59impl Drop for InFlightWriteBytesCounter {
60 fn drop(&mut self) {
62 self.in_flight_write_bytes
63 .fetch_sub(self.processing_write_bytes, Ordering::Relaxed);
64 }
65}
66
67impl Limiter {
68 pub fn new(max_in_flight_write_bytes: u64) -> Self {
69 Self {
70 max_in_flight_write_bytes,
71 in_flight_write_bytes: Arc::new(AtomicU64::new(0)),
72 }
73 }
74
75 pub fn limit_request(&self, request: &Request) -> Option<InFlightWriteBytesCounter> {
76 let size = match request {
77 Request::Inserts(requests) => self.insert_requests_data_size(requests),
78 Request::RowInserts(requests) => self.rows_insert_requests_data_size(requests),
79 _ => 0,
80 };
81 self.limit_in_flight_write_bytes(size as u64)
82 }
83
84 pub fn limit_row_inserts(
85 &self,
86 requests: &RowInsertRequests,
87 ) -> Option<InFlightWriteBytesCounter> {
88 let size = self.rows_insert_requests_data_size(requests);
89 self.limit_in_flight_write_bytes(size as u64)
90 }
91
92 pub fn limit_in_flight_write_bytes(&self, bytes: u64) -> Option<InFlightWriteBytesCounter> {
95 let result = self.in_flight_write_bytes.fetch_update(
96 Ordering::Relaxed,
97 Ordering::Relaxed,
98 |current| {
99 if current + bytes > self.max_in_flight_write_bytes {
100 warn!(
101 "in-flight write bytes exceed the maximum limit {}, request with {} bytes will be limited",
102 self.max_in_flight_write_bytes,
103 bytes
104 );
105 return None;
106 }
107 Some(current + bytes)
108 },
109 );
110
111 match result {
112 Ok(_) => Some(InFlightWriteBytesCounter::new(
114 self.in_flight_write_bytes.clone(),
115 bytes,
116 )),
117 Err(_) => None,
119 }
120 }
121
122 #[allow(dead_code)]
124 pub fn in_flight_write_bytes(&self) -> u64 {
125 self.in_flight_write_bytes.load(Ordering::Relaxed)
126 }
127
128 fn insert_requests_data_size(&self, request: &InsertRequests) -> usize {
129 let mut size: usize = 0;
130 for insert in &request.inserts {
131 for column in &insert.columns {
132 if let Some(values) = &column.values {
133 size += self.size_of_column_values(values);
134 }
135 }
136 }
137 size
138 }
139
140 fn rows_insert_requests_data_size(&self, request: &RowInsertRequests) -> usize {
141 let mut size: usize = 0;
142 for insert in &request.inserts {
143 if let Some(rows) = &insert.rows {
144 for row in &rows.rows {
145 for value in &row.values {
146 if let Some(value) = &value.value_data {
147 size += self.size_of_value_data(value);
148 }
149 }
150 }
151 }
152 }
153 size
154 }
155
156 fn size_of_column_values(&self, values: &Values) -> usize {
157 let mut size: usize = 0;
158 size += values.i8_values.len() * size_of::<i32>();
159 size += values.i16_values.len() * size_of::<i32>();
160 size += values.i32_values.len() * size_of::<i32>();
161 size += values.i64_values.len() * size_of::<i64>();
162 size += values.u8_values.len() * size_of::<u32>();
163 size += values.u16_values.len() * size_of::<u32>();
164 size += values.u32_values.len() * size_of::<u32>();
165 size += values.u64_values.len() * size_of::<u64>();
166 size += values.f32_values.len() * size_of::<f32>();
167 size += values.f64_values.len() * size_of::<f64>();
168 size += values.bool_values.len() * size_of::<bool>();
169 size += values
170 .binary_values
171 .iter()
172 .map(|v| v.len() * size_of::<u8>())
173 .sum::<usize>();
174 size += values.string_values.iter().map(|v| v.len()).sum::<usize>();
175 size += values.date_values.len() * size_of::<i32>();
176 size += values.datetime_values.len() * size_of::<i64>();
177 size += values.timestamp_second_values.len() * size_of::<i64>();
178 size += values.timestamp_millisecond_values.len() * size_of::<i64>();
179 size += values.timestamp_microsecond_values.len() * size_of::<i64>();
180 size += values.timestamp_nanosecond_values.len() * size_of::<i64>();
181 size += values.time_second_values.len() * size_of::<i64>();
182 size += values.time_millisecond_values.len() * size_of::<i64>();
183 size += values.time_microsecond_values.len() * size_of::<i64>();
184 size += values.time_nanosecond_values.len() * size_of::<i64>();
185 size += values.interval_year_month_values.len() * size_of::<i64>();
186 size += values.interval_day_time_values.len() * size_of::<i64>();
187 size += values.interval_month_day_nano_values.len() * size_of::<IntervalMonthDayNano>();
188 size += values.decimal128_values.len() * size_of::<Decimal128>();
189 size
190 }
191
192 fn size_of_value_data(&self, value: &ValueData) -> usize {
193 match value {
194 ValueData::I8Value(_) => size_of::<i32>(),
195 ValueData::I16Value(_) => size_of::<i32>(),
196 ValueData::I32Value(_) => size_of::<i32>(),
197 ValueData::I64Value(_) => size_of::<i64>(),
198 ValueData::U8Value(_) => size_of::<u32>(),
199 ValueData::U16Value(_) => size_of::<u32>(),
200 ValueData::U32Value(_) => size_of::<u32>(),
201 ValueData::U64Value(_) => size_of::<u64>(),
202 ValueData::F32Value(_) => size_of::<f32>(),
203 ValueData::F64Value(_) => size_of::<f64>(),
204 ValueData::BoolValue(_) => size_of::<bool>(),
205 ValueData::BinaryValue(v) => v.len() * size_of::<u8>(),
206 ValueData::StringValue(v) => v.len(),
207 ValueData::DateValue(_) => size_of::<i32>(),
208 ValueData::DatetimeValue(_) => size_of::<i64>(),
209 ValueData::TimestampSecondValue(_) => size_of::<i64>(),
210 ValueData::TimestampMillisecondValue(_) => size_of::<i64>(),
211 ValueData::TimestampMicrosecondValue(_) => size_of::<i64>(),
212 ValueData::TimestampNanosecondValue(_) => size_of::<i64>(),
213 ValueData::TimeSecondValue(_) => size_of::<i64>(),
214 ValueData::TimeMillisecondValue(_) => size_of::<i64>(),
215 ValueData::TimeMicrosecondValue(_) => size_of::<i64>(),
216 ValueData::TimeNanosecondValue(_) => size_of::<i64>(),
217 ValueData::IntervalYearMonthValue(_) => size_of::<i32>(),
218 ValueData::IntervalDayTimeValue(_) => size_of::<i64>(),
219 ValueData::IntervalMonthDayNanoValue(_) => size_of::<IntervalMonthDayNano>(),
220 ValueData::Decimal128Value(_) => size_of::<Decimal128>(),
221 }
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use api::v1::column::Values;
228 use api::v1::greptime_request::Request;
229 use api::v1::{Column, InsertRequest};
230
231 use super::*;
232
233 fn generate_request(size: usize) -> Request {
234 let i8_values = vec![0; size / 4];
235 Request::Inserts(InsertRequests {
236 inserts: vec![InsertRequest {
237 columns: vec![Column {
238 values: Some(Values {
239 i8_values,
240 ..Default::default()
241 }),
242 ..Default::default()
243 }],
244 ..Default::default()
245 }],
246 })
247 }
248
249 #[tokio::test]
250 async fn test_limiter() {
251 let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
252 let tasks_count = 10;
253 let request_data_size = 100;
254 let mut handles = vec![];
255
256 for _ in 0..tasks_count {
258 let limiter = limiter_ref.clone();
259 let handle = tokio::spawn(async move {
260 let result = limiter.limit_request(&generate_request(request_data_size));
261 assert!(result.is_some());
262 });
263 handles.push(handle);
264 }
265
266 for handle in handles {
268 handle.await.unwrap();
269 }
270 }
271
272 #[test]
273 fn test_in_flight_write_bytes() {
274 let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
275 let req1 = generate_request(100);
276 let result1 = limiter_ref.limit_request(&req1);
277 assert!(result1.is_some());
278 assert_eq!(limiter_ref.in_flight_write_bytes(), 100);
279
280 let req2 = generate_request(200);
281 let result2 = limiter_ref.limit_request(&req2);
282 assert!(result2.is_some());
283 assert_eq!(limiter_ref.in_flight_write_bytes(), 300);
284
285 drop(result1.unwrap());
286 assert_eq!(limiter_ref.in_flight_write_bytes(), 200);
287
288 drop(result2.unwrap());
289 assert_eq!(limiter_ref.in_flight_write_bytes(), 0);
290 }
291}