1use std::collections::btree_map::Entry;
16use std::collections::BTreeMap;
17use std::ops::Deref;
18use std::slice;
19
20use api::prom_store::remote::Sample;
21use bytes::{Buf, Bytes};
22use common_query::prelude::{GREPTIME_TIMESTAMP, GREPTIME_VALUE};
23use common_telemetry::warn;
24use pipeline::{ContextReq, GreptimePipelineParams, PipelineContext, PipelineDefinition};
25use prost::encoding::message::merge;
26use prost::encoding::{decode_key, decode_varint, WireType};
27use prost::DecodeError;
28use session::context::QueryContextRef;
29use snafu::OptionExt;
30use vrl::prelude::NotNan;
31use vrl::value::{KeyString, Value as VrlValue};
32
33use crate::error::InternalSnafu;
34use crate::http::event::PipelineIngestRequest;
35use crate::http::PromValidationMode;
36use crate::pipeline::run_pipeline;
37use crate::prom_row_builder::{PromCtx, TablesBuilder};
38use crate::prom_store::{
39 DATABASE_LABEL_BYTES, METRIC_NAME_LABEL_BYTES, PHYSICAL_TABLE_LABEL_BYTES,
40};
41use crate::query_handler::PipelineHandlerRef;
42use crate::repeated_field::{Clear, RepeatedField};
43
44impl Clear for Sample {
45 fn clear(&mut self) {
46 self.timestamp = 0;
47 self.value = 0.0;
48 }
49}
50
51#[derive(Default, Clone, Debug)]
52pub struct PromLabel {
53 pub name: Bytes,
54 pub value: Bytes,
55}
56
57impl Clear for PromLabel {
58 fn clear(&mut self) {
59 self.name.clear();
60 self.value.clear();
61 }
62}
63
64impl PromLabel {
65 pub fn merge_field(
66 &mut self,
67 tag: u32,
68 wire_type: WireType,
69 buf: &mut Bytes,
70 ) -> Result<(), DecodeError> {
71 const STRUCT_NAME: &str = "PromLabel";
72 match tag {
73 1u32 => {
74 let value = &mut self.name;
76 merge_bytes(value, buf).map_err(|mut error| {
77 error.push(STRUCT_NAME, "name");
78 error
79 })
80 }
81 2u32 => {
82 let value = &mut self.value;
84 merge_bytes(value, buf).map_err(|mut error| {
85 error.push(STRUCT_NAME, "value");
86 error
87 })
88 }
89 _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
90 }
91 }
92}
93
94#[inline(always)]
95fn copy_to_bytes(data: &mut Bytes, len: usize) -> Bytes {
96 if len == data.remaining() {
97 std::mem::replace(data, Bytes::new())
98 } else {
99 let ret = split_to(data, len);
100 data.advance(len);
101 ret
102 }
103}
104
105#[inline(always)]
111fn split_to(buf: &mut Bytes, end: usize) -> Bytes {
112 let len = buf.len();
113 assert!(
114 end <= len,
115 "range end out of bounds: {:?} <= {:?}",
116 end,
117 len,
118 );
119
120 if end == 0 {
121 return Bytes::new();
122 }
123
124 let ptr = buf.as_ptr();
125 let x = unsafe { slice::from_raw_parts(ptr, end) };
126 Bytes::from_static(x)
128}
129
130#[inline(always)]
134fn merge_bytes(value: &mut Bytes, buf: &mut Bytes) -> Result<(), DecodeError> {
135 let len = decode_varint(buf)?;
136 if len > buf.remaining() as u64 {
137 return Err(DecodeError::new(format!(
138 "buffer underflow, len: {}, remaining: {}",
139 len,
140 buf.remaining()
141 )));
142 }
143 *value = copy_to_bytes(buf, len as usize);
144 Ok(())
145}
146
147#[derive(Default, Debug)]
148pub struct PromTimeSeries {
149 pub table_name: String,
150 pub schema: Option<String>,
152 pub physical_table: Option<String>,
154
155 pub labels: RepeatedField<PromLabel>,
156 pub samples: RepeatedField<Sample>,
157}
158
159impl Clear for PromTimeSeries {
160 fn clear(&mut self) {
161 self.table_name.clear();
162 self.labels.clear();
163 self.samples.clear();
164 }
165}
166
167impl PromTimeSeries {
168 pub fn merge_field(
169 &mut self,
170 tag: u32,
171 wire_type: WireType,
172 buf: &mut Bytes,
173 prom_validation_mode: PromValidationMode,
174 ) -> Result<(), DecodeError> {
175 const STRUCT_NAME: &str = "PromTimeSeries";
176 match tag {
177 1u32 => {
178 let label = self.labels.push_default();
180
181 let len = decode_varint(buf).map_err(|mut error| {
182 error.push(STRUCT_NAME, "labels");
183 error
184 })?;
185 let remaining = buf.remaining();
186 if len > remaining as u64 {
187 return Err(DecodeError::new("buffer underflow"));
188 }
189
190 let limit = remaining - len as usize;
191 while buf.remaining() > limit {
192 let (tag, wire_type) = decode_key(buf)?;
193 label.merge_field(tag, wire_type, buf)?;
194 }
195 if buf.remaining() != limit {
196 return Err(DecodeError::new("delimited length exceeded"));
197 }
198
199 match label.name.deref() {
200 METRIC_NAME_LABEL_BYTES => {
201 self.table_name = prom_validation_mode.decode_string(&label.value)?;
202 self.labels.truncate(self.labels.len() - 1); }
204 DATABASE_LABEL_BYTES => {
205 self.schema = Some(prom_validation_mode.decode_string(&label.value)?);
206 self.labels.truncate(self.labels.len() - 1); }
208 PHYSICAL_TABLE_LABEL_BYTES => {
209 self.physical_table =
210 Some(prom_validation_mode.decode_string(&label.value)?);
211 self.labels.truncate(self.labels.len() - 1); }
213 _ => {}
214 }
215
216 Ok(())
217 }
218 2u32 => {
219 let sample = self.samples.push_default();
220 merge(WireType::LengthDelimited, sample, buf, Default::default()).map_err(
221 |mut error| {
222 error.push(STRUCT_NAME, "samples");
223 error
224 },
225 )?;
226 Ok(())
227 }
228 3u32 => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
230 _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
231 }
232 }
233
234 fn add_to_table_data(
235 &mut self,
236 table_builders: &mut TablesBuilder,
237 prom_validation_mode: PromValidationMode,
238 ) -> Result<(), DecodeError> {
239 let label_num = self.labels.len();
240 let row_num = self.samples.len();
241
242 let prom_ctx = PromCtx {
243 schema: self.schema.take(),
244 physical_table: self.physical_table.take(),
245 };
246
247 let table_data = table_builders.get_or_create_table_builder(
248 prom_ctx,
249 std::mem::take(&mut self.table_name),
250 label_num,
251 row_num,
252 );
253 table_data.add_labels_and_samples(
254 self.labels.as_slice(),
255 self.samples.as_slice(),
256 prom_validation_mode,
257 )?;
258
259 Ok(())
260 }
261}
262
263#[derive(Default, Debug)]
264pub struct PromWriteRequest {
265 pub(crate) table_data: TablesBuilder,
266 series: PromTimeSeries,
267}
268
269impl Clear for PromWriteRequest {
270 fn clear(&mut self) {
271 self.table_data.clear();
272 }
273}
274
275impl PromWriteRequest {
276 pub fn as_row_insert_requests(&mut self) -> ContextReq {
277 self.table_data.as_insert_requests()
278 }
279
280 pub fn merge(
282 &mut self,
283 mut buf: Bytes,
284 prom_validation_mode: PromValidationMode,
285 processor: &mut PromSeriesProcessor,
286 ) -> Result<(), DecodeError> {
287 const STRUCT_NAME: &str = "PromWriteRequest";
288 while buf.has_remaining() {
289 let (tag, wire_type) = decode_key(&mut buf)?;
290 assert_eq!(WireType::LengthDelimited, wire_type);
291 match tag {
292 1u32 => {
293 let len = decode_varint(&mut buf).map_err(|mut e| {
295 e.push(STRUCT_NAME, "timeseries");
296 e
297 })?;
298 let remaining = buf.remaining();
299 if len > remaining as u64 {
300 return Err(DecodeError::new("buffer underflow"));
301 }
302
303 let limit = remaining - len as usize;
304 while buf.remaining() > limit {
305 let (tag, wire_type) = decode_key(&mut buf)?;
306 self.series
307 .merge_field(tag, wire_type, &mut buf, prom_validation_mode)?;
308 }
309 if buf.remaining() != limit {
310 return Err(DecodeError::new("delimited length exceeded"));
311 }
312
313 if processor.use_pipeline {
314 processor.consume_series_to_pipeline_map(
315 &mut self.series,
316 prom_validation_mode,
317 )?;
318 } else {
319 self.series
320 .add_to_table_data(&mut self.table_data, prom_validation_mode)?;
321 }
322
323 self.series.labels.clear();
325 self.series.samples.clear();
326 }
327 3u32 => {
328 prost::encoding::skip_field(wire_type, tag, &mut buf, Default::default())?;
330 }
331 _ => prost::encoding::skip_field(wire_type, tag, &mut buf, Default::default())?,
332 }
333 }
334
335 Ok(())
336 }
337}
338
339pub struct PromSeriesProcessor {
347 pub(crate) use_pipeline: bool,
348 pub(crate) table_values: BTreeMap<String, Vec<VrlValue>>,
349
350 pub(crate) pipeline_handler: Option<PipelineHandlerRef>,
352 pub(crate) query_ctx: Option<QueryContextRef>,
353 pub(crate) pipeline_def: Option<PipelineDefinition>,
354}
355
356impl PromSeriesProcessor {
357 pub fn default_processor() -> Self {
358 Self {
359 use_pipeline: false,
360 table_values: BTreeMap::new(),
361 pipeline_handler: None,
362 query_ctx: None,
363 pipeline_def: None,
364 }
365 }
366
367 pub fn set_pipeline(
368 &mut self,
369 handler: PipelineHandlerRef,
370 query_ctx: QueryContextRef,
371 pipeline_def: PipelineDefinition,
372 ) {
373 self.use_pipeline = true;
374 self.pipeline_handler = Some(handler);
375 self.query_ctx = Some(query_ctx);
376 self.pipeline_def = Some(pipeline_def);
377 }
378
379 pub(crate) fn consume_series_to_pipeline_map(
381 &mut self,
382 series: &mut PromTimeSeries,
383 prom_validation_mode: PromValidationMode,
384 ) -> Result<(), DecodeError> {
385 let mut vec_pipeline_map = Vec::new();
386 let mut pipeline_map = BTreeMap::new();
387 for l in series.labels.iter() {
388 let name = prom_validation_mode.decode_string(&l.name)?;
389 let value = prom_validation_mode.decode_string(&l.value)?;
390 pipeline_map.insert(KeyString::from(name), VrlValue::Bytes(value.into()));
391 }
392
393 let one_sample = series.samples.len() == 1;
394
395 for s in series.samples.iter() {
396 let Ok(value) = NotNan::new(s.value) else {
397 warn!("Invalid float value: {}", s.value);
398 continue;
399 };
400
401 let timestamp = s.timestamp;
402 pipeline_map.insert(
403 KeyString::from(GREPTIME_TIMESTAMP),
404 VrlValue::Integer(timestamp),
405 );
406 pipeline_map.insert(KeyString::from(GREPTIME_VALUE), VrlValue::Float(value));
407 if one_sample {
408 vec_pipeline_map.push(VrlValue::Object(pipeline_map));
409 break;
410 } else {
411 vec_pipeline_map.push(VrlValue::Object(pipeline_map.clone()));
412 }
413 }
414
415 let table_name = std::mem::take(&mut series.table_name);
416 match self.table_values.entry(table_name) {
417 Entry::Occupied(mut occupied_entry) => {
418 occupied_entry.get_mut().append(&mut vec_pipeline_map);
419 }
420 Entry::Vacant(vacant_entry) => {
421 vacant_entry.insert(vec_pipeline_map);
422 }
423 }
424
425 Ok(())
426 }
427
428 pub(crate) async fn exec_pipeline(&mut self) -> crate::error::Result<ContextReq> {
429 let handler = self.pipeline_handler.as_ref().context(InternalSnafu {
431 err_msg: "pipeline handler is not set",
432 })?;
433 let pipeline_def = self.pipeline_def.as_ref().context(InternalSnafu {
434 err_msg: "pipeline definition is not set",
435 })?;
436 let pipeline_param = GreptimePipelineParams::default();
437 let query_ctx = self.query_ctx.as_ref().context(InternalSnafu {
438 err_msg: "query context is not set",
439 })?;
440
441 let pipeline_ctx = PipelineContext::new(pipeline_def, &pipeline_param, query_ctx.channel());
442
443 let mut req = ContextReq::default();
445 let table_values = std::mem::take(&mut self.table_values);
446 for (table_name, pipeline_maps) in table_values.into_iter() {
447 let pipeline_req = PipelineIngestRequest {
448 table: table_name,
449 values: pipeline_maps,
450 };
451 let row_req =
452 run_pipeline(handler, &pipeline_ctx, pipeline_req, query_ctx, true).await?;
453 req.merge(row_req);
454 }
455
456 Ok(req)
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use std::collections::HashMap;
463
464 use api::prom_store::remote::WriteRequest;
465 use api::v1::{Row, RowInsertRequests, Rows};
466 use bytes::Bytes;
467 use prost::Message;
468
469 use crate::http::PromValidationMode;
470 use crate::prom_store::to_grpc_row_insert_requests;
471 use crate::proto::{PromSeriesProcessor, PromWriteRequest};
472 use crate::repeated_field::Clear;
473
474 fn sort_rows(rows: Rows) -> Rows {
475 let permutation =
476 permutation::sort_by_key(&rows.schema, |schema| schema.column_name.clone());
477 let schema = permutation.apply_slice(&rows.schema);
478 let mut inner_rows = vec![];
479 for row in rows.rows {
480 let values = permutation.apply_slice(&row.values);
481 inner_rows.push(Row { values });
482 }
483 Rows {
484 schema,
485 rows: inner_rows,
486 }
487 }
488
489 fn check_deserialized(
490 prom_write_request: &mut PromWriteRequest,
491 data: &Bytes,
492 expected_samples: usize,
493 expected_rows: &RowInsertRequests,
494 ) {
495 let mut p = PromSeriesProcessor::default_processor();
496 prom_write_request.clear();
497 prom_write_request
498 .merge(data.clone(), PromValidationMode::Strict, &mut p)
499 .unwrap();
500
501 let req = prom_write_request.as_row_insert_requests();
502
503 let samples = req
504 .ref_all_req()
505 .filter_map(|r| r.rows.as_ref().map(|r| r.rows.len()))
506 .sum::<usize>();
507 let prom_rows = RowInsertRequests {
508 inserts: req.all_req().collect::<Vec<_>>(),
509 };
510
511 assert_eq!(expected_samples, samples);
512 assert_eq!(expected_rows.inserts.len(), prom_rows.inserts.len());
513
514 let expected_rows_map = expected_rows
515 .inserts
516 .iter()
517 .map(|insert| (insert.table_name.clone(), insert.rows.clone().unwrap()))
518 .collect::<HashMap<_, _>>();
519
520 for r in &prom_rows.inserts {
521 let expected_rows = expected_rows_map.get(&r.table_name).unwrap().clone();
523 assert_eq!(sort_rows(expected_rows), sort_rows(r.rows.clone().unwrap()));
524 }
525 }
526
527 #[test]
529 fn test_decode_write_request() {
530 let mut d = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
531 d.push("benches");
532 d.push("write_request.pb.data");
533 let data = Bytes::from(std::fs::read(d).unwrap());
534
535 let (expected_rows, expected_samples) =
536 to_grpc_row_insert_requests(&WriteRequest::decode(data.clone()).unwrap()).unwrap();
537
538 let mut prom_write_request = PromWriteRequest::default();
539 for _ in 0..3 {
540 check_deserialized(
541 &mut prom_write_request,
542 &data,
543 expected_samples,
544 &expected_rows,
545 );
546 }
547 }
548
549 #[test]
550 fn test_decode_string_strict_mode_valid_utf8() {
551 let valid_utf8 = Bytes::from("hello world");
552 let result = PromValidationMode::Strict.decode_string(&valid_utf8);
553 assert!(result.is_ok());
554 assert_eq!(result.unwrap(), "hello world");
555 }
556
557 #[test]
558 fn test_decode_string_strict_mode_empty() {
559 let empty = Bytes::new();
560 let result = PromValidationMode::Strict.decode_string(&empty);
561 assert!(result.is_ok());
562 assert_eq!(result.unwrap(), "");
563 }
564
565 #[test]
566 fn test_decode_string_strict_mode_unicode() {
567 let unicode = Bytes::from("Hello ไธ็ ๐");
568 let result = PromValidationMode::Strict.decode_string(&unicode);
569 assert!(result.is_ok());
570 assert_eq!(result.unwrap(), "Hello ไธ็ ๐");
571 }
572
573 #[test]
574 fn test_decode_string_strict_mode_invalid_utf8() {
575 let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
577 let result = PromValidationMode::Strict.decode_string(&invalid_utf8);
578 assert!(result.is_err());
579 assert_eq!(
580 result.unwrap_err().to_string(),
581 "failed to decode Protobuf message: invalid utf-8"
582 );
583 }
584
585 #[test]
586 fn test_decode_string_strict_mode_incomplete_utf8() {
587 let incomplete_utf8 = Bytes::from(vec![0xC2]); let result = PromValidationMode::Strict.decode_string(&incomplete_utf8);
590 assert!(result.is_err());
591 assert_eq!(
592 result.unwrap_err().to_string(),
593 "failed to decode Protobuf message: invalid utf-8"
594 );
595 }
596
597 #[test]
598 fn test_decode_string_lossy_mode_valid_utf8() {
599 let valid_utf8 = Bytes::from("hello world");
600 let result = PromValidationMode::Lossy.decode_string(&valid_utf8);
601 assert!(result.is_ok());
602 assert_eq!(result.unwrap(), "hello world");
603 }
604
605 #[test]
606 fn test_decode_string_lossy_mode_empty() {
607 let empty = Bytes::new();
608 let result = PromValidationMode::Lossy.decode_string(&empty);
609 assert!(result.is_ok());
610 assert_eq!(result.unwrap(), "");
611 }
612
613 #[test]
614 fn test_decode_string_lossy_mode_unicode() {
615 let unicode = Bytes::from("Hello ไธ็ ๐");
616 let result = PromValidationMode::Lossy.decode_string(&unicode);
617 assert!(result.is_ok());
618 assert_eq!(result.unwrap(), "Hello ไธ็ ๐");
619 }
620
621 #[test]
622 fn test_decode_string_lossy_mode_invalid_utf8() {
623 let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
625 let result = PromValidationMode::Lossy.decode_string(&invalid_utf8);
626 assert!(result.is_ok());
627 assert_eq!(result.unwrap(), "๏ฟฝ๏ฟฝ๏ฟฝ");
629 }
630
631 #[test]
632 fn test_decode_string_lossy_mode_mixed_valid_invalid() {
633 let mut mixed = Vec::new();
635 mixed.extend_from_slice(b"hello");
636 mixed.push(0xFF); mixed.extend_from_slice(b"world");
638 let mixed_utf8 = Bytes::from(mixed);
639
640 let result = PromValidationMode::Lossy.decode_string(&mixed_utf8);
641 assert!(result.is_ok());
642 assert_eq!(result.unwrap(), "hello๏ฟฝworld");
643 }
644
645 #[test]
646 fn test_decode_string_unchecked_mode_valid_utf8() {
647 let valid_utf8 = Bytes::from("hello world");
648 let result = PromValidationMode::Unchecked.decode_string(&valid_utf8);
649 assert!(result.is_ok());
650 assert_eq!(result.unwrap(), "hello world");
651 }
652
653 #[test]
654 fn test_decode_string_unchecked_mode_empty() {
655 let empty = Bytes::new();
656 let result = PromValidationMode::Unchecked.decode_string(&empty);
657 assert!(result.is_ok());
658 assert_eq!(result.unwrap(), "");
659 }
660
661 #[test]
662 fn test_decode_string_unchecked_mode_unicode() {
663 let unicode = Bytes::from("Hello ไธ็ ๐");
664 let result = PromValidationMode::Unchecked.decode_string(&unicode);
665 assert!(result.is_ok());
666 assert_eq!(result.unwrap(), "Hello ไธ็ ๐");
667 }
668
669 #[test]
670 fn test_decode_string_unchecked_mode_invalid_utf8() {
671 let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
673 let result = PromValidationMode::Unchecked.decode_string(&invalid_utf8);
674 assert!(result.is_ok());
676 let _string = result.unwrap();
679 }
680
681 #[test]
682 fn test_decode_string_all_modes_ascii() {
683 let ascii = Bytes::from("simple_ascii_123");
684
685 let strict_result = PromValidationMode::Strict.decode_string(&ascii).unwrap();
687 let lossy_result = PromValidationMode::Lossy.decode_string(&ascii).unwrap();
688 let unchecked_result = PromValidationMode::Unchecked.decode_string(&ascii).unwrap();
689
690 assert_eq!(strict_result, "simple_ascii_123");
691 assert_eq!(lossy_result, "simple_ascii_123");
692 assert_eq!(unchecked_result, "simple_ascii_123");
693 assert_eq!(strict_result, lossy_result);
694 assert_eq!(lossy_result, unchecked_result);
695 }
696}