1use std::collections::btree_map::Entry;
16use std::collections::BTreeMap;
17use std::ops::Deref;
18use std::slice;
19
20use api::prom_store::remote::Sample;
21use bytes::{Buf, Bytes};
22use common_query::prelude::{GREPTIME_TIMESTAMP, GREPTIME_VALUE};
23use pipeline::{ContextReq, GreptimePipelineParams, PipelineContext, PipelineDefinition, Value};
24use prost::encoding::message::merge;
25use prost::encoding::{decode_key, decode_varint, WireType};
26use prost::DecodeError;
27use session::context::QueryContextRef;
28use snafu::OptionExt;
29
30use crate::error::InternalSnafu;
31use crate::http::event::PipelineIngestRequest;
32use crate::http::PromValidationMode;
33use crate::pipeline::run_pipeline;
34use crate::prom_row_builder::{PromCtx, TablesBuilder};
35use crate::prom_store::{
36 DATABASE_LABEL_BYTES, METRIC_NAME_LABEL_BYTES, PHYSICAL_TABLE_LABEL_BYTES,
37};
38use crate::query_handler::PipelineHandlerRef;
39use crate::repeated_field::{Clear, RepeatedField};
40
41impl Clear for Sample {
42 fn clear(&mut self) {
43 self.timestamp = 0;
44 self.value = 0.0;
45 }
46}
47
48#[derive(Default, Clone, Debug)]
49pub struct PromLabel {
50 pub name: Bytes,
51 pub value: Bytes,
52}
53
54impl Clear for PromLabel {
55 fn clear(&mut self) {
56 self.name.clear();
57 self.value.clear();
58 }
59}
60
61impl PromLabel {
62 pub fn merge_field(
63 &mut self,
64 tag: u32,
65 wire_type: WireType,
66 buf: &mut Bytes,
67 ) -> Result<(), DecodeError> {
68 const STRUCT_NAME: &str = "PromLabel";
69 match tag {
70 1u32 => {
71 let value = &mut self.name;
73 merge_bytes(value, buf).map_err(|mut error| {
74 error.push(STRUCT_NAME, "name");
75 error
76 })
77 }
78 2u32 => {
79 let value = &mut self.value;
81 merge_bytes(value, buf).map_err(|mut error| {
82 error.push(STRUCT_NAME, "value");
83 error
84 })
85 }
86 _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
87 }
88 }
89}
90
91#[inline(always)]
92fn copy_to_bytes(data: &mut Bytes, len: usize) -> Bytes {
93 if len == data.remaining() {
94 std::mem::replace(data, Bytes::new())
95 } else {
96 let ret = split_to(data, len);
97 data.advance(len);
98 ret
99 }
100}
101
102#[inline(always)]
108fn split_to(buf: &mut Bytes, end: usize) -> Bytes {
109 let len = buf.len();
110 assert!(
111 end <= len,
112 "range end out of bounds: {:?} <= {:?}",
113 end,
114 len,
115 );
116
117 if end == 0 {
118 return Bytes::new();
119 }
120
121 let ptr = buf.as_ptr();
122 let x = unsafe { slice::from_raw_parts(ptr, end) };
123 Bytes::from_static(x)
125}
126
127#[inline(always)]
131fn merge_bytes(value: &mut Bytes, buf: &mut Bytes) -> Result<(), DecodeError> {
132 let len = decode_varint(buf)?;
133 if len > buf.remaining() as u64 {
134 return Err(DecodeError::new(format!(
135 "buffer underflow, len: {}, remaining: {}",
136 len,
137 buf.remaining()
138 )));
139 }
140 *value = copy_to_bytes(buf, len as usize);
141 Ok(())
142}
143
144#[derive(Default, Debug)]
145pub struct PromTimeSeries {
146 pub table_name: String,
147 pub schema: Option<String>,
149 pub physical_table: Option<String>,
151
152 pub labels: RepeatedField<PromLabel>,
153 pub samples: RepeatedField<Sample>,
154}
155
156impl Clear for PromTimeSeries {
157 fn clear(&mut self) {
158 self.table_name.clear();
159 self.labels.clear();
160 self.samples.clear();
161 }
162}
163
164impl PromTimeSeries {
165 pub fn merge_field(
166 &mut self,
167 tag: u32,
168 wire_type: WireType,
169 buf: &mut Bytes,
170 prom_validation_mode: PromValidationMode,
171 ) -> Result<(), DecodeError> {
172 const STRUCT_NAME: &str = "PromTimeSeries";
173 match tag {
174 1u32 => {
175 let label = self.labels.push_default();
177
178 let len = decode_varint(buf).map_err(|mut error| {
179 error.push(STRUCT_NAME, "labels");
180 error
181 })?;
182 let remaining = buf.remaining();
183 if len > remaining as u64 {
184 return Err(DecodeError::new("buffer underflow"));
185 }
186
187 let limit = remaining - len as usize;
188 while buf.remaining() > limit {
189 let (tag, wire_type) = decode_key(buf)?;
190 label.merge_field(tag, wire_type, buf)?;
191 }
192 if buf.remaining() != limit {
193 return Err(DecodeError::new("delimited length exceeded"));
194 }
195
196 match label.name.deref() {
197 METRIC_NAME_LABEL_BYTES => {
198 self.table_name = prom_validation_mode.decode_string(&label.value)?;
199 self.labels.truncate(self.labels.len() - 1); }
201 DATABASE_LABEL_BYTES => {
202 self.schema = Some(prom_validation_mode.decode_string(&label.value)?);
203 self.labels.truncate(self.labels.len() - 1); }
205 PHYSICAL_TABLE_LABEL_BYTES => {
206 self.physical_table =
207 Some(prom_validation_mode.decode_string(&label.value)?);
208 self.labels.truncate(self.labels.len() - 1); }
210 _ => {}
211 }
212
213 Ok(())
214 }
215 2u32 => {
216 let sample = self.samples.push_default();
217 merge(WireType::LengthDelimited, sample, buf, Default::default()).map_err(
218 |mut error| {
219 error.push(STRUCT_NAME, "samples");
220 error
221 },
222 )?;
223 Ok(())
224 }
225 3u32 => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
227 _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
228 }
229 }
230
231 fn add_to_table_data(
232 &mut self,
233 table_builders: &mut TablesBuilder,
234 prom_validation_mode: PromValidationMode,
235 ) -> Result<(), DecodeError> {
236 let label_num = self.labels.len();
237 let row_num = self.samples.len();
238
239 let prom_ctx = PromCtx {
240 schema: self.schema.take(),
241 physical_table: self.physical_table.take(),
242 };
243
244 let table_data = table_builders.get_or_create_table_builder(
245 prom_ctx,
246 std::mem::take(&mut self.table_name),
247 label_num,
248 row_num,
249 );
250 table_data.add_labels_and_samples(
251 self.labels.as_slice(),
252 self.samples.as_slice(),
253 prom_validation_mode,
254 )?;
255
256 Ok(())
257 }
258}
259
260#[derive(Default, Debug)]
261pub struct PromWriteRequest {
262 table_data: TablesBuilder,
263 series: PromTimeSeries,
264}
265
266impl Clear for PromWriteRequest {
267 fn clear(&mut self) {
268 self.table_data.clear();
269 }
270}
271
272impl PromWriteRequest {
273 pub fn as_row_insert_requests(&mut self) -> ContextReq {
274 self.table_data.as_insert_requests()
275 }
276
277 pub fn merge(
279 &mut self,
280 mut buf: Bytes,
281 prom_validation_mode: PromValidationMode,
282 processor: &mut PromSeriesProcessor,
283 ) -> Result<(), DecodeError> {
284 const STRUCT_NAME: &str = "PromWriteRequest";
285 while buf.has_remaining() {
286 let (tag, wire_type) = decode_key(&mut buf)?;
287 assert_eq!(WireType::LengthDelimited, wire_type);
288 match tag {
289 1u32 => {
290 let len = decode_varint(&mut buf).map_err(|mut e| {
292 e.push(STRUCT_NAME, "timeseries");
293 e
294 })?;
295 let remaining = buf.remaining();
296 if len > remaining as u64 {
297 return Err(DecodeError::new("buffer underflow"));
298 }
299
300 let limit = remaining - len as usize;
301 while buf.remaining() > limit {
302 let (tag, wire_type) = decode_key(&mut buf)?;
303 self.series
304 .merge_field(tag, wire_type, &mut buf, prom_validation_mode)?;
305 }
306 if buf.remaining() != limit {
307 return Err(DecodeError::new("delimited length exceeded"));
308 }
309
310 if processor.use_pipeline {
311 processor.consume_series_to_pipeline_map(
312 &mut self.series,
313 prom_validation_mode,
314 )?;
315 } else {
316 self.series
317 .add_to_table_data(&mut self.table_data, prom_validation_mode)?;
318 }
319
320 self.series.labels.clear();
322 self.series.samples.clear();
323 }
324 3u32 => {
325 prost::encoding::skip_field(wire_type, tag, &mut buf, Default::default())?;
327 }
328 _ => prost::encoding::skip_field(wire_type, tag, &mut buf, Default::default())?,
329 }
330 }
331
332 Ok(())
333 }
334}
335
336pub struct PromSeriesProcessor {
344 pub(crate) use_pipeline: bool,
345 pub(crate) table_values: BTreeMap<String, Vec<Value>>,
346
347 pub(crate) pipeline_handler: Option<PipelineHandlerRef>,
349 pub(crate) query_ctx: Option<QueryContextRef>,
350 pub(crate) pipeline_def: Option<PipelineDefinition>,
351}
352
353impl PromSeriesProcessor {
354 pub fn default_processor() -> Self {
355 Self {
356 use_pipeline: false,
357 table_values: BTreeMap::new(),
358 pipeline_handler: None,
359 query_ctx: None,
360 pipeline_def: None,
361 }
362 }
363
364 pub fn set_pipeline(
365 &mut self,
366 handler: PipelineHandlerRef,
367 query_ctx: QueryContextRef,
368 pipeline_def: PipelineDefinition,
369 ) {
370 self.use_pipeline = true;
371 self.pipeline_handler = Some(handler);
372 self.query_ctx = Some(query_ctx);
373 self.pipeline_def = Some(pipeline_def);
374 }
375
376 pub(crate) fn consume_series_to_pipeline_map(
378 &mut self,
379 series: &mut PromTimeSeries,
380 prom_validation_mode: PromValidationMode,
381 ) -> Result<(), DecodeError> {
382 let mut vec_pipeline_map: Vec<Value> = Vec::new();
383 let mut pipeline_map = BTreeMap::new();
384 for l in series.labels.iter() {
385 let name = prom_validation_mode.decode_string(&l.name)?;
386 let value = prom_validation_mode.decode_string(&l.value)?;
387 pipeline_map.insert(name, Value::String(value));
388 }
389
390 let one_sample = series.samples.len() == 1;
391
392 for s in series.samples.iter() {
393 let timestamp = s.timestamp;
394 pipeline_map.insert(GREPTIME_TIMESTAMP.to_string(), Value::Int64(timestamp));
395 pipeline_map.insert(GREPTIME_VALUE.to_string(), Value::Float64(s.value));
396 if one_sample {
397 vec_pipeline_map.push(Value::Map(pipeline_map.into()));
398 break;
399 } else {
400 vec_pipeline_map.push(Value::Map(pipeline_map.clone().into()));
401 }
402 }
403
404 let table_name = std::mem::take(&mut series.table_name);
405 match self.table_values.entry(table_name) {
406 Entry::Occupied(mut occupied_entry) => {
407 occupied_entry.get_mut().append(&mut vec_pipeline_map);
408 }
409 Entry::Vacant(vacant_entry) => {
410 vacant_entry.insert(vec_pipeline_map);
411 }
412 }
413
414 Ok(())
415 }
416
417 pub(crate) async fn exec_pipeline(&mut self) -> crate::error::Result<ContextReq> {
418 let handler = self.pipeline_handler.as_ref().context(InternalSnafu {
420 err_msg: "pipeline handler is not set",
421 })?;
422 let pipeline_def = self.pipeline_def.as_ref().context(InternalSnafu {
423 err_msg: "pipeline definition is not set",
424 })?;
425 let pipeline_param = GreptimePipelineParams::default();
426 let query_ctx = self.query_ctx.as_ref().context(InternalSnafu {
427 err_msg: "query context is not set",
428 })?;
429
430 let pipeline_ctx = PipelineContext::new(pipeline_def, &pipeline_param, query_ctx.channel());
431
432 let mut req = ContextReq::default();
434 for (table_name, pipeline_maps) in self.table_values.iter_mut() {
435 let pipeline_req = PipelineIngestRequest {
436 table: table_name.clone(),
437 values: pipeline_maps.clone(),
438 };
439 let row_req =
440 run_pipeline(handler, &pipeline_ctx, pipeline_req, query_ctx, true).await?;
441 req.merge(row_req);
442 }
443
444 Ok(req)
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use std::collections::HashMap;
451
452 use api::prom_store::remote::WriteRequest;
453 use api::v1::{Row, RowInsertRequests, Rows};
454 use bytes::Bytes;
455 use prost::Message;
456
457 use crate::http::PromValidationMode;
458 use crate::prom_store::to_grpc_row_insert_requests;
459 use crate::proto::{PromSeriesProcessor, PromWriteRequest};
460 use crate::repeated_field::Clear;
461
462 fn sort_rows(rows: Rows) -> Rows {
463 let permutation =
464 permutation::sort_by_key(&rows.schema, |schema| schema.column_name.clone());
465 let schema = permutation.apply_slice(&rows.schema);
466 let mut inner_rows = vec![];
467 for row in rows.rows {
468 let values = permutation.apply_slice(&row.values);
469 inner_rows.push(Row { values });
470 }
471 Rows {
472 schema,
473 rows: inner_rows,
474 }
475 }
476
477 fn check_deserialized(
478 prom_write_request: &mut PromWriteRequest,
479 data: &Bytes,
480 expected_samples: usize,
481 expected_rows: &RowInsertRequests,
482 ) {
483 let mut p = PromSeriesProcessor::default_processor();
484 prom_write_request.clear();
485 prom_write_request
486 .merge(data.clone(), PromValidationMode::Strict, &mut p)
487 .unwrap();
488
489 let req = prom_write_request.as_row_insert_requests();
490
491 let samples = req
492 .ref_all_req()
493 .filter_map(|r| r.rows.as_ref().map(|r| r.rows.len()))
494 .sum::<usize>();
495 let prom_rows = RowInsertRequests {
496 inserts: req.all_req().collect::<Vec<_>>(),
497 };
498
499 assert_eq!(expected_samples, samples);
500 assert_eq!(expected_rows.inserts.len(), prom_rows.inserts.len());
501
502 let expected_rows_map = expected_rows
503 .inserts
504 .iter()
505 .map(|insert| (insert.table_name.clone(), insert.rows.clone().unwrap()))
506 .collect::<HashMap<_, _>>();
507
508 for r in &prom_rows.inserts {
509 let expected_rows = expected_rows_map.get(&r.table_name).unwrap().clone();
511 assert_eq!(sort_rows(expected_rows), sort_rows(r.rows.clone().unwrap()));
512 }
513 }
514
515 #[test]
517 fn test_decode_write_request() {
518 let mut d = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
519 d.push("benches");
520 d.push("write_request.pb.data");
521 let data = Bytes::from(std::fs::read(d).unwrap());
522
523 let (expected_rows, expected_samples) =
524 to_grpc_row_insert_requests(&WriteRequest::decode(data.clone()).unwrap()).unwrap();
525
526 let mut prom_write_request = PromWriteRequest::default();
527 for _ in 0..3 {
528 check_deserialized(
529 &mut prom_write_request,
530 &data,
531 expected_samples,
532 &expected_rows,
533 );
534 }
535 }
536
537 #[test]
538 fn test_decode_string_strict_mode_valid_utf8() {
539 let valid_utf8 = Bytes::from("hello world");
540 let result = PromValidationMode::Strict.decode_string(&valid_utf8);
541 assert!(result.is_ok());
542 assert_eq!(result.unwrap(), "hello world");
543 }
544
545 #[test]
546 fn test_decode_string_strict_mode_empty() {
547 let empty = Bytes::new();
548 let result = PromValidationMode::Strict.decode_string(&empty);
549 assert!(result.is_ok());
550 assert_eq!(result.unwrap(), "");
551 }
552
553 #[test]
554 fn test_decode_string_strict_mode_unicode() {
555 let unicode = Bytes::from("Hello ไธ็ ๐");
556 let result = PromValidationMode::Strict.decode_string(&unicode);
557 assert!(result.is_ok());
558 assert_eq!(result.unwrap(), "Hello ไธ็ ๐");
559 }
560
561 #[test]
562 fn test_decode_string_strict_mode_invalid_utf8() {
563 let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
565 let result = PromValidationMode::Strict.decode_string(&invalid_utf8);
566 assert!(result.is_err());
567 assert_eq!(
568 result.unwrap_err().to_string(),
569 "failed to decode Protobuf message: invalid utf-8"
570 );
571 }
572
573 #[test]
574 fn test_decode_string_strict_mode_incomplete_utf8() {
575 let incomplete_utf8 = Bytes::from(vec![0xC2]); let result = PromValidationMode::Strict.decode_string(&incomplete_utf8);
578 assert!(result.is_err());
579 assert_eq!(
580 result.unwrap_err().to_string(),
581 "failed to decode Protobuf message: invalid utf-8"
582 );
583 }
584
585 #[test]
586 fn test_decode_string_lossy_mode_valid_utf8() {
587 let valid_utf8 = Bytes::from("hello world");
588 let result = PromValidationMode::Lossy.decode_string(&valid_utf8);
589 assert!(result.is_ok());
590 assert_eq!(result.unwrap(), "hello world");
591 }
592
593 #[test]
594 fn test_decode_string_lossy_mode_empty() {
595 let empty = Bytes::new();
596 let result = PromValidationMode::Lossy.decode_string(&empty);
597 assert!(result.is_ok());
598 assert_eq!(result.unwrap(), "");
599 }
600
601 #[test]
602 fn test_decode_string_lossy_mode_unicode() {
603 let unicode = Bytes::from("Hello ไธ็ ๐");
604 let result = PromValidationMode::Lossy.decode_string(&unicode);
605 assert!(result.is_ok());
606 assert_eq!(result.unwrap(), "Hello ไธ็ ๐");
607 }
608
609 #[test]
610 fn test_decode_string_lossy_mode_invalid_utf8() {
611 let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
613 let result = PromValidationMode::Lossy.decode_string(&invalid_utf8);
614 assert!(result.is_ok());
615 assert_eq!(result.unwrap(), "๏ฟฝ๏ฟฝ๏ฟฝ");
617 }
618
619 #[test]
620 fn test_decode_string_lossy_mode_mixed_valid_invalid() {
621 let mut mixed = Vec::new();
623 mixed.extend_from_slice(b"hello");
624 mixed.push(0xFF); mixed.extend_from_slice(b"world");
626 let mixed_utf8 = Bytes::from(mixed);
627
628 let result = PromValidationMode::Lossy.decode_string(&mixed_utf8);
629 assert!(result.is_ok());
630 assert_eq!(result.unwrap(), "hello๏ฟฝworld");
631 }
632
633 #[test]
634 fn test_decode_string_unchecked_mode_valid_utf8() {
635 let valid_utf8 = Bytes::from("hello world");
636 let result = PromValidationMode::Unchecked.decode_string(&valid_utf8);
637 assert!(result.is_ok());
638 assert_eq!(result.unwrap(), "hello world");
639 }
640
641 #[test]
642 fn test_decode_string_unchecked_mode_empty() {
643 let empty = Bytes::new();
644 let result = PromValidationMode::Unchecked.decode_string(&empty);
645 assert!(result.is_ok());
646 assert_eq!(result.unwrap(), "");
647 }
648
649 #[test]
650 fn test_decode_string_unchecked_mode_unicode() {
651 let unicode = Bytes::from("Hello ไธ็ ๐");
652 let result = PromValidationMode::Unchecked.decode_string(&unicode);
653 assert!(result.is_ok());
654 assert_eq!(result.unwrap(), "Hello ไธ็ ๐");
655 }
656
657 #[test]
658 fn test_decode_string_unchecked_mode_invalid_utf8() {
659 let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
661 let result = PromValidationMode::Unchecked.decode_string(&invalid_utf8);
662 assert!(result.is_ok());
664 let _string = result.unwrap();
667 }
668
669 #[test]
670 fn test_decode_string_all_modes_ascii() {
671 let ascii = Bytes::from("simple_ascii_123");
672
673 let strict_result = PromValidationMode::Strict.decode_string(&ascii).unwrap();
675 let lossy_result = PromValidationMode::Lossy.decode_string(&ascii).unwrap();
676 let unchecked_result = PromValidationMode::Unchecked.decode_string(&ascii).unwrap();
677
678 assert_eq!(strict_result, "simple_ascii_123");
679 assert_eq!(lossy_result, "simple_ascii_123");
680 assert_eq!(unchecked_result, "simple_ascii_123");
681 assert_eq!(strict_result, lossy_result);
682 assert_eq!(lossy_result, unchecked_result);
683 }
684}