servers/
proto.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::btree_map::Entry;
16use std::collections::BTreeMap;
17use std::ops::Deref;
18use std::slice;
19
20use api::prom_store::remote::Sample;
21use bytes::{Buf, Bytes};
22use common_query::prelude::{GREPTIME_TIMESTAMP, GREPTIME_VALUE};
23use pipeline::{ContextReq, GreptimePipelineParams, PipelineContext, PipelineDefinition, Value};
24use prost::encoding::message::merge;
25use prost::encoding::{decode_key, decode_varint, WireType};
26use prost::DecodeError;
27use session::context::QueryContextRef;
28use snafu::OptionExt;
29
30use crate::error::InternalSnafu;
31use crate::http::event::PipelineIngestRequest;
32use crate::http::PromValidationMode;
33use crate::pipeline::run_pipeline;
34use crate::prom_row_builder::{PromCtx, TablesBuilder};
35use crate::prom_store::{
36    DATABASE_LABEL_BYTES, METRIC_NAME_LABEL_BYTES, PHYSICAL_TABLE_LABEL_BYTES,
37};
38use crate::query_handler::PipelineHandlerRef;
39use crate::repeated_field::{Clear, RepeatedField};
40
41impl Clear for Sample {
42    fn clear(&mut self) {
43        self.timestamp = 0;
44        self.value = 0.0;
45    }
46}
47
48#[derive(Default, Clone, Debug)]
49pub struct PromLabel {
50    pub name: Bytes,
51    pub value: Bytes,
52}
53
54impl Clear for PromLabel {
55    fn clear(&mut self) {
56        self.name.clear();
57        self.value.clear();
58    }
59}
60
61impl PromLabel {
62    pub fn merge_field(
63        &mut self,
64        tag: u32,
65        wire_type: WireType,
66        buf: &mut Bytes,
67    ) -> Result<(), DecodeError> {
68        const STRUCT_NAME: &str = "PromLabel";
69        match tag {
70            1u32 => {
71                // decode label name
72                let value = &mut self.name;
73                merge_bytes(value, buf).map_err(|mut error| {
74                    error.push(STRUCT_NAME, "name");
75                    error
76                })
77            }
78            2u32 => {
79                // decode label value
80                let value = &mut self.value;
81                merge_bytes(value, buf).map_err(|mut error| {
82                    error.push(STRUCT_NAME, "value");
83                    error
84                })
85            }
86            _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
87        }
88    }
89}
90
91#[inline(always)]
92fn copy_to_bytes(data: &mut Bytes, len: usize) -> Bytes {
93    if len == data.remaining() {
94        std::mem::replace(data, Bytes::new())
95    } else {
96        let ret = split_to(data, len);
97        data.advance(len);
98        ret
99    }
100}
101
102/// Similar to `Bytes::split_to`, but directly operates on underlying memory region.
103/// # Safety
104/// This function is safe as long as `data` is backed by a consecutive region of memory,
105/// for example `Vec<u8>` or `&[u8]`, and caller must ensure that `buf` outlives
106/// the `Bytes` returned.
107#[inline(always)]
108fn split_to(buf: &mut Bytes, end: usize) -> Bytes {
109    let len = buf.len();
110    assert!(
111        end <= len,
112        "range end out of bounds: {:?} <= {:?}",
113        end,
114        len,
115    );
116
117    if end == 0 {
118        return Bytes::new();
119    }
120
121    let ptr = buf.as_ptr();
122    let x = unsafe { slice::from_raw_parts(ptr, end) };
123    // `Bytes::drop` does nothing when it's built via `from_static`.
124    Bytes::from_static(x)
125}
126
127/// Reads a variable-length encoded bytes field from `buf` and assign it to `value`.
128/// # Safety
129/// Callers must ensure `buf` outlives `value`.
130#[inline(always)]
131fn merge_bytes(value: &mut Bytes, buf: &mut Bytes) -> Result<(), DecodeError> {
132    let len = decode_varint(buf)?;
133    if len > buf.remaining() as u64 {
134        return Err(DecodeError::new(format!(
135            "buffer underflow, len: {}, remaining: {}",
136            len,
137            buf.remaining()
138        )));
139    }
140    *value = copy_to_bytes(buf, len as usize);
141    Ok(())
142}
143
144#[derive(Default, Debug)]
145pub struct PromTimeSeries {
146    pub table_name: String,
147    // specified using `__database__` label
148    pub schema: Option<String>,
149    // specified using `__physical_table__` label
150    pub physical_table: Option<String>,
151
152    pub labels: RepeatedField<PromLabel>,
153    pub samples: RepeatedField<Sample>,
154}
155
156impl Clear for PromTimeSeries {
157    fn clear(&mut self) {
158        self.table_name.clear();
159        self.labels.clear();
160        self.samples.clear();
161    }
162}
163
164impl PromTimeSeries {
165    pub fn merge_field(
166        &mut self,
167        tag: u32,
168        wire_type: WireType,
169        buf: &mut Bytes,
170        prom_validation_mode: PromValidationMode,
171    ) -> Result<(), DecodeError> {
172        const STRUCT_NAME: &str = "PromTimeSeries";
173        match tag {
174            1u32 => {
175                // decode labels
176                let label = self.labels.push_default();
177
178                let len = decode_varint(buf).map_err(|mut error| {
179                    error.push(STRUCT_NAME, "labels");
180                    error
181                })?;
182                let remaining = buf.remaining();
183                if len > remaining as u64 {
184                    return Err(DecodeError::new("buffer underflow"));
185                }
186
187                let limit = remaining - len as usize;
188                while buf.remaining() > limit {
189                    let (tag, wire_type) = decode_key(buf)?;
190                    label.merge_field(tag, wire_type, buf)?;
191                }
192                if buf.remaining() != limit {
193                    return Err(DecodeError::new("delimited length exceeded"));
194                }
195
196                match label.name.deref() {
197                    METRIC_NAME_LABEL_BYTES => {
198                        self.table_name = prom_validation_mode.decode_string(&label.value)?;
199                        self.labels.truncate(self.labels.len() - 1); // remove last label
200                    }
201                    DATABASE_LABEL_BYTES => {
202                        self.schema = Some(prom_validation_mode.decode_string(&label.value)?);
203                        self.labels.truncate(self.labels.len() - 1); // remove last label
204                    }
205                    PHYSICAL_TABLE_LABEL_BYTES => {
206                        self.physical_table =
207                            Some(prom_validation_mode.decode_string(&label.value)?);
208                        self.labels.truncate(self.labels.len() - 1); // remove last label
209                    }
210                    _ => {}
211                }
212
213                Ok(())
214            }
215            2u32 => {
216                let sample = self.samples.push_default();
217                merge(WireType::LengthDelimited, sample, buf, Default::default()).map_err(
218                    |mut error| {
219                        error.push(STRUCT_NAME, "samples");
220                        error
221                    },
222                )?;
223                Ok(())
224            }
225            // todo(hl): exemplars are skipped temporarily
226            3u32 => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
227            _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
228        }
229    }
230
231    fn add_to_table_data(
232        &mut self,
233        table_builders: &mut TablesBuilder,
234        prom_validation_mode: PromValidationMode,
235    ) -> Result<(), DecodeError> {
236        let label_num = self.labels.len();
237        let row_num = self.samples.len();
238
239        let prom_ctx = PromCtx {
240            schema: self.schema.take(),
241            physical_table: self.physical_table.take(),
242        };
243
244        let table_data = table_builders.get_or_create_table_builder(
245            prom_ctx,
246            std::mem::take(&mut self.table_name),
247            label_num,
248            row_num,
249        );
250        table_data.add_labels_and_samples(
251            self.labels.as_slice(),
252            self.samples.as_slice(),
253            prom_validation_mode,
254        )?;
255
256        Ok(())
257    }
258}
259
260#[derive(Default, Debug)]
261pub struct PromWriteRequest {
262    table_data: TablesBuilder,
263    series: PromTimeSeries,
264}
265
266impl Clear for PromWriteRequest {
267    fn clear(&mut self) {
268        self.table_data.clear();
269    }
270}
271
272impl PromWriteRequest {
273    pub fn as_row_insert_requests(&mut self) -> ContextReq {
274        self.table_data.as_insert_requests()
275    }
276
277    // todo(hl): maybe use &[u8] can reduce the overhead introduced with Bytes.
278    pub fn merge(
279        &mut self,
280        mut buf: Bytes,
281        prom_validation_mode: PromValidationMode,
282        processor: &mut PromSeriesProcessor,
283    ) -> Result<(), DecodeError> {
284        const STRUCT_NAME: &str = "PromWriteRequest";
285        while buf.has_remaining() {
286            let (tag, wire_type) = decode_key(&mut buf)?;
287            assert_eq!(WireType::LengthDelimited, wire_type);
288            match tag {
289                1u32 => {
290                    // decode TimeSeries
291                    let len = decode_varint(&mut buf).map_err(|mut e| {
292                        e.push(STRUCT_NAME, "timeseries");
293                        e
294                    })?;
295                    let remaining = buf.remaining();
296                    if len > remaining as u64 {
297                        return Err(DecodeError::new("buffer underflow"));
298                    }
299
300                    let limit = remaining - len as usize;
301                    while buf.remaining() > limit {
302                        let (tag, wire_type) = decode_key(&mut buf)?;
303                        self.series
304                            .merge_field(tag, wire_type, &mut buf, prom_validation_mode)?;
305                    }
306                    if buf.remaining() != limit {
307                        return Err(DecodeError::new("delimited length exceeded"));
308                    }
309
310                    if processor.use_pipeline {
311                        processor.consume_series_to_pipeline_map(
312                            &mut self.series,
313                            prom_validation_mode,
314                        )?;
315                    } else {
316                        self.series
317                            .add_to_table_data(&mut self.table_data, prom_validation_mode)?;
318                    }
319
320                    // clear state
321                    self.series.labels.clear();
322                    self.series.samples.clear();
323                }
324                3u32 => {
325                    // todo(hl): metadata are skipped.
326                    prost::encoding::skip_field(wire_type, tag, &mut buf, Default::default())?;
327                }
328                _ => prost::encoding::skip_field(wire_type, tag, &mut buf, Default::default())?,
329            }
330        }
331
332        Ok(())
333    }
334}
335
336/// A hook to be injected into the PromWriteRequest decoding process.
337/// It was originally designed with two usage:
338/// 1. consume one series to desired type, in this case, the pipeline map
339/// 2. convert itself to RowInsertRequests
340///
341/// Since the origin conversion is coupled with PromWriteRequest,
342/// let's keep it that way for now.
343pub struct PromSeriesProcessor {
344    pub(crate) use_pipeline: bool,
345    pub(crate) table_values: BTreeMap<String, Vec<Value>>,
346
347    // optional fields for pipeline
348    pub(crate) pipeline_handler: Option<PipelineHandlerRef>,
349    pub(crate) query_ctx: Option<QueryContextRef>,
350    pub(crate) pipeline_def: Option<PipelineDefinition>,
351}
352
353impl PromSeriesProcessor {
354    pub fn default_processor() -> Self {
355        Self {
356            use_pipeline: false,
357            table_values: BTreeMap::new(),
358            pipeline_handler: None,
359            query_ctx: None,
360            pipeline_def: None,
361        }
362    }
363
364    pub fn set_pipeline(
365        &mut self,
366        handler: PipelineHandlerRef,
367        query_ctx: QueryContextRef,
368        pipeline_def: PipelineDefinition,
369    ) {
370        self.use_pipeline = true;
371        self.pipeline_handler = Some(handler);
372        self.query_ctx = Some(query_ctx);
373        self.pipeline_def = Some(pipeline_def);
374    }
375
376    // convert one series to pipeline map
377    pub(crate) fn consume_series_to_pipeline_map(
378        &mut self,
379        series: &mut PromTimeSeries,
380        prom_validation_mode: PromValidationMode,
381    ) -> Result<(), DecodeError> {
382        let mut vec_pipeline_map: Vec<Value> = Vec::new();
383        let mut pipeline_map = BTreeMap::new();
384        for l in series.labels.iter() {
385            let name = prom_validation_mode.decode_string(&l.name)?;
386            let value = prom_validation_mode.decode_string(&l.value)?;
387            pipeline_map.insert(name, Value::String(value));
388        }
389
390        let one_sample = series.samples.len() == 1;
391
392        for s in series.samples.iter() {
393            let timestamp = s.timestamp;
394            pipeline_map.insert(GREPTIME_TIMESTAMP.to_string(), Value::Int64(timestamp));
395            pipeline_map.insert(GREPTIME_VALUE.to_string(), Value::Float64(s.value));
396            if one_sample {
397                vec_pipeline_map.push(Value::Map(pipeline_map.into()));
398                break;
399            } else {
400                vec_pipeline_map.push(Value::Map(pipeline_map.clone().into()));
401            }
402        }
403
404        let table_name = std::mem::take(&mut series.table_name);
405        match self.table_values.entry(table_name) {
406            Entry::Occupied(mut occupied_entry) => {
407                occupied_entry.get_mut().append(&mut vec_pipeline_map);
408            }
409            Entry::Vacant(vacant_entry) => {
410                vacant_entry.insert(vec_pipeline_map);
411            }
412        }
413
414        Ok(())
415    }
416
417    pub(crate) async fn exec_pipeline(&mut self) -> crate::error::Result<ContextReq> {
418        // prepare params
419        let handler = self.pipeline_handler.as_ref().context(InternalSnafu {
420            err_msg: "pipeline handler is not set",
421        })?;
422        let pipeline_def = self.pipeline_def.as_ref().context(InternalSnafu {
423            err_msg: "pipeline definition is not set",
424        })?;
425        let pipeline_param = GreptimePipelineParams::default();
426        let query_ctx = self.query_ctx.as_ref().context(InternalSnafu {
427            err_msg: "query context is not set",
428        })?;
429
430        let pipeline_ctx = PipelineContext::new(pipeline_def, &pipeline_param, query_ctx.channel());
431
432        // run pipeline
433        let mut req = ContextReq::default();
434        for (table_name, pipeline_maps) in self.table_values.iter_mut() {
435            let pipeline_req = PipelineIngestRequest {
436                table: table_name.clone(),
437                values: pipeline_maps.clone(),
438            };
439            let row_req =
440                run_pipeline(handler, &pipeline_ctx, pipeline_req, query_ctx, true).await?;
441            req.merge(row_req);
442        }
443
444        Ok(req)
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use std::collections::HashMap;
451
452    use api::prom_store::remote::WriteRequest;
453    use api::v1::{Row, RowInsertRequests, Rows};
454    use bytes::Bytes;
455    use prost::Message;
456
457    use crate::http::PromValidationMode;
458    use crate::prom_store::to_grpc_row_insert_requests;
459    use crate::proto::{PromSeriesProcessor, PromWriteRequest};
460    use crate::repeated_field::Clear;
461
462    fn sort_rows(rows: Rows) -> Rows {
463        let permutation =
464            permutation::sort_by_key(&rows.schema, |schema| schema.column_name.clone());
465        let schema = permutation.apply_slice(&rows.schema);
466        let mut inner_rows = vec![];
467        for row in rows.rows {
468            let values = permutation.apply_slice(&row.values);
469            inner_rows.push(Row { values });
470        }
471        Rows {
472            schema,
473            rows: inner_rows,
474        }
475    }
476
477    fn check_deserialized(
478        prom_write_request: &mut PromWriteRequest,
479        data: &Bytes,
480        expected_samples: usize,
481        expected_rows: &RowInsertRequests,
482    ) {
483        let mut p = PromSeriesProcessor::default_processor();
484        prom_write_request.clear();
485        prom_write_request
486            .merge(data.clone(), PromValidationMode::Strict, &mut p)
487            .unwrap();
488
489        let req = prom_write_request.as_row_insert_requests();
490
491        let samples = req
492            .ref_all_req()
493            .filter_map(|r| r.rows.as_ref().map(|r| r.rows.len()))
494            .sum::<usize>();
495        let prom_rows = RowInsertRequests {
496            inserts: req.all_req().collect::<Vec<_>>(),
497        };
498
499        assert_eq!(expected_samples, samples);
500        assert_eq!(expected_rows.inserts.len(), prom_rows.inserts.len());
501
502        let expected_rows_map = expected_rows
503            .inserts
504            .iter()
505            .map(|insert| (insert.table_name.clone(), insert.rows.clone().unwrap()))
506            .collect::<HashMap<_, _>>();
507
508        for r in &prom_rows.inserts {
509            // check value
510            let expected_rows = expected_rows_map.get(&r.table_name).unwrap().clone();
511            assert_eq!(sort_rows(expected_rows), sort_rows(r.rows.clone().unwrap()));
512        }
513    }
514
515    // Ensures `WriteRequest` and `PromWriteRequest` produce the same gRPC request.
516    #[test]
517    fn test_decode_write_request() {
518        let mut d = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
519        d.push("benches");
520        d.push("write_request.pb.data");
521        let data = Bytes::from(std::fs::read(d).unwrap());
522
523        let (expected_rows, expected_samples) =
524            to_grpc_row_insert_requests(&WriteRequest::decode(data.clone()).unwrap()).unwrap();
525
526        let mut prom_write_request = PromWriteRequest::default();
527        for _ in 0..3 {
528            check_deserialized(
529                &mut prom_write_request,
530                &data,
531                expected_samples,
532                &expected_rows,
533            );
534        }
535    }
536
537    #[test]
538    fn test_decode_string_strict_mode_valid_utf8() {
539        let valid_utf8 = Bytes::from("hello world");
540        let result = PromValidationMode::Strict.decode_string(&valid_utf8);
541        assert!(result.is_ok());
542        assert_eq!(result.unwrap(), "hello world");
543    }
544
545    #[test]
546    fn test_decode_string_strict_mode_empty() {
547        let empty = Bytes::new();
548        let result = PromValidationMode::Strict.decode_string(&empty);
549        assert!(result.is_ok());
550        assert_eq!(result.unwrap(), "");
551    }
552
553    #[test]
554    fn test_decode_string_strict_mode_unicode() {
555        let unicode = Bytes::from("Hello ไธ–็•Œ ๐ŸŒ");
556        let result = PromValidationMode::Strict.decode_string(&unicode);
557        assert!(result.is_ok());
558        assert_eq!(result.unwrap(), "Hello ไธ–็•Œ ๐ŸŒ");
559    }
560
561    #[test]
562    fn test_decode_string_strict_mode_invalid_utf8() {
563        // Invalid UTF-8 sequence
564        let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
565        let result = PromValidationMode::Strict.decode_string(&invalid_utf8);
566        assert!(result.is_err());
567        assert_eq!(
568            result.unwrap_err().to_string(),
569            "failed to decode Protobuf message: invalid utf-8"
570        );
571    }
572
573    #[test]
574    fn test_decode_string_strict_mode_incomplete_utf8() {
575        // Incomplete UTF-8 sequence (missing continuation bytes)
576        let incomplete_utf8 = Bytes::from(vec![0xC2]); // Start of 2-byte sequence but missing second byte
577        let result = PromValidationMode::Strict.decode_string(&incomplete_utf8);
578        assert!(result.is_err());
579        assert_eq!(
580            result.unwrap_err().to_string(),
581            "failed to decode Protobuf message: invalid utf-8"
582        );
583    }
584
585    #[test]
586    fn test_decode_string_lossy_mode_valid_utf8() {
587        let valid_utf8 = Bytes::from("hello world");
588        let result = PromValidationMode::Lossy.decode_string(&valid_utf8);
589        assert!(result.is_ok());
590        assert_eq!(result.unwrap(), "hello world");
591    }
592
593    #[test]
594    fn test_decode_string_lossy_mode_empty() {
595        let empty = Bytes::new();
596        let result = PromValidationMode::Lossy.decode_string(&empty);
597        assert!(result.is_ok());
598        assert_eq!(result.unwrap(), "");
599    }
600
601    #[test]
602    fn test_decode_string_lossy_mode_unicode() {
603        let unicode = Bytes::from("Hello ไธ–็•Œ ๐ŸŒ");
604        let result = PromValidationMode::Lossy.decode_string(&unicode);
605        assert!(result.is_ok());
606        assert_eq!(result.unwrap(), "Hello ไธ–็•Œ ๐ŸŒ");
607    }
608
609    #[test]
610    fn test_decode_string_lossy_mode_invalid_utf8() {
611        // Invalid UTF-8 sequence - should be replaced with replacement character
612        let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
613        let result = PromValidationMode::Lossy.decode_string(&invalid_utf8);
614        assert!(result.is_ok());
615        // Each invalid byte should be replaced with the Unicode replacement character
616        assert_eq!(result.unwrap(), "๏ฟฝ๏ฟฝ๏ฟฝ");
617    }
618
619    #[test]
620    fn test_decode_string_lossy_mode_mixed_valid_invalid() {
621        // Mix of valid and invalid UTF-8
622        let mut mixed = Vec::new();
623        mixed.extend_from_slice(b"hello");
624        mixed.push(0xFF); // Invalid byte
625        mixed.extend_from_slice(b"world");
626        let mixed_utf8 = Bytes::from(mixed);
627
628        let result = PromValidationMode::Lossy.decode_string(&mixed_utf8);
629        assert!(result.is_ok());
630        assert_eq!(result.unwrap(), "hello๏ฟฝworld");
631    }
632
633    #[test]
634    fn test_decode_string_unchecked_mode_valid_utf8() {
635        let valid_utf8 = Bytes::from("hello world");
636        let result = PromValidationMode::Unchecked.decode_string(&valid_utf8);
637        assert!(result.is_ok());
638        assert_eq!(result.unwrap(), "hello world");
639    }
640
641    #[test]
642    fn test_decode_string_unchecked_mode_empty() {
643        let empty = Bytes::new();
644        let result = PromValidationMode::Unchecked.decode_string(&empty);
645        assert!(result.is_ok());
646        assert_eq!(result.unwrap(), "");
647    }
648
649    #[test]
650    fn test_decode_string_unchecked_mode_unicode() {
651        let unicode = Bytes::from("Hello ไธ–็•Œ ๐ŸŒ");
652        let result = PromValidationMode::Unchecked.decode_string(&unicode);
653        assert!(result.is_ok());
654        assert_eq!(result.unwrap(), "Hello ไธ–็•Œ ๐ŸŒ");
655    }
656
657    #[test]
658    fn test_decode_string_unchecked_mode_invalid_utf8() {
659        // Invalid UTF-8 sequence - unchecked mode doesn't validate
660        let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
661        let result = PromValidationMode::Unchecked.decode_string(&invalid_utf8);
662        // This should succeed but the resulting string may contain invalid UTF-8
663        assert!(result.is_ok());
664        // We can't easily test the exact content since it's invalid UTF-8,
665        // but we can verify it doesn't panic and returns something
666        let _string = result.unwrap();
667    }
668
669    #[test]
670    fn test_decode_string_all_modes_ascii() {
671        let ascii = Bytes::from("simple_ascii_123");
672
673        // All modes should handle ASCII identically
674        let strict_result = PromValidationMode::Strict.decode_string(&ascii).unwrap();
675        let lossy_result = PromValidationMode::Lossy.decode_string(&ascii).unwrap();
676        let unchecked_result = PromValidationMode::Unchecked.decode_string(&ascii).unwrap();
677
678        assert_eq!(strict_result, "simple_ascii_123");
679        assert_eq!(lossy_result, "simple_ascii_123");
680        assert_eq!(unchecked_result, "simple_ascii_123");
681        assert_eq!(strict_result, lossy_result);
682        assert_eq!(lossy_result, unchecked_result);
683    }
684}