1use std::collections::btree_map::Entry;
16use std::collections::BTreeMap;
17use std::ops::Deref;
18use std::slice;
19
20use api::prom_store::remote::Sample;
21use bytes::{Buf, Bytes};
22use common_query::prelude::{GREPTIME_TIMESTAMP, GREPTIME_VALUE};
23use common_telemetry::warn;
24use pipeline::{ContextReq, GreptimePipelineParams, PipelineContext, PipelineDefinition};
25use prost::encoding::message::merge;
26use prost::encoding::{decode_key, decode_varint, WireType};
27use prost::DecodeError;
28use session::context::QueryContextRef;
29use snafu::OptionExt;
30use vrl::prelude::NotNan;
31use vrl::value::{KeyString, Value as VrlValue};
32
33use crate::error::InternalSnafu;
34use crate::http::event::PipelineIngestRequest;
35use crate::http::PromValidationMode;
36use crate::pipeline::run_pipeline;
37use crate::prom_row_builder::{PromCtx, TablesBuilder};
38use crate::prom_store::{
39 DATABASE_LABEL_BYTES, METRIC_NAME_LABEL_BYTES, PHYSICAL_TABLE_LABEL_BYTES, SCHEMA_LABEL_BYTES,
40};
41use crate::query_handler::PipelineHandlerRef;
42use crate::repeated_field::{Clear, RepeatedField};
43
44impl Clear for Sample {
45 fn clear(&mut self) {
46 self.timestamp = 0;
47 self.value = 0.0;
48 }
49}
50
51#[derive(Default, Clone, Debug)]
52pub struct PromLabel {
53 pub name: Bytes,
54 pub value: Bytes,
55}
56
57impl Clear for PromLabel {
58 fn clear(&mut self) {
59 self.name.clear();
60 self.value.clear();
61 }
62}
63
64impl PromLabel {
65 pub fn merge_field(
66 &mut self,
67 tag: u32,
68 wire_type: WireType,
69 buf: &mut Bytes,
70 ) -> Result<(), DecodeError> {
71 const STRUCT_NAME: &str = "PromLabel";
72 match tag {
73 1u32 => {
74 let value = &mut self.name;
76 merge_bytes(value, buf).map_err(|mut error| {
77 error.push(STRUCT_NAME, "name");
78 error
79 })
80 }
81 2u32 => {
82 let value = &mut self.value;
84 merge_bytes(value, buf).map_err(|mut error| {
85 error.push(STRUCT_NAME, "value");
86 error
87 })
88 }
89 _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
90 }
91 }
92}
93
94#[inline(always)]
95fn copy_to_bytes(data: &mut Bytes, len: usize) -> Bytes {
96 if len == data.remaining() {
97 std::mem::replace(data, Bytes::new())
98 } else {
99 let ret = split_to(data, len);
100 data.advance(len);
101 ret
102 }
103}
104
105#[inline(always)]
111fn split_to(buf: &mut Bytes, end: usize) -> Bytes {
112 let len = buf.len();
113 assert!(
114 end <= len,
115 "range end out of bounds: {:?} <= {:?}",
116 end,
117 len,
118 );
119
120 if end == 0 {
121 return Bytes::new();
122 }
123
124 let ptr = buf.as_ptr();
125 let x = unsafe { slice::from_raw_parts(ptr, end) };
126 Bytes::from_static(x)
128}
129
130#[inline(always)]
134fn merge_bytes(value: &mut Bytes, buf: &mut Bytes) -> Result<(), DecodeError> {
135 let len = decode_varint(buf)?;
136 if len > buf.remaining() as u64 {
137 return Err(DecodeError::new(format!(
138 "buffer underflow, len: {}, remaining: {}",
139 len,
140 buf.remaining()
141 )));
142 }
143 *value = copy_to_bytes(buf, len as usize);
144 Ok(())
145}
146
147#[derive(Default, Debug)]
148pub struct PromTimeSeries {
149 pub table_name: String,
150 pub schema: Option<String>,
152 pub physical_table: Option<String>,
154
155 pub labels: RepeatedField<PromLabel>,
156 pub samples: RepeatedField<Sample>,
157}
158
159impl Clear for PromTimeSeries {
160 fn clear(&mut self) {
161 self.table_name.clear();
162 self.labels.clear();
163 self.samples.clear();
164 }
165}
166
167impl PromTimeSeries {
168 pub fn merge_field(
169 &mut self,
170 tag: u32,
171 wire_type: WireType,
172 buf: &mut Bytes,
173 prom_validation_mode: PromValidationMode,
174 ) -> Result<(), DecodeError> {
175 const STRUCT_NAME: &str = "PromTimeSeries";
176 match tag {
177 1u32 => {
178 let label = self.labels.push_default();
180
181 let len = decode_varint(buf).map_err(|mut error| {
182 error.push(STRUCT_NAME, "labels");
183 error
184 })?;
185 let remaining = buf.remaining();
186 if len > remaining as u64 {
187 return Err(DecodeError::new("buffer underflow"));
188 }
189
190 let limit = remaining - len as usize;
191 while buf.remaining() > limit {
192 let (tag, wire_type) = decode_key(buf)?;
193 label.merge_field(tag, wire_type, buf)?;
194 }
195 if buf.remaining() != limit {
196 return Err(DecodeError::new("delimited length exceeded"));
197 }
198
199 match label.name.deref() {
200 METRIC_NAME_LABEL_BYTES => {
201 self.table_name = prom_validation_mode.decode_string(&label.value)?;
202 self.labels.truncate(self.labels.len() - 1); }
204 SCHEMA_LABEL_BYTES => {
205 self.schema = Some(prom_validation_mode.decode_string(&label.value)?);
206 self.labels.truncate(self.labels.len() - 1); }
208 DATABASE_LABEL_BYTES => {
209 if self.schema.is_none() {
211 self.schema = Some(prom_validation_mode.decode_string(&label.value)?);
212 }
213 self.labels.truncate(self.labels.len() - 1); }
215 PHYSICAL_TABLE_LABEL_BYTES => {
216 self.physical_table =
217 Some(prom_validation_mode.decode_string(&label.value)?);
218 self.labels.truncate(self.labels.len() - 1); }
220 _ => {}
221 }
222
223 Ok(())
224 }
225 2u32 => {
226 let sample = self.samples.push_default();
227 merge(WireType::LengthDelimited, sample, buf, Default::default()).map_err(
228 |mut error| {
229 error.push(STRUCT_NAME, "samples");
230 error
231 },
232 )?;
233 Ok(())
234 }
235 3u32 => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
237 _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
238 }
239 }
240
241 fn add_to_table_data(
242 &mut self,
243 table_builders: &mut TablesBuilder,
244 prom_validation_mode: PromValidationMode,
245 ) -> Result<(), DecodeError> {
246 let label_num = self.labels.len();
247 let row_num = self.samples.len();
248
249 let prom_ctx = PromCtx {
250 schema: self.schema.take(),
251 physical_table: self.physical_table.take(),
252 };
253
254 let table_data = table_builders.get_or_create_table_builder(
255 prom_ctx,
256 std::mem::take(&mut self.table_name),
257 label_num,
258 row_num,
259 );
260 table_data.add_labels_and_samples(
261 self.labels.as_slice(),
262 self.samples.as_slice(),
263 prom_validation_mode,
264 )?;
265
266 Ok(())
267 }
268}
269
270#[derive(Default, Debug)]
271pub struct PromWriteRequest {
272 pub(crate) table_data: TablesBuilder,
273 series: PromTimeSeries,
274}
275
276impl Clear for PromWriteRequest {
277 fn clear(&mut self) {
278 self.table_data.clear();
279 }
280}
281
282impl PromWriteRequest {
283 pub fn as_row_insert_requests(&mut self) -> ContextReq {
284 self.table_data.as_insert_requests()
285 }
286
287 pub fn merge(
289 &mut self,
290 mut buf: Bytes,
291 prom_validation_mode: PromValidationMode,
292 processor: &mut PromSeriesProcessor,
293 ) -> Result<(), DecodeError> {
294 const STRUCT_NAME: &str = "PromWriteRequest";
295 while buf.has_remaining() {
296 let (tag, wire_type) = decode_key(&mut buf)?;
297 assert_eq!(WireType::LengthDelimited, wire_type);
298 match tag {
299 1u32 => {
300 let len = decode_varint(&mut buf).map_err(|mut e| {
302 e.push(STRUCT_NAME, "timeseries");
303 e
304 })?;
305 let remaining = buf.remaining();
306 if len > remaining as u64 {
307 return Err(DecodeError::new("buffer underflow"));
308 }
309
310 let limit = remaining - len as usize;
311 while buf.remaining() > limit {
312 let (tag, wire_type) = decode_key(&mut buf)?;
313 self.series
314 .merge_field(tag, wire_type, &mut buf, prom_validation_mode)?;
315 }
316 if buf.remaining() != limit {
317 return Err(DecodeError::new("delimited length exceeded"));
318 }
319
320 if processor.use_pipeline {
321 processor.consume_series_to_pipeline_map(
322 &mut self.series,
323 prom_validation_mode,
324 )?;
325 } else {
326 self.series
327 .add_to_table_data(&mut self.table_data, prom_validation_mode)?;
328 }
329
330 self.series.labels.clear();
332 self.series.samples.clear();
333 }
334 3u32 => {
335 prost::encoding::skip_field(wire_type, tag, &mut buf, Default::default())?;
337 }
338 _ => prost::encoding::skip_field(wire_type, tag, &mut buf, Default::default())?,
339 }
340 }
341
342 Ok(())
343 }
344}
345
346pub struct PromSeriesProcessor {
354 pub(crate) use_pipeline: bool,
355 pub(crate) table_values: BTreeMap<String, Vec<VrlValue>>,
356
357 pub(crate) pipeline_handler: Option<PipelineHandlerRef>,
359 pub(crate) query_ctx: Option<QueryContextRef>,
360 pub(crate) pipeline_def: Option<PipelineDefinition>,
361}
362
363impl PromSeriesProcessor {
364 pub fn default_processor() -> Self {
365 Self {
366 use_pipeline: false,
367 table_values: BTreeMap::new(),
368 pipeline_handler: None,
369 query_ctx: None,
370 pipeline_def: None,
371 }
372 }
373
374 pub fn set_pipeline(
375 &mut self,
376 handler: PipelineHandlerRef,
377 query_ctx: QueryContextRef,
378 pipeline_def: PipelineDefinition,
379 ) {
380 self.use_pipeline = true;
381 self.pipeline_handler = Some(handler);
382 self.query_ctx = Some(query_ctx);
383 self.pipeline_def = Some(pipeline_def);
384 }
385
386 pub(crate) fn consume_series_to_pipeline_map(
388 &mut self,
389 series: &mut PromTimeSeries,
390 prom_validation_mode: PromValidationMode,
391 ) -> Result<(), DecodeError> {
392 let mut vec_pipeline_map = Vec::new();
393 let mut pipeline_map = BTreeMap::new();
394 for l in series.labels.iter() {
395 let name = prom_validation_mode.decode_string(&l.name)?;
396 let value = prom_validation_mode.decode_string(&l.value)?;
397 pipeline_map.insert(KeyString::from(name), VrlValue::Bytes(value.into()));
398 }
399
400 let one_sample = series.samples.len() == 1;
401
402 for s in series.samples.iter() {
403 let Ok(value) = NotNan::new(s.value) else {
404 warn!("Invalid float value: {}", s.value);
405 continue;
406 };
407
408 let timestamp = s.timestamp;
409 pipeline_map.insert(
410 KeyString::from(GREPTIME_TIMESTAMP),
411 VrlValue::Integer(timestamp),
412 );
413 pipeline_map.insert(KeyString::from(GREPTIME_VALUE), VrlValue::Float(value));
414 if one_sample {
415 vec_pipeline_map.push(VrlValue::Object(pipeline_map));
416 break;
417 } else {
418 vec_pipeline_map.push(VrlValue::Object(pipeline_map.clone()));
419 }
420 }
421
422 let table_name = std::mem::take(&mut series.table_name);
423 match self.table_values.entry(table_name) {
424 Entry::Occupied(mut occupied_entry) => {
425 occupied_entry.get_mut().append(&mut vec_pipeline_map);
426 }
427 Entry::Vacant(vacant_entry) => {
428 vacant_entry.insert(vec_pipeline_map);
429 }
430 }
431
432 Ok(())
433 }
434
435 pub(crate) async fn exec_pipeline(&mut self) -> crate::error::Result<ContextReq> {
436 let handler = self.pipeline_handler.as_ref().context(InternalSnafu {
438 err_msg: "pipeline handler is not set",
439 })?;
440 let pipeline_def = self.pipeline_def.as_ref().context(InternalSnafu {
441 err_msg: "pipeline definition is not set",
442 })?;
443 let pipeline_param = GreptimePipelineParams::default();
444 let query_ctx = self.query_ctx.as_ref().context(InternalSnafu {
445 err_msg: "query context is not set",
446 })?;
447
448 let pipeline_ctx = PipelineContext::new(pipeline_def, &pipeline_param, query_ctx.channel());
449
450 let mut req = ContextReq::default();
452 let table_values = std::mem::take(&mut self.table_values);
453 for (table_name, pipeline_maps) in table_values.into_iter() {
454 let pipeline_req = PipelineIngestRequest {
455 table: table_name,
456 values: pipeline_maps,
457 };
458 let row_req =
459 run_pipeline(handler, &pipeline_ctx, pipeline_req, query_ctx, true).await?;
460 req.merge(row_req);
461 }
462
463 Ok(req)
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use std::collections::HashMap;
470
471 use api::prom_store::remote::WriteRequest;
472 use api::v1::{Row, RowInsertRequests, Rows};
473 use bytes::Bytes;
474 use prost::Message;
475
476 use crate::http::PromValidationMode;
477 use crate::prom_store::to_grpc_row_insert_requests;
478 use crate::proto::{PromSeriesProcessor, PromWriteRequest};
479 use crate::repeated_field::Clear;
480
481 fn sort_rows(rows: Rows) -> Rows {
482 let permutation =
483 permutation::sort_by_key(&rows.schema, |schema| schema.column_name.clone());
484 let schema = permutation.apply_slice(&rows.schema);
485 let mut inner_rows = vec![];
486 for row in rows.rows {
487 let values = permutation.apply_slice(&row.values);
488 inner_rows.push(Row { values });
489 }
490 Rows {
491 schema,
492 rows: inner_rows,
493 }
494 }
495
496 fn check_deserialized(
497 prom_write_request: &mut PromWriteRequest,
498 data: &Bytes,
499 expected_samples: usize,
500 expected_rows: &RowInsertRequests,
501 ) {
502 let mut p = PromSeriesProcessor::default_processor();
503 prom_write_request.clear();
504 prom_write_request
505 .merge(data.clone(), PromValidationMode::Strict, &mut p)
506 .unwrap();
507
508 let req = prom_write_request.as_row_insert_requests();
509
510 let samples = req
511 .ref_all_req()
512 .filter_map(|r| r.rows.as_ref().map(|r| r.rows.len()))
513 .sum::<usize>();
514 let prom_rows = RowInsertRequests {
515 inserts: req.all_req().collect::<Vec<_>>(),
516 };
517
518 assert_eq!(expected_samples, samples);
519 assert_eq!(expected_rows.inserts.len(), prom_rows.inserts.len());
520
521 let expected_rows_map = expected_rows
522 .inserts
523 .iter()
524 .map(|insert| (insert.table_name.clone(), insert.rows.clone().unwrap()))
525 .collect::<HashMap<_, _>>();
526
527 for r in &prom_rows.inserts {
528 let expected_rows = expected_rows_map.get(&r.table_name).unwrap().clone();
530 assert_eq!(sort_rows(expected_rows), sort_rows(r.rows.clone().unwrap()));
531 }
532 }
533
534 #[test]
536 fn test_decode_write_request() {
537 let mut d = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
538 d.push("benches");
539 d.push("write_request.pb.data");
540 let data = Bytes::from(std::fs::read(d).unwrap());
541
542 let (expected_rows, expected_samples) =
543 to_grpc_row_insert_requests(&WriteRequest::decode(data.clone()).unwrap()).unwrap();
544
545 let mut prom_write_request = PromWriteRequest::default();
546 for _ in 0..3 {
547 check_deserialized(
548 &mut prom_write_request,
549 &data,
550 expected_samples,
551 &expected_rows,
552 );
553 }
554 }
555
556 #[test]
557 fn test_decode_string_strict_mode_valid_utf8() {
558 let valid_utf8 = Bytes::from("hello world");
559 let result = PromValidationMode::Strict.decode_string(&valid_utf8);
560 assert!(result.is_ok());
561 assert_eq!(result.unwrap(), "hello world");
562 }
563
564 #[test]
565 fn test_decode_string_strict_mode_empty() {
566 let empty = Bytes::new();
567 let result = PromValidationMode::Strict.decode_string(&empty);
568 assert!(result.is_ok());
569 assert_eq!(result.unwrap(), "");
570 }
571
572 #[test]
573 fn test_decode_string_strict_mode_unicode() {
574 let unicode = Bytes::from("Hello ไธ็ ๐");
575 let result = PromValidationMode::Strict.decode_string(&unicode);
576 assert!(result.is_ok());
577 assert_eq!(result.unwrap(), "Hello ไธ็ ๐");
578 }
579
580 #[test]
581 fn test_decode_string_strict_mode_invalid_utf8() {
582 let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
584 let result = PromValidationMode::Strict.decode_string(&invalid_utf8);
585 assert!(result.is_err());
586 assert_eq!(
587 result.unwrap_err().to_string(),
588 "failed to decode Protobuf message: invalid utf-8"
589 );
590 }
591
592 #[test]
593 fn test_decode_string_strict_mode_incomplete_utf8() {
594 let incomplete_utf8 = Bytes::from(vec![0xC2]); let result = PromValidationMode::Strict.decode_string(&incomplete_utf8);
597 assert!(result.is_err());
598 assert_eq!(
599 result.unwrap_err().to_string(),
600 "failed to decode Protobuf message: invalid utf-8"
601 );
602 }
603
604 #[test]
605 fn test_decode_string_lossy_mode_valid_utf8() {
606 let valid_utf8 = Bytes::from("hello world");
607 let result = PromValidationMode::Lossy.decode_string(&valid_utf8);
608 assert!(result.is_ok());
609 assert_eq!(result.unwrap(), "hello world");
610 }
611
612 #[test]
613 fn test_decode_string_lossy_mode_empty() {
614 let empty = Bytes::new();
615 let result = PromValidationMode::Lossy.decode_string(&empty);
616 assert!(result.is_ok());
617 assert_eq!(result.unwrap(), "");
618 }
619
620 #[test]
621 fn test_decode_string_lossy_mode_unicode() {
622 let unicode = Bytes::from("Hello ไธ็ ๐");
623 let result = PromValidationMode::Lossy.decode_string(&unicode);
624 assert!(result.is_ok());
625 assert_eq!(result.unwrap(), "Hello ไธ็ ๐");
626 }
627
628 #[test]
629 fn test_decode_string_lossy_mode_invalid_utf8() {
630 let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
632 let result = PromValidationMode::Lossy.decode_string(&invalid_utf8);
633 assert!(result.is_ok());
634 assert_eq!(result.unwrap(), "๏ฟฝ๏ฟฝ๏ฟฝ");
636 }
637
638 #[test]
639 fn test_decode_string_lossy_mode_mixed_valid_invalid() {
640 let mut mixed = Vec::new();
642 mixed.extend_from_slice(b"hello");
643 mixed.push(0xFF); mixed.extend_from_slice(b"world");
645 let mixed_utf8 = Bytes::from(mixed);
646
647 let result = PromValidationMode::Lossy.decode_string(&mixed_utf8);
648 assert!(result.is_ok());
649 assert_eq!(result.unwrap(), "hello๏ฟฝworld");
650 }
651
652 #[test]
653 fn test_decode_string_unchecked_mode_valid_utf8() {
654 let valid_utf8 = Bytes::from("hello world");
655 let result = PromValidationMode::Unchecked.decode_string(&valid_utf8);
656 assert!(result.is_ok());
657 assert_eq!(result.unwrap(), "hello world");
658 }
659
660 #[test]
661 fn test_decode_string_unchecked_mode_empty() {
662 let empty = Bytes::new();
663 let result = PromValidationMode::Unchecked.decode_string(&empty);
664 assert!(result.is_ok());
665 assert_eq!(result.unwrap(), "");
666 }
667
668 #[test]
669 fn test_decode_string_unchecked_mode_unicode() {
670 let unicode = Bytes::from("Hello ไธ็ ๐");
671 let result = PromValidationMode::Unchecked.decode_string(&unicode);
672 assert!(result.is_ok());
673 assert_eq!(result.unwrap(), "Hello ไธ็ ๐");
674 }
675
676 #[test]
677 fn test_decode_string_unchecked_mode_invalid_utf8() {
678 let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
680 let result = PromValidationMode::Unchecked.decode_string(&invalid_utf8);
681 assert!(result.is_ok());
683 let _string = result.unwrap();
686 }
687
688 #[test]
689 fn test_decode_string_all_modes_ascii() {
690 let ascii = Bytes::from("simple_ascii_123");
691
692 let strict_result = PromValidationMode::Strict.decode_string(&ascii).unwrap();
694 let lossy_result = PromValidationMode::Lossy.decode_string(&ascii).unwrap();
695 let unchecked_result = PromValidationMode::Unchecked.decode_string(&ascii).unwrap();
696
697 assert_eq!(strict_result, "simple_ascii_123");
698 assert_eq!(lossy_result, "simple_ascii_123");
699 assert_eq!(unchecked_result, "simple_ascii_123");
700 assert_eq!(strict_result, lossy_result);
701 assert_eq!(lossy_result, unchecked_result);
702 }
703}