1use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::Arc;
17
18use api::v1::column::Values;
19use api::v1::greptime_request::Request;
20use api::v1::value::ValueData;
21use api::v1::{
22 Decimal128, InsertRequests, IntervalMonthDayNano, RowInsertRequest, RowInsertRequests,
23};
24use common_telemetry::{debug, warn};
25use pipeline::ContextReq;
26
27pub(crate) type LimiterRef = Arc<Limiter>;
28
29pub(crate) struct Limiter {
31 max_in_flight_write_bytes: u64,
33
34 in_flight_write_bytes: Arc<AtomicU64>,
36}
37
38pub(crate) struct InFlightWriteBytesCounter {
40 in_flight_write_bytes: Arc<AtomicU64>,
42
43 processing_write_bytes: u64,
45}
46
47impl InFlightWriteBytesCounter {
48 pub fn new(in_flight_write_bytes: Arc<AtomicU64>, processing_write_bytes: u64) -> Self {
50 debug!(
51 "processing write bytes: {}, current in-flight write bytes: {}",
52 processing_write_bytes,
53 in_flight_write_bytes.load(Ordering::Relaxed)
54 );
55 Self {
56 in_flight_write_bytes,
57 processing_write_bytes,
58 }
59 }
60}
61
62impl Drop for InFlightWriteBytesCounter {
63 fn drop(&mut self) {
65 self.in_flight_write_bytes
66 .fetch_sub(self.processing_write_bytes, Ordering::Relaxed);
67 }
68}
69
70impl Limiter {
71 pub fn new(max_in_flight_write_bytes: u64) -> Self {
72 Self {
73 max_in_flight_write_bytes,
74 in_flight_write_bytes: Arc::new(AtomicU64::new(0)),
75 }
76 }
77
78 pub fn limit_request(&self, request: &Request) -> Option<InFlightWriteBytesCounter> {
79 let size = match request {
80 Request::Inserts(requests) => self.insert_requests_data_size(requests),
81 Request::RowInserts(requests) => {
82 self.rows_insert_requests_data_size(requests.inserts.iter())
83 }
84 _ => 0,
85 };
86 self.limit_in_flight_write_bytes(size as u64)
87 }
88
89 pub fn limit_row_inserts(
90 &self,
91 requests: &RowInsertRequests,
92 ) -> Option<InFlightWriteBytesCounter> {
93 let size = self.rows_insert_requests_data_size(requests.inserts.iter());
94 self.limit_in_flight_write_bytes(size as u64)
95 }
96
97 pub fn limit_ctx_req(&self, opt_req: &ContextReq) -> Option<InFlightWriteBytesCounter> {
98 let size = self.rows_insert_requests_data_size(opt_req.ref_all_req());
99 self.limit_in_flight_write_bytes(size as u64)
100 }
101
102 pub fn limit_in_flight_write_bytes(&self, bytes: u64) -> Option<InFlightWriteBytesCounter> {
105 let result = self.in_flight_write_bytes.fetch_update(
106 Ordering::Relaxed,
107 Ordering::Relaxed,
108 |current| {
109 if current + bytes > self.max_in_flight_write_bytes {
110 warn!(
111 "in-flight write bytes exceed the maximum limit {}, request with {} bytes will be limited",
112 self.max_in_flight_write_bytes,
113 bytes
114 );
115 return None;
116 }
117 Some(current + bytes)
118 },
119 );
120
121 match result {
122 Ok(_) => Some(InFlightWriteBytesCounter::new(
124 self.in_flight_write_bytes.clone(),
125 bytes,
126 )),
127 Err(_) => None,
129 }
130 }
131
132 #[allow(dead_code)]
134 pub fn in_flight_write_bytes(&self) -> u64 {
135 self.in_flight_write_bytes.load(Ordering::Relaxed)
136 }
137
138 fn insert_requests_data_size(&self, request: &InsertRequests) -> usize {
139 let mut size: usize = 0;
140 for insert in &request.inserts {
141 for column in &insert.columns {
142 if let Some(values) = &column.values {
143 size += self.size_of_column_values(values);
144 }
145 }
146 }
147 size
148 }
149
150 fn rows_insert_requests_data_size<'a>(
151 &self,
152 inserts: impl Iterator<Item = &'a RowInsertRequest>,
153 ) -> usize {
154 let mut size: usize = 0;
155 for insert in inserts {
156 if let Some(rows) = &insert.rows {
157 for row in &rows.rows {
158 for value in &row.values {
159 if let Some(value) = &value.value_data {
160 size += self.size_of_value_data(value);
161 }
162 }
163 }
164 }
165 }
166 size
167 }
168
169 fn size_of_column_values(&self, values: &Values) -> usize {
170 let mut size: usize = 0;
171 size += values.i8_values.len() * size_of::<i32>();
172 size += values.i16_values.len() * size_of::<i32>();
173 size += values.i32_values.len() * size_of::<i32>();
174 size += values.i64_values.len() * size_of::<i64>();
175 size += values.u8_values.len() * size_of::<u32>();
176 size += values.u16_values.len() * size_of::<u32>();
177 size += values.u32_values.len() * size_of::<u32>();
178 size += values.u64_values.len() * size_of::<u64>();
179 size += values.f32_values.len() * size_of::<f32>();
180 size += values.f64_values.len() * size_of::<f64>();
181 size += values.bool_values.len() * size_of::<bool>();
182 size += values
183 .binary_values
184 .iter()
185 .map(|v| v.len() * size_of::<u8>())
186 .sum::<usize>();
187 size += values.string_values.iter().map(|v| v.len()).sum::<usize>();
188 size += values.date_values.len() * size_of::<i32>();
189 size += values.datetime_values.len() * size_of::<i64>();
190 size += values.timestamp_second_values.len() * size_of::<i64>();
191 size += values.timestamp_millisecond_values.len() * size_of::<i64>();
192 size += values.timestamp_microsecond_values.len() * size_of::<i64>();
193 size += values.timestamp_nanosecond_values.len() * size_of::<i64>();
194 size += values.time_second_values.len() * size_of::<i64>();
195 size += values.time_millisecond_values.len() * size_of::<i64>();
196 size += values.time_microsecond_values.len() * size_of::<i64>();
197 size += values.time_nanosecond_values.len() * size_of::<i64>();
198 size += values.interval_year_month_values.len() * size_of::<i64>();
199 size += values.interval_day_time_values.len() * size_of::<i64>();
200 size += values.interval_month_day_nano_values.len() * size_of::<IntervalMonthDayNano>();
201 size += values.decimal128_values.len() * size_of::<Decimal128>();
202 size
203 }
204
205 fn size_of_value_data(&self, value: &ValueData) -> usize {
206 match value {
207 ValueData::I8Value(_) => size_of::<i32>(),
208 ValueData::I16Value(_) => size_of::<i32>(),
209 ValueData::I32Value(_) => size_of::<i32>(),
210 ValueData::I64Value(_) => size_of::<i64>(),
211 ValueData::U8Value(_) => size_of::<u32>(),
212 ValueData::U16Value(_) => size_of::<u32>(),
213 ValueData::U32Value(_) => size_of::<u32>(),
214 ValueData::U64Value(_) => size_of::<u64>(),
215 ValueData::F32Value(_) => size_of::<f32>(),
216 ValueData::F64Value(_) => size_of::<f64>(),
217 ValueData::BoolValue(_) => size_of::<bool>(),
218 ValueData::BinaryValue(v) => v.len() * size_of::<u8>(),
219 ValueData::StringValue(v) => v.len(),
220 ValueData::DateValue(_) => size_of::<i32>(),
221 ValueData::DatetimeValue(_) => size_of::<i64>(),
222 ValueData::TimestampSecondValue(_) => size_of::<i64>(),
223 ValueData::TimestampMillisecondValue(_) => size_of::<i64>(),
224 ValueData::TimestampMicrosecondValue(_) => size_of::<i64>(),
225 ValueData::TimestampNanosecondValue(_) => size_of::<i64>(),
226 ValueData::TimeSecondValue(_) => size_of::<i64>(),
227 ValueData::TimeMillisecondValue(_) => size_of::<i64>(),
228 ValueData::TimeMicrosecondValue(_) => size_of::<i64>(),
229 ValueData::TimeNanosecondValue(_) => size_of::<i64>(),
230 ValueData::IntervalYearMonthValue(_) => size_of::<i32>(),
231 ValueData::IntervalDayTimeValue(_) => size_of::<i64>(),
232 ValueData::IntervalMonthDayNanoValue(_) => size_of::<IntervalMonthDayNano>(),
233 ValueData::Decimal128Value(_) => size_of::<Decimal128>(),
234 }
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use api::v1::column::Values;
241 use api::v1::greptime_request::Request;
242 use api::v1::{Column, InsertRequest};
243
244 use super::*;
245
246 fn generate_request(size: usize) -> Request {
247 let i8_values = vec![0; size / 4];
248 Request::Inserts(InsertRequests {
249 inserts: vec![InsertRequest {
250 columns: vec![Column {
251 values: Some(Values {
252 i8_values,
253 ..Default::default()
254 }),
255 ..Default::default()
256 }],
257 ..Default::default()
258 }],
259 })
260 }
261
262 #[tokio::test]
263 async fn test_limiter() {
264 let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
265 let tasks_count = 10;
266 let request_data_size = 100;
267 let mut handles = vec![];
268
269 for _ in 0..tasks_count {
271 let limiter = limiter_ref.clone();
272 let handle = tokio::spawn(async move {
273 let result = limiter.limit_request(&generate_request(request_data_size));
274 assert!(result.is_some());
275 });
276 handles.push(handle);
277 }
278
279 for handle in handles {
281 handle.await.unwrap();
282 }
283 }
284
285 #[test]
286 fn test_in_flight_write_bytes() {
287 let limiter_ref: LimiterRef = Arc::new(Limiter::new(1024));
288 let req1 = generate_request(100);
289 let result1 = limiter_ref.limit_request(&req1);
290 assert!(result1.is_some());
291 assert_eq!(limiter_ref.in_flight_write_bytes(), 100);
292
293 let req2 = generate_request(200);
294 let result2 = limiter_ref.limit_request(&req2);
295 assert!(result2.is_some());
296 assert_eq!(limiter_ref.in_flight_write_bytes(), 300);
297
298 drop(result1.unwrap());
299 assert_eq!(limiter_ref.in_flight_write_bytes(), 200);
300
301 drop(result2.unwrap());
302 assert_eq!(limiter_ref.in_flight_write_bytes(), 0);
303 }
304}