1use std::collections::BTreeMap;
16use std::collections::btree_map::Entry;
17use std::ops::Deref;
18use std::slice;
19
20use api::prom_store::remote::Sample;
21use bytes::{Buf, Bytes};
22use common_query::prelude::{greptime_timestamp, greptime_value};
23use common_telemetry::warn;
24use pipeline::{ContextReq, GreptimePipelineParams, PipelineContext, PipelineDefinition};
25use prost::DecodeError;
26use prost::encoding::message::merge;
27use prost::encoding::{WireType, decode_key, decode_varint};
28use session::context::QueryContextRef;
29use snafu::OptionExt;
30use vrl::prelude::NotNan;
31use vrl::value::{KeyString, Value as VrlValue};
32
33use crate::error::InternalSnafu;
34use crate::http::PromValidationMode;
35use crate::http::event::PipelineIngestRequest;
36use crate::pipeline::run_pipeline;
37use crate::prom_row_builder::{PromCtx, TablesBuilder};
38use crate::prom_store::{
39 DATABASE_LABEL_ALT_BYTES, DATABASE_LABEL_BYTES, METRIC_NAME_LABEL_BYTES,
40 PHYSICAL_TABLE_LABEL_ALT_BYTES, PHYSICAL_TABLE_LABEL_BYTES,
41};
42use crate::query_handler::PipelineHandlerRef;
43use crate::repeated_field::{Clear, RepeatedField};
44
45impl Clear for Sample {
46 fn clear(&mut self) {
47 self.timestamp = 0;
48 self.value = 0.0;
49 }
50}
51
52#[derive(Default, Clone, Debug)]
53pub struct PromLabel {
54 pub name: Bytes,
55 pub value: Bytes,
56}
57
58impl Clear for PromLabel {
59 fn clear(&mut self) {
60 self.name.clear();
61 self.value.clear();
62 }
63}
64
65impl PromLabel {
66 pub fn merge_field(
67 &mut self,
68 tag: u32,
69 wire_type: WireType,
70 buf: &mut Bytes,
71 ) -> Result<(), DecodeError> {
72 const STRUCT_NAME: &str = "PromLabel";
73 match tag {
74 1u32 => {
75 let value = &mut self.name;
77 merge_bytes(value, buf).map_err(|mut error| {
78 error.push(STRUCT_NAME, "name");
79 error
80 })
81 }
82 2u32 => {
83 let value = &mut self.value;
85 merge_bytes(value, buf).map_err(|mut error| {
86 error.push(STRUCT_NAME, "value");
87 error
88 })
89 }
90 _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
91 }
92 }
93}
94
95#[inline(always)]
96fn copy_to_bytes(data: &mut Bytes, len: usize) -> Bytes {
97 if len == data.remaining() {
98 std::mem::replace(data, Bytes::new())
99 } else {
100 let ret = split_to(data, len);
101 data.advance(len);
102 ret
103 }
104}
105
106#[inline(always)]
112fn split_to(buf: &mut Bytes, end: usize) -> Bytes {
113 let len = buf.len();
114 assert!(
115 end <= len,
116 "range end out of bounds: {:?} <= {:?}",
117 end,
118 len,
119 );
120
121 if end == 0 {
122 return Bytes::new();
123 }
124
125 let ptr = buf.as_ptr();
126 let x = unsafe { slice::from_raw_parts(ptr, end) };
127 Bytes::from_static(x)
129}
130
131#[inline(always)]
135fn merge_bytes(value: &mut Bytes, buf: &mut Bytes) -> Result<(), DecodeError> {
136 let len = decode_varint(buf)?;
137 if len > buf.remaining() as u64 {
138 return Err(DecodeError::new(format!(
139 "buffer underflow, len: {}, remaining: {}",
140 len,
141 buf.remaining()
142 )));
143 }
144 *value = copy_to_bytes(buf, len as usize);
145 Ok(())
146}
147
148#[derive(Default, Debug)]
149pub struct PromTimeSeries {
150 pub table_name: String,
151 pub schema: Option<String>,
153 pub physical_table: Option<String>,
155
156 pub labels: RepeatedField<PromLabel>,
157 pub samples: RepeatedField<Sample>,
158}
159
160impl Clear for PromTimeSeries {
161 fn clear(&mut self) {
162 self.table_name.clear();
163 self.labels.clear();
164 self.samples.clear();
165 }
166}
167
168impl PromTimeSeries {
169 pub fn merge_field(
170 &mut self,
171 tag: u32,
172 wire_type: WireType,
173 buf: &mut Bytes,
174 prom_validation_mode: PromValidationMode,
175 ) -> Result<(), DecodeError> {
176 const STRUCT_NAME: &str = "PromTimeSeries";
177 match tag {
178 1u32 => {
179 let label = self.labels.push_default();
181
182 let len = decode_varint(buf).map_err(|mut error| {
183 error.push(STRUCT_NAME, "labels");
184 error
185 })?;
186 let remaining = buf.remaining();
187 if len > remaining as u64 {
188 return Err(DecodeError::new("buffer underflow"));
189 }
190
191 let limit = remaining - len as usize;
192 while buf.remaining() > limit {
193 let (tag, wire_type) = decode_key(buf)?;
194 label.merge_field(tag, wire_type, buf)?;
195 }
196 if buf.remaining() != limit {
197 return Err(DecodeError::new("delimited length exceeded"));
198 }
199
200 match label.name.deref() {
201 METRIC_NAME_LABEL_BYTES => {
202 self.table_name = prom_validation_mode.decode_string(&label.value)?;
203 self.labels.truncate(self.labels.len() - 1); }
205 #[allow(deprecated)]
206 crate::prom_store::SCHEMA_LABEL_BYTES => {
207 self.schema = Some(prom_validation_mode.decode_string(&label.value)?);
208 self.labels.truncate(self.labels.len() - 1); }
210 DATABASE_LABEL_BYTES | DATABASE_LABEL_ALT_BYTES => {
211 if self.schema.is_none() {
213 self.schema = Some(prom_validation_mode.decode_string(&label.value)?);
214 }
215 self.labels.truncate(self.labels.len() - 1); }
217 PHYSICAL_TABLE_LABEL_BYTES | PHYSICAL_TABLE_LABEL_ALT_BYTES => {
218 self.physical_table =
219 Some(prom_validation_mode.decode_string(&label.value)?);
220 self.labels.truncate(self.labels.len() - 1); }
222 _ => {}
223 }
224
225 Ok(())
226 }
227 2u32 => {
228 let sample = self.samples.push_default();
229 merge(WireType::LengthDelimited, sample, buf, Default::default()).map_err(
230 |mut error| {
231 error.push(STRUCT_NAME, "samples");
232 error
233 },
234 )?;
235 Ok(())
236 }
237 3u32 => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
239 _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
240 }
241 }
242
243 fn add_to_table_data(
244 &mut self,
245 table_builders: &mut TablesBuilder,
246 prom_validation_mode: PromValidationMode,
247 ) -> Result<(), DecodeError> {
248 let label_num = self.labels.len();
249 let row_num = self.samples.len();
250
251 let prom_ctx = PromCtx {
252 schema: self.schema.take(),
253 physical_table: self.physical_table.take(),
254 };
255
256 let table_data = table_builders.get_or_create_table_builder(
257 prom_ctx,
258 std::mem::take(&mut self.table_name),
259 label_num,
260 row_num,
261 );
262 table_data.add_labels_and_samples(
263 self.labels.as_slice(),
264 self.samples.as_slice(),
265 prom_validation_mode,
266 )?;
267
268 Ok(())
269 }
270}
271
272#[derive(Default, Debug)]
273pub struct PromWriteRequest {
274 pub(crate) table_data: TablesBuilder,
275 series: PromTimeSeries,
276}
277
278impl Clear for PromWriteRequest {
279 fn clear(&mut self) {
280 self.table_data.clear();
281 }
282}
283
284impl PromWriteRequest {
285 pub fn as_row_insert_requests(&mut self) -> ContextReq {
286 self.table_data.as_insert_requests()
287 }
288
289 pub fn merge(
291 &mut self,
292 mut buf: Bytes,
293 prom_validation_mode: PromValidationMode,
294 processor: &mut PromSeriesProcessor,
295 ) -> Result<(), DecodeError> {
296 const STRUCT_NAME: &str = "PromWriteRequest";
297 while buf.has_remaining() {
298 let (tag, wire_type) = decode_key(&mut buf)?;
299 assert_eq!(WireType::LengthDelimited, wire_type);
300 match tag {
301 1u32 => {
302 let len = decode_varint(&mut buf).map_err(|mut e| {
304 e.push(STRUCT_NAME, "timeseries");
305 e
306 })?;
307 let remaining = buf.remaining();
308 if len > remaining as u64 {
309 return Err(DecodeError::new("buffer underflow"));
310 }
311
312 let limit = remaining - len as usize;
313 while buf.remaining() > limit {
314 let (tag, wire_type) = decode_key(&mut buf)?;
315 self.series
316 .merge_field(tag, wire_type, &mut buf, prom_validation_mode)?;
317 }
318 if buf.remaining() != limit {
319 return Err(DecodeError::new("delimited length exceeded"));
320 }
321
322 if processor.use_pipeline {
323 processor.consume_series_to_pipeline_map(
324 &mut self.series,
325 prom_validation_mode,
326 )?;
327 } else {
328 self.series
329 .add_to_table_data(&mut self.table_data, prom_validation_mode)?;
330 }
331
332 self.series.labels.clear();
334 self.series.samples.clear();
335 }
336 3u32 => {
337 prost::encoding::skip_field(wire_type, tag, &mut buf, Default::default())?;
339 }
340 _ => prost::encoding::skip_field(wire_type, tag, &mut buf, Default::default())?,
341 }
342 }
343
344 Ok(())
345 }
346}
347
348pub struct PromSeriesProcessor {
356 pub(crate) use_pipeline: bool,
357 pub(crate) table_values: BTreeMap<String, Vec<VrlValue>>,
358
359 pub(crate) pipeline_handler: Option<PipelineHandlerRef>,
361 pub(crate) query_ctx: Option<QueryContextRef>,
362 pub(crate) pipeline_def: Option<PipelineDefinition>,
363}
364
365impl PromSeriesProcessor {
366 pub fn default_processor() -> Self {
367 Self {
368 use_pipeline: false,
369 table_values: BTreeMap::new(),
370 pipeline_handler: None,
371 query_ctx: None,
372 pipeline_def: None,
373 }
374 }
375
376 pub fn set_pipeline(
377 &mut self,
378 handler: PipelineHandlerRef,
379 query_ctx: QueryContextRef,
380 pipeline_def: PipelineDefinition,
381 ) {
382 self.use_pipeline = true;
383 self.pipeline_handler = Some(handler);
384 self.query_ctx = Some(query_ctx);
385 self.pipeline_def = Some(pipeline_def);
386 }
387
388 pub(crate) fn consume_series_to_pipeline_map(
390 &mut self,
391 series: &mut PromTimeSeries,
392 prom_validation_mode: PromValidationMode,
393 ) -> Result<(), DecodeError> {
394 let mut vec_pipeline_map = Vec::new();
395 let mut pipeline_map = BTreeMap::new();
396 for l in series.labels.iter() {
397 let name = prom_validation_mode.decode_string(&l.name)?;
398 let value = prom_validation_mode.decode_string(&l.value)?;
399 pipeline_map.insert(KeyString::from(name), VrlValue::Bytes(value.into()));
400 }
401
402 let one_sample = series.samples.len() == 1;
403
404 for s in series.samples.iter() {
405 let Ok(value) = NotNan::new(s.value) else {
406 warn!("Invalid float value: {}", s.value);
407 continue;
408 };
409
410 let timestamp = s.timestamp;
411 pipeline_map.insert(
412 KeyString::from(greptime_timestamp()),
413 VrlValue::Integer(timestamp),
414 );
415 pipeline_map.insert(KeyString::from(greptime_value()), VrlValue::Float(value));
416 if one_sample {
417 vec_pipeline_map.push(VrlValue::Object(pipeline_map));
418 break;
419 } else {
420 vec_pipeline_map.push(VrlValue::Object(pipeline_map.clone()));
421 }
422 }
423
424 let table_name = std::mem::take(&mut series.table_name);
425 match self.table_values.entry(table_name) {
426 Entry::Occupied(mut occupied_entry) => {
427 occupied_entry.get_mut().append(&mut vec_pipeline_map);
428 }
429 Entry::Vacant(vacant_entry) => {
430 vacant_entry.insert(vec_pipeline_map);
431 }
432 }
433
434 Ok(())
435 }
436
437 pub(crate) async fn exec_pipeline(&mut self) -> crate::error::Result<ContextReq> {
438 let handler = self.pipeline_handler.as_ref().context(InternalSnafu {
440 err_msg: "pipeline handler is not set",
441 })?;
442 let pipeline_def = self.pipeline_def.as_ref().context(InternalSnafu {
443 err_msg: "pipeline definition is not set",
444 })?;
445 let pipeline_param = GreptimePipelineParams::default();
446 let query_ctx = self.query_ctx.as_ref().context(InternalSnafu {
447 err_msg: "query context is not set",
448 })?;
449
450 let pipeline_ctx = PipelineContext::new(pipeline_def, &pipeline_param, query_ctx.channel());
451
452 let mut req = ContextReq::default();
454 let table_values = std::mem::take(&mut self.table_values);
455 for (table_name, pipeline_maps) in table_values.into_iter() {
456 let pipeline_req = PipelineIngestRequest {
457 table: table_name,
458 values: pipeline_maps,
459 };
460 let row_req =
461 run_pipeline(handler, &pipeline_ctx, pipeline_req, query_ctx, true).await?;
462 req.merge(row_req);
463 }
464
465 Ok(req)
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use std::collections::HashMap;
472
473 use api::prom_store::remote::WriteRequest;
474 use api::v1::{Row, RowInsertRequests, Rows};
475 use bytes::Bytes;
476 use prost::Message;
477
478 use crate::http::PromValidationMode;
479 use crate::prom_store::to_grpc_row_insert_requests;
480 use crate::proto::{PromSeriesProcessor, PromWriteRequest};
481 use crate::repeated_field::Clear;
482
483 fn sort_rows(rows: Rows) -> Rows {
484 let permutation =
485 permutation::sort_by_key(&rows.schema, |schema| schema.column_name.clone());
486 let schema = permutation.apply_slice(&rows.schema);
487 let mut inner_rows = vec![];
488 for row in rows.rows {
489 let values = permutation.apply_slice(&row.values);
490 inner_rows.push(Row { values });
491 }
492 Rows {
493 schema,
494 rows: inner_rows,
495 }
496 }
497
498 fn check_deserialized(
499 prom_write_request: &mut PromWriteRequest,
500 data: &Bytes,
501 expected_samples: usize,
502 expected_rows: &RowInsertRequests,
503 ) {
504 let mut p = PromSeriesProcessor::default_processor();
505 prom_write_request.clear();
506 prom_write_request
507 .merge(data.clone(), PromValidationMode::Strict, &mut p)
508 .unwrap();
509
510 let req = prom_write_request.as_row_insert_requests();
511
512 let samples = req
513 .ref_all_req()
514 .filter_map(|r| r.rows.as_ref().map(|r| r.rows.len()))
515 .sum::<usize>();
516 let prom_rows = RowInsertRequests {
517 inserts: req.all_req().collect::<Vec<_>>(),
518 };
519
520 assert_eq!(expected_samples, samples);
521 assert_eq!(expected_rows.inserts.len(), prom_rows.inserts.len());
522
523 let expected_rows_map = expected_rows
524 .inserts
525 .iter()
526 .map(|insert| (insert.table_name.clone(), insert.rows.clone().unwrap()))
527 .collect::<HashMap<_, _>>();
528
529 for r in &prom_rows.inserts {
530 let expected_rows = expected_rows_map.get(&r.table_name).unwrap().clone();
532 assert_eq!(sort_rows(expected_rows), sort_rows(r.rows.clone().unwrap()));
533 }
534 }
535
536 #[test]
538 fn test_decode_write_request() {
539 let mut d = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
540 d.push("benches");
541 d.push("write_request.pb.data");
542 let data = Bytes::from(std::fs::read(d).unwrap());
543
544 let (expected_rows, expected_samples) =
545 to_grpc_row_insert_requests(&WriteRequest::decode(data.clone()).unwrap()).unwrap();
546
547 let mut prom_write_request = PromWriteRequest::default();
548 for _ in 0..3 {
549 check_deserialized(
550 &mut prom_write_request,
551 &data,
552 expected_samples,
553 &expected_rows,
554 );
555 }
556 }
557
558 #[test]
559 fn test_decode_string_strict_mode_valid_utf8() {
560 let valid_utf8 = Bytes::from("hello world");
561 let result = PromValidationMode::Strict.decode_string(&valid_utf8);
562 assert!(result.is_ok());
563 assert_eq!(result.unwrap(), "hello world");
564 }
565
566 #[test]
567 fn test_decode_string_strict_mode_empty() {
568 let empty = Bytes::new();
569 let result = PromValidationMode::Strict.decode_string(&empty);
570 assert!(result.is_ok());
571 assert_eq!(result.unwrap(), "");
572 }
573
574 #[test]
575 fn test_decode_string_strict_mode_unicode() {
576 let unicode = Bytes::from("Hello ไธ็ ๐");
577 let result = PromValidationMode::Strict.decode_string(&unicode);
578 assert!(result.is_ok());
579 assert_eq!(result.unwrap(), "Hello ไธ็ ๐");
580 }
581
582 #[test]
583 fn test_decode_string_strict_mode_invalid_utf8() {
584 let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
586 let result = PromValidationMode::Strict.decode_string(&invalid_utf8);
587 assert!(result.is_err());
588 assert_eq!(
589 result.unwrap_err().to_string(),
590 "failed to decode Protobuf message: invalid utf-8"
591 );
592 }
593
594 #[test]
595 fn test_decode_string_strict_mode_incomplete_utf8() {
596 let incomplete_utf8 = Bytes::from(vec![0xC2]); let result = PromValidationMode::Strict.decode_string(&incomplete_utf8);
599 assert!(result.is_err());
600 assert_eq!(
601 result.unwrap_err().to_string(),
602 "failed to decode Protobuf message: invalid utf-8"
603 );
604 }
605
606 #[test]
607 fn test_decode_string_lossy_mode_valid_utf8() {
608 let valid_utf8 = Bytes::from("hello world");
609 let result = PromValidationMode::Lossy.decode_string(&valid_utf8);
610 assert!(result.is_ok());
611 assert_eq!(result.unwrap(), "hello world");
612 }
613
614 #[test]
615 fn test_decode_string_lossy_mode_empty() {
616 let empty = Bytes::new();
617 let result = PromValidationMode::Lossy.decode_string(&empty);
618 assert!(result.is_ok());
619 assert_eq!(result.unwrap(), "");
620 }
621
622 #[test]
623 fn test_decode_string_lossy_mode_unicode() {
624 let unicode = Bytes::from("Hello ไธ็ ๐");
625 let result = PromValidationMode::Lossy.decode_string(&unicode);
626 assert!(result.is_ok());
627 assert_eq!(result.unwrap(), "Hello ไธ็ ๐");
628 }
629
630 #[test]
631 fn test_decode_string_lossy_mode_invalid_utf8() {
632 let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
634 let result = PromValidationMode::Lossy.decode_string(&invalid_utf8);
635 assert!(result.is_ok());
636 assert_eq!(result.unwrap(), "๏ฟฝ๏ฟฝ๏ฟฝ");
638 }
639
640 #[test]
641 fn test_decode_string_lossy_mode_mixed_valid_invalid() {
642 let mut mixed = Vec::new();
644 mixed.extend_from_slice(b"hello");
645 mixed.push(0xFF); mixed.extend_from_slice(b"world");
647 let mixed_utf8 = Bytes::from(mixed);
648
649 let result = PromValidationMode::Lossy.decode_string(&mixed_utf8);
650 assert!(result.is_ok());
651 assert_eq!(result.unwrap(), "hello๏ฟฝworld");
652 }
653
654 #[test]
655 fn test_decode_string_unchecked_mode_valid_utf8() {
656 let valid_utf8 = Bytes::from("hello world");
657 let result = PromValidationMode::Unchecked.decode_string(&valid_utf8);
658 assert!(result.is_ok());
659 assert_eq!(result.unwrap(), "hello world");
660 }
661
662 #[test]
663 fn test_decode_string_unchecked_mode_empty() {
664 let empty = Bytes::new();
665 let result = PromValidationMode::Unchecked.decode_string(&empty);
666 assert!(result.is_ok());
667 assert_eq!(result.unwrap(), "");
668 }
669
670 #[test]
671 fn test_decode_string_unchecked_mode_unicode() {
672 let unicode = Bytes::from("Hello ไธ็ ๐");
673 let result = PromValidationMode::Unchecked.decode_string(&unicode);
674 assert!(result.is_ok());
675 assert_eq!(result.unwrap(), "Hello ไธ็ ๐");
676 }
677
678 #[test]
679 fn test_decode_string_unchecked_mode_invalid_utf8() {
680 let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
682 let result = PromValidationMode::Unchecked.decode_string(&invalid_utf8);
683 assert!(result.is_ok());
685 let _string = result.unwrap();
688 }
689
690 #[test]
691 fn test_decode_string_all_modes_ascii() {
692 let ascii = Bytes::from("simple_ascii_123");
693
694 let strict_result = PromValidationMode::Strict.decode_string(&ascii).unwrap();
696 let lossy_result = PromValidationMode::Lossy.decode_string(&ascii).unwrap();
697 let unchecked_result = PromValidationMode::Unchecked.decode_string(&ascii).unwrap();
698
699 assert_eq!(strict_result, "simple_ascii_123");
700 assert_eq!(lossy_result, "simple_ascii_123");
701 assert_eq!(unchecked_result, "simple_ascii_123");
702 assert_eq!(strict_result, lossy_result);
703 assert_eq!(lossy_result, unchecked_result);
704 }
705}