1use std::collections::btree_map::Entry;
16use std::collections::BTreeMap;
17use std::ops::Deref;
18use std::slice;
19
20use api::prom_store::remote::Sample;
21use api::v1::RowInsertRequest;
22use bytes::{Buf, Bytes};
23use common_query::prelude::{GREPTIME_TIMESTAMP, GREPTIME_VALUE};
24use common_telemetry::debug;
25use pipeline::{
26 ContextReq, GreptimePipelineParams, PipelineContext, PipelineDefinition, PipelineMap, Value,
27};
28use prost::encoding::message::merge;
29use prost::encoding::{decode_key, decode_varint, WireType};
30use prost::DecodeError;
31use session::context::QueryContextRef;
32use snafu::OptionExt;
33
34use crate::error::InternalSnafu;
35use crate::http::event::PipelineIngestRequest;
36use crate::http::PromValidationMode;
37use crate::pipeline::run_pipeline;
38use crate::prom_row_builder::TablesBuilder;
39use crate::prom_store::METRIC_NAME_LABEL_BYTES;
40use crate::query_handler::PipelineHandlerRef;
41use crate::repeated_field::{Clear, RepeatedField};
42
43impl Clear for Sample {
44 fn clear(&mut self) {
45 self.timestamp = 0;
46 self.value = 0.0;
47 }
48}
49
50#[derive(Default, Clone, Debug)]
51pub struct PromLabel {
52 pub name: Bytes,
53 pub value: Bytes,
54}
55
56impl Clear for PromLabel {
57 fn clear(&mut self) {
58 self.name.clear();
59 self.value.clear();
60 }
61}
62
63impl PromLabel {
64 pub fn merge_field(
65 &mut self,
66 tag: u32,
67 wire_type: WireType,
68 buf: &mut Bytes,
69 ) -> Result<(), DecodeError> {
70 const STRUCT_NAME: &str = "PromLabel";
71 match tag {
72 1u32 => {
73 let value = &mut self.name;
75 merge_bytes(value, buf).map_err(|mut error| {
76 error.push(STRUCT_NAME, "name");
77 error
78 })
79 }
80 2u32 => {
81 let value = &mut self.value;
83 merge_bytes(value, buf).map_err(|mut error| {
84 error.push(STRUCT_NAME, "value");
85 error
86 })
87 }
88 _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
89 }
90 }
91}
92
93#[inline(always)]
94fn copy_to_bytes(data: &mut Bytes, len: usize) -> Bytes {
95 if len == data.remaining() {
96 std::mem::replace(data, Bytes::new())
97 } else {
98 let ret = split_to(data, len);
99 data.advance(len);
100 ret
101 }
102}
103
104#[inline(always)]
110fn split_to(buf: &mut Bytes, end: usize) -> Bytes {
111 let len = buf.len();
112 assert!(
113 end <= len,
114 "range end out of bounds: {:?} <= {:?}",
115 end,
116 len,
117 );
118
119 if end == 0 {
120 return Bytes::new();
121 }
122
123 let ptr = buf.as_ptr();
124 let x = unsafe { slice::from_raw_parts(ptr, end) };
125 Bytes::from_static(x)
127}
128
129#[inline(always)]
133fn merge_bytes(value: &mut Bytes, buf: &mut Bytes) -> Result<(), DecodeError> {
134 let len = decode_varint(buf)?;
135 if len > buf.remaining() as u64 {
136 return Err(DecodeError::new(format!(
137 "buffer underflow, len: {}, remaining: {}",
138 len,
139 buf.remaining()
140 )));
141 }
142 *value = copy_to_bytes(buf, len as usize);
143 Ok(())
144}
145
146#[derive(Default, Debug)]
147pub struct PromTimeSeries {
148 pub table_name: String,
149 pub labels: RepeatedField<PromLabel>,
150 pub samples: RepeatedField<Sample>,
151}
152
153impl Clear for PromTimeSeries {
154 fn clear(&mut self) {
155 self.table_name.clear();
156 self.labels.clear();
157 self.samples.clear();
158 }
159}
160
161impl PromTimeSeries {
162 pub fn merge_field(
163 &mut self,
164 tag: u32,
165 wire_type: WireType,
166 buf: &mut Bytes,
167 prom_validation_mode: PromValidationMode,
168 ) -> Result<(), DecodeError> {
169 const STRUCT_NAME: &str = "PromTimeSeries";
170 match tag {
171 1u32 => {
172 let label = self.labels.push_default();
174
175 let len = decode_varint(buf).map_err(|mut error| {
176 error.push(STRUCT_NAME, "labels");
177 error
178 })?;
179 let remaining = buf.remaining();
180 if len > remaining as u64 {
181 return Err(DecodeError::new("buffer underflow"));
182 }
183
184 let limit = remaining - len as usize;
185 while buf.remaining() > limit {
186 let (tag, wire_type) = decode_key(buf)?;
187 label.merge_field(tag, wire_type, buf)?;
188 }
189 if buf.remaining() != limit {
190 return Err(DecodeError::new("delimited length exceeded"));
191 }
192 if label.name.deref() == METRIC_NAME_LABEL_BYTES {
193 self.table_name = decode_string(&label.value, prom_validation_mode)?;
194 self.labels.truncate(self.labels.len() - 1); }
196 Ok(())
197 }
198 2u32 => {
199 let sample = self.samples.push_default();
200 merge(WireType::LengthDelimited, sample, buf, Default::default()).map_err(
201 |mut error| {
202 error.push(STRUCT_NAME, "samples");
203 error
204 },
205 )?;
206 Ok(())
207 }
208 3u32 => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
210 _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
211 }
212 }
213
214 fn add_to_table_data(
215 &mut self,
216 table_builders: &mut TablesBuilder,
217 prom_validation_mode: PromValidationMode,
218 ) -> Result<(), DecodeError> {
219 let label_num = self.labels.len();
220 let row_num = self.samples.len();
221 let table_data = table_builders.get_or_create_table_builder(
222 std::mem::take(&mut self.table_name),
223 label_num,
224 row_num,
225 );
226 table_data.add_labels_and_samples(
227 self.labels.as_slice(),
228 self.samples.as_slice(),
229 prom_validation_mode,
230 )?;
231
232 Ok(())
233 }
234}
235
236pub(crate) fn decode_string(
238 bytes: &Bytes,
239 mode: PromValidationMode,
240) -> Result<String, DecodeError> {
241 let result = match mode {
242 PromValidationMode::Strict => match String::from_utf8(bytes.to_vec()) {
243 Ok(s) => s,
244 Err(e) => {
245 debug!(
246 "Invalid UTF-8 string value: {:?}, error: {:?}",
247 &bytes[..],
248 e
249 );
250 return Err(DecodeError::new("invalid utf-8"));
251 }
252 },
253 PromValidationMode::Lossy => String::from_utf8_lossy(bytes).to_string(),
254 PromValidationMode::Unchecked => unsafe { String::from_utf8_unchecked(bytes.to_vec()) },
255 };
256 Ok(result)
257}
258
259#[derive(Default, Debug)]
260pub struct PromWriteRequest {
261 table_data: TablesBuilder,
262 series: PromTimeSeries,
263}
264
265impl Clear for PromWriteRequest {
266 fn clear(&mut self) {
267 self.table_data.clear();
268 }
269}
270
271impl PromWriteRequest {
272 pub fn as_row_insert_requests(&mut self) -> Vec<RowInsertRequest> {
273 self.table_data.as_insert_requests()
274 }
275
276 pub fn merge(
278 &mut self,
279 mut buf: Bytes,
280 prom_validation_mode: PromValidationMode,
281 processor: &mut PromSeriesProcessor,
282 ) -> Result<(), DecodeError> {
283 const STRUCT_NAME: &str = "PromWriteRequest";
284 while buf.has_remaining() {
285 let (tag, wire_type) = decode_key(&mut buf)?;
286 assert_eq!(WireType::LengthDelimited, wire_type);
287 match tag {
288 1u32 => {
289 let len = decode_varint(&mut buf).map_err(|mut e| {
291 e.push(STRUCT_NAME, "timeseries");
292 e
293 })?;
294 let remaining = buf.remaining();
295 if len > remaining as u64 {
296 return Err(DecodeError::new("buffer underflow"));
297 }
298
299 let limit = remaining - len as usize;
300 while buf.remaining() > limit {
301 let (tag, wire_type) = decode_key(&mut buf)?;
302 self.series
303 .merge_field(tag, wire_type, &mut buf, prom_validation_mode)?;
304 }
305 if buf.remaining() != limit {
306 return Err(DecodeError::new("delimited length exceeded"));
307 }
308
309 if processor.use_pipeline {
310 processor.consume_series_to_pipeline_map(&mut self.series)?;
311 } else {
312 self.series
313 .add_to_table_data(&mut self.table_data, prom_validation_mode)?;
314 }
315
316 self.series.labels.clear();
318 self.series.samples.clear();
319 }
320 3u32 => {
321 prost::encoding::skip_field(wire_type, tag, &mut buf, Default::default())?;
323 }
324 _ => prost::encoding::skip_field(wire_type, tag, &mut buf, Default::default())?,
325 }
326 }
327
328 Ok(())
329 }
330}
331
332pub struct PromSeriesProcessor {
340 pub(crate) use_pipeline: bool,
341 pub(crate) table_values: BTreeMap<String, Vec<PipelineMap>>,
342
343 pub(crate) pipeline_handler: Option<PipelineHandlerRef>,
345 pub(crate) query_ctx: Option<QueryContextRef>,
346 pub(crate) pipeline_def: Option<PipelineDefinition>,
347}
348
349impl PromSeriesProcessor {
350 pub fn default_processor() -> Self {
351 Self {
352 use_pipeline: false,
353 table_values: BTreeMap::new(),
354 pipeline_handler: None,
355 query_ctx: None,
356 pipeline_def: None,
357 }
358 }
359
360 pub fn set_pipeline(
361 &mut self,
362 handler: PipelineHandlerRef,
363 query_ctx: QueryContextRef,
364 pipeline_def: PipelineDefinition,
365 ) {
366 self.use_pipeline = true;
367 self.pipeline_handler = Some(handler);
368 self.query_ctx = Some(query_ctx);
369 self.pipeline_def = Some(pipeline_def);
370 }
371
372 pub(crate) fn consume_series_to_pipeline_map(
374 &mut self,
375 series: &mut PromTimeSeries,
376 ) -> Result<(), DecodeError> {
377 let mut vec_pipeline_map: Vec<PipelineMap> = Vec::new();
378 let mut pipeline_map = PipelineMap::new();
379 for l in series.labels.iter() {
380 let name = String::from_utf8(l.name.to_vec())
381 .map_err(|_| DecodeError::new("invalid utf-8"))?;
382 let value = String::from_utf8(l.value.to_vec())
383 .map_err(|_| DecodeError::new("invalid utf-8"))?;
384 pipeline_map.insert(name, Value::String(value));
385 }
386
387 let one_sample = series.samples.len() == 1;
388
389 for s in series.samples.iter() {
390 let timestamp = s.timestamp;
391 pipeline_map.insert(GREPTIME_TIMESTAMP.to_string(), Value::Int64(timestamp));
392 pipeline_map.insert(GREPTIME_VALUE.to_string(), Value::Float64(s.value));
393 if one_sample {
394 vec_pipeline_map.push(pipeline_map);
395 break;
396 } else {
397 vec_pipeline_map.push(pipeline_map.clone());
398 }
399 }
400
401 let table_name = std::mem::take(&mut series.table_name);
402 match self.table_values.entry(table_name) {
403 Entry::Occupied(mut occupied_entry) => {
404 occupied_entry.get_mut().append(&mut vec_pipeline_map);
405 }
406 Entry::Vacant(vacant_entry) => {
407 vacant_entry.insert(vec_pipeline_map);
408 }
409 }
410
411 Ok(())
412 }
413
414 pub(crate) async fn exec_pipeline(&mut self) -> crate::error::Result<ContextReq> {
415 let handler = self.pipeline_handler.as_ref().context(InternalSnafu {
417 err_msg: "pipeline handler is not set",
418 })?;
419 let pipeline_def = self.pipeline_def.as_ref().context(InternalSnafu {
420 err_msg: "pipeline definition is not set",
421 })?;
422 let pipeline_param = GreptimePipelineParams::default();
423 let query_ctx = self.query_ctx.as_ref().context(InternalSnafu {
424 err_msg: "query context is not set",
425 })?;
426
427 let pipeline_ctx = PipelineContext::new(pipeline_def, &pipeline_param, query_ctx.channel());
428
429 let mut req = ContextReq::default();
431 for (table_name, pipeline_maps) in self.table_values.iter_mut() {
432 let pipeline_req = PipelineIngestRequest {
433 table: table_name.clone(),
434 values: pipeline_maps.clone(),
435 };
436 let row_req =
437 run_pipeline(handler, &pipeline_ctx, pipeline_req, query_ctx, true).await?;
438 req.merge(row_req);
439 }
440
441 Ok(req)
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use std::collections::HashMap;
448
449 use api::prom_store::remote::WriteRequest;
450 use api::v1::{Row, RowInsertRequests, Rows};
451 use bytes::Bytes;
452 use prost::Message;
453
454 use crate::http::PromValidationMode;
455 use crate::prom_store::to_grpc_row_insert_requests;
456 use crate::proto::{decode_string, PromSeriesProcessor, PromWriteRequest};
457 use crate::repeated_field::Clear;
458
459 fn sort_rows(rows: Rows) -> Rows {
460 let permutation =
461 permutation::sort_by_key(&rows.schema, |schema| schema.column_name.clone());
462 let schema = permutation.apply_slice(&rows.schema);
463 let mut inner_rows = vec![];
464 for row in rows.rows {
465 let values = permutation.apply_slice(&row.values);
466 inner_rows.push(Row { values });
467 }
468 Rows {
469 schema,
470 rows: inner_rows,
471 }
472 }
473
474 fn check_deserialized(
475 prom_write_request: &mut PromWriteRequest,
476 data: &Bytes,
477 expected_samples: usize,
478 expected_rows: &RowInsertRequests,
479 ) {
480 let mut p = PromSeriesProcessor::default_processor();
481 prom_write_request.clear();
482 prom_write_request
483 .merge(data.clone(), PromValidationMode::Strict, &mut p)
484 .unwrap();
485
486 let req = prom_write_request.as_row_insert_requests();
487 let samples = req
488 .iter()
489 .filter_map(|r| r.rows.as_ref().map(|r| r.rows.len()))
490 .sum::<usize>();
491 let prom_rows = RowInsertRequests { inserts: req };
492
493 assert_eq!(expected_samples, samples);
494 assert_eq!(expected_rows.inserts.len(), prom_rows.inserts.len());
495
496 let expected_rows_map = expected_rows
497 .inserts
498 .iter()
499 .map(|insert| (insert.table_name.clone(), insert.rows.clone().unwrap()))
500 .collect::<HashMap<_, _>>();
501
502 for r in &prom_rows.inserts {
503 let expected_rows = expected_rows_map.get(&r.table_name).unwrap().clone();
505 assert_eq!(sort_rows(expected_rows), sort_rows(r.rows.clone().unwrap()));
506 }
507 }
508
509 #[test]
511 fn test_decode_write_request() {
512 let mut d = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
513 d.push("benches");
514 d.push("write_request.pb.data");
515 let data = Bytes::from(std::fs::read(d).unwrap());
516
517 let (expected_rows, expected_samples) =
518 to_grpc_row_insert_requests(&WriteRequest::decode(data.clone()).unwrap()).unwrap();
519
520 let mut prom_write_request = PromWriteRequest::default();
521 for _ in 0..3 {
522 check_deserialized(
523 &mut prom_write_request,
524 &data,
525 expected_samples,
526 &expected_rows,
527 );
528 }
529 }
530
531 #[test]
532 fn test_decode_string_strict_mode_valid_utf8() {
533 let valid_utf8 = Bytes::from("hello world");
534 let result = decode_string(&valid_utf8, PromValidationMode::Strict);
535 assert!(result.is_ok());
536 assert_eq!(result.unwrap(), "hello world");
537 }
538
539 #[test]
540 fn test_decode_string_strict_mode_empty() {
541 let empty = Bytes::new();
542 let result = decode_string(&empty, PromValidationMode::Strict);
543 assert!(result.is_ok());
544 assert_eq!(result.unwrap(), "");
545 }
546
547 #[test]
548 fn test_decode_string_strict_mode_unicode() {
549 let unicode = Bytes::from("Hello ไธ็ ๐");
550 let result = decode_string(&unicode, PromValidationMode::Strict);
551 assert!(result.is_ok());
552 assert_eq!(result.unwrap(), "Hello ไธ็ ๐");
553 }
554
555 #[test]
556 fn test_decode_string_strict_mode_invalid_utf8() {
557 let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
559 let result = decode_string(&invalid_utf8, PromValidationMode::Strict);
560 assert!(result.is_err());
561 assert_eq!(
562 result.unwrap_err().to_string(),
563 "failed to decode Protobuf message: invalid utf-8"
564 );
565 }
566
567 #[test]
568 fn test_decode_string_strict_mode_incomplete_utf8() {
569 let incomplete_utf8 = Bytes::from(vec![0xC2]); let result = decode_string(&incomplete_utf8, PromValidationMode::Strict);
572 assert!(result.is_err());
573 assert_eq!(
574 result.unwrap_err().to_string(),
575 "failed to decode Protobuf message: invalid utf-8"
576 );
577 }
578
579 #[test]
580 fn test_decode_string_lossy_mode_valid_utf8() {
581 let valid_utf8 = Bytes::from("hello world");
582 let result = decode_string(&valid_utf8, PromValidationMode::Lossy);
583 assert!(result.is_ok());
584 assert_eq!(result.unwrap(), "hello world");
585 }
586
587 #[test]
588 fn test_decode_string_lossy_mode_empty() {
589 let empty = Bytes::new();
590 let result = decode_string(&empty, PromValidationMode::Lossy);
591 assert!(result.is_ok());
592 assert_eq!(result.unwrap(), "");
593 }
594
595 #[test]
596 fn test_decode_string_lossy_mode_unicode() {
597 let unicode = Bytes::from("Hello ไธ็ ๐");
598 let result = decode_string(&unicode, PromValidationMode::Lossy);
599 assert!(result.is_ok());
600 assert_eq!(result.unwrap(), "Hello ไธ็ ๐");
601 }
602
603 #[test]
604 fn test_decode_string_lossy_mode_invalid_utf8() {
605 let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
607 let result = decode_string(&invalid_utf8, PromValidationMode::Lossy);
608 assert!(result.is_ok());
609 assert_eq!(result.unwrap(), "๏ฟฝ๏ฟฝ๏ฟฝ");
611 }
612
613 #[test]
614 fn test_decode_string_lossy_mode_mixed_valid_invalid() {
615 let mut mixed = Vec::new();
617 mixed.extend_from_slice(b"hello");
618 mixed.push(0xFF); mixed.extend_from_slice(b"world");
620 let mixed_utf8 = Bytes::from(mixed);
621
622 let result = decode_string(&mixed_utf8, PromValidationMode::Lossy);
623 assert!(result.is_ok());
624 assert_eq!(result.unwrap(), "hello๏ฟฝworld");
625 }
626
627 #[test]
628 fn test_decode_string_unchecked_mode_valid_utf8() {
629 let valid_utf8 = Bytes::from("hello world");
630 let result = decode_string(&valid_utf8, PromValidationMode::Unchecked);
631 assert!(result.is_ok());
632 assert_eq!(result.unwrap(), "hello world");
633 }
634
635 #[test]
636 fn test_decode_string_unchecked_mode_empty() {
637 let empty = Bytes::new();
638 let result = decode_string(&empty, PromValidationMode::Unchecked);
639 assert!(result.is_ok());
640 assert_eq!(result.unwrap(), "");
641 }
642
643 #[test]
644 fn test_decode_string_unchecked_mode_unicode() {
645 let unicode = Bytes::from("Hello ไธ็ ๐");
646 let result = decode_string(&unicode, PromValidationMode::Unchecked);
647 assert!(result.is_ok());
648 assert_eq!(result.unwrap(), "Hello ไธ็ ๐");
649 }
650
651 #[test]
652 fn test_decode_string_unchecked_mode_invalid_utf8() {
653 let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
655 let result = decode_string(&invalid_utf8, PromValidationMode::Unchecked);
656 assert!(result.is_ok());
658 let _string = result.unwrap();
661 }
662
663 #[test]
664 fn test_decode_string_all_modes_ascii() {
665 let ascii = Bytes::from("simple_ascii_123");
666
667 let strict_result = decode_string(&ascii, PromValidationMode::Strict).unwrap();
669 let lossy_result = decode_string(&ascii, PromValidationMode::Lossy).unwrap();
670 let unchecked_result = decode_string(&ascii, PromValidationMode::Unchecked).unwrap();
671
672 assert_eq!(strict_result, "simple_ascii_123");
673 assert_eq!(lossy_result, "simple_ascii_123");
674 assert_eq!(unchecked_result, "simple_ascii_123");
675 assert_eq!(strict_result, lossy_result);
676 assert_eq!(lossy_result, unchecked_result);
677 }
678}