servers/
proto.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::btree_map::Entry;
16use std::collections::BTreeMap;
17use std::ops::Deref;
18use std::slice;
19
20use api::prom_store::remote::Sample;
21use api::v1::RowInsertRequest;
22use bytes::{Buf, Bytes};
23use common_query::prelude::{GREPTIME_TIMESTAMP, GREPTIME_VALUE};
24use common_telemetry::debug;
25use pipeline::{
26    ContextReq, GreptimePipelineParams, PipelineContext, PipelineDefinition, PipelineMap, Value,
27};
28use prost::encoding::message::merge;
29use prost::encoding::{decode_key, decode_varint, WireType};
30use prost::DecodeError;
31use session::context::QueryContextRef;
32use snafu::OptionExt;
33
34use crate::error::InternalSnafu;
35use crate::http::event::PipelineIngestRequest;
36use crate::http::PromValidationMode;
37use crate::pipeline::run_pipeline;
38use crate::prom_row_builder::TablesBuilder;
39use crate::prom_store::METRIC_NAME_LABEL_BYTES;
40use crate::query_handler::PipelineHandlerRef;
41use crate::repeated_field::{Clear, RepeatedField};
42
43impl Clear for Sample {
44    fn clear(&mut self) {
45        self.timestamp = 0;
46        self.value = 0.0;
47    }
48}
49
50#[derive(Default, Clone, Debug)]
51pub struct PromLabel {
52    pub name: Bytes,
53    pub value: Bytes,
54}
55
56impl Clear for PromLabel {
57    fn clear(&mut self) {
58        self.name.clear();
59        self.value.clear();
60    }
61}
62
63impl PromLabel {
64    pub fn merge_field(
65        &mut self,
66        tag: u32,
67        wire_type: WireType,
68        buf: &mut Bytes,
69    ) -> Result<(), DecodeError> {
70        const STRUCT_NAME: &str = "PromLabel";
71        match tag {
72            1u32 => {
73                // decode label name
74                let value = &mut self.name;
75                merge_bytes(value, buf).map_err(|mut error| {
76                    error.push(STRUCT_NAME, "name");
77                    error
78                })
79            }
80            2u32 => {
81                // decode label value
82                let value = &mut self.value;
83                merge_bytes(value, buf).map_err(|mut error| {
84                    error.push(STRUCT_NAME, "value");
85                    error
86                })
87            }
88            _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
89        }
90    }
91}
92
93#[inline(always)]
94fn copy_to_bytes(data: &mut Bytes, len: usize) -> Bytes {
95    if len == data.remaining() {
96        std::mem::replace(data, Bytes::new())
97    } else {
98        let ret = split_to(data, len);
99        data.advance(len);
100        ret
101    }
102}
103
104/// Similar to `Bytes::split_to`, but directly operates on underlying memory region.
105/// # Safety
106/// This function is safe as long as `data` is backed by a consecutive region of memory,
107/// for example `Vec<u8>` or `&[u8]`, and caller must ensure that `buf` outlives
108/// the `Bytes` returned.
109#[inline(always)]
110fn split_to(buf: &mut Bytes, end: usize) -> Bytes {
111    let len = buf.len();
112    assert!(
113        end <= len,
114        "range end out of bounds: {:?} <= {:?}",
115        end,
116        len,
117    );
118
119    if end == 0 {
120        return Bytes::new();
121    }
122
123    let ptr = buf.as_ptr();
124    let x = unsafe { slice::from_raw_parts(ptr, end) };
125    // `Bytes::drop` does nothing when it's built via `from_static`.
126    Bytes::from_static(x)
127}
128
129/// Reads a variable-length encoded bytes field from `buf` and assign it to `value`.
130/// # Safety
131/// Callers must ensure `buf` outlives `value`.
132#[inline(always)]
133fn merge_bytes(value: &mut Bytes, buf: &mut Bytes) -> Result<(), DecodeError> {
134    let len = decode_varint(buf)?;
135    if len > buf.remaining() as u64 {
136        return Err(DecodeError::new(format!(
137            "buffer underflow, len: {}, remaining: {}",
138            len,
139            buf.remaining()
140        )));
141    }
142    *value = copy_to_bytes(buf, len as usize);
143    Ok(())
144}
145
146#[derive(Default, Debug)]
147pub struct PromTimeSeries {
148    pub table_name: String,
149    pub labels: RepeatedField<PromLabel>,
150    pub samples: RepeatedField<Sample>,
151}
152
153impl Clear for PromTimeSeries {
154    fn clear(&mut self) {
155        self.table_name.clear();
156        self.labels.clear();
157        self.samples.clear();
158    }
159}
160
161impl PromTimeSeries {
162    pub fn merge_field(
163        &mut self,
164        tag: u32,
165        wire_type: WireType,
166        buf: &mut Bytes,
167        prom_validation_mode: PromValidationMode,
168    ) -> Result<(), DecodeError> {
169        const STRUCT_NAME: &str = "PromTimeSeries";
170        match tag {
171            1u32 => {
172                // decode labels
173                let label = self.labels.push_default();
174
175                let len = decode_varint(buf).map_err(|mut error| {
176                    error.push(STRUCT_NAME, "labels");
177                    error
178                })?;
179                let remaining = buf.remaining();
180                if len > remaining as u64 {
181                    return Err(DecodeError::new("buffer underflow"));
182                }
183
184                let limit = remaining - len as usize;
185                while buf.remaining() > limit {
186                    let (tag, wire_type) = decode_key(buf)?;
187                    label.merge_field(tag, wire_type, buf)?;
188                }
189                if buf.remaining() != limit {
190                    return Err(DecodeError::new("delimited length exceeded"));
191                }
192                if label.name.deref() == METRIC_NAME_LABEL_BYTES {
193                    self.table_name = decode_string(&label.value, prom_validation_mode)?;
194                    self.labels.truncate(self.labels.len() - 1); // remove last label
195                }
196                Ok(())
197            }
198            2u32 => {
199                let sample = self.samples.push_default();
200                merge(WireType::LengthDelimited, sample, buf, Default::default()).map_err(
201                    |mut error| {
202                        error.push(STRUCT_NAME, "samples");
203                        error
204                    },
205                )?;
206                Ok(())
207            }
208            // todo(hl): exemplars are skipped temporarily
209            3u32 => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
210            _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
211        }
212    }
213
214    fn add_to_table_data(
215        &mut self,
216        table_builders: &mut TablesBuilder,
217        prom_validation_mode: PromValidationMode,
218    ) -> Result<(), DecodeError> {
219        let label_num = self.labels.len();
220        let row_num = self.samples.len();
221        let table_data = table_builders.get_or_create_table_builder(
222            std::mem::take(&mut self.table_name),
223            label_num,
224            row_num,
225        );
226        table_data.add_labels_and_samples(
227            self.labels.as_slice(),
228            self.samples.as_slice(),
229            prom_validation_mode,
230        )?;
231
232        Ok(())
233    }
234}
235
236/// Decodes bytes into String values according provided validation mode.
237pub(crate) fn decode_string(
238    bytes: &Bytes,
239    mode: PromValidationMode,
240) -> Result<String, DecodeError> {
241    let result = match mode {
242        PromValidationMode::Strict => match String::from_utf8(bytes.to_vec()) {
243            Ok(s) => s,
244            Err(e) => {
245                debug!(
246                    "Invalid UTF-8 string value: {:?}, error: {:?}",
247                    &bytes[..],
248                    e
249                );
250                return Err(DecodeError::new("invalid utf-8"));
251            }
252        },
253        PromValidationMode::Lossy => String::from_utf8_lossy(bytes).to_string(),
254        PromValidationMode::Unchecked => unsafe { String::from_utf8_unchecked(bytes.to_vec()) },
255    };
256    Ok(result)
257}
258
259#[derive(Default, Debug)]
260pub struct PromWriteRequest {
261    table_data: TablesBuilder,
262    series: PromTimeSeries,
263}
264
265impl Clear for PromWriteRequest {
266    fn clear(&mut self) {
267        self.table_data.clear();
268    }
269}
270
271impl PromWriteRequest {
272    pub fn as_row_insert_requests(&mut self) -> Vec<RowInsertRequest> {
273        self.table_data.as_insert_requests()
274    }
275
276    // todo(hl): maybe use &[u8] can reduce the overhead introduced with Bytes.
277    pub fn merge(
278        &mut self,
279        mut buf: Bytes,
280        prom_validation_mode: PromValidationMode,
281        processor: &mut PromSeriesProcessor,
282    ) -> Result<(), DecodeError> {
283        const STRUCT_NAME: &str = "PromWriteRequest";
284        while buf.has_remaining() {
285            let (tag, wire_type) = decode_key(&mut buf)?;
286            assert_eq!(WireType::LengthDelimited, wire_type);
287            match tag {
288                1u32 => {
289                    // decode TimeSeries
290                    let len = decode_varint(&mut buf).map_err(|mut e| {
291                        e.push(STRUCT_NAME, "timeseries");
292                        e
293                    })?;
294                    let remaining = buf.remaining();
295                    if len > remaining as u64 {
296                        return Err(DecodeError::new("buffer underflow"));
297                    }
298
299                    let limit = remaining - len as usize;
300                    while buf.remaining() > limit {
301                        let (tag, wire_type) = decode_key(&mut buf)?;
302                        self.series
303                            .merge_field(tag, wire_type, &mut buf, prom_validation_mode)?;
304                    }
305                    if buf.remaining() != limit {
306                        return Err(DecodeError::new("delimited length exceeded"));
307                    }
308
309                    if processor.use_pipeline {
310                        processor.consume_series_to_pipeline_map(&mut self.series)?;
311                    } else {
312                        self.series
313                            .add_to_table_data(&mut self.table_data, prom_validation_mode)?;
314                    }
315
316                    // clear state
317                    self.series.labels.clear();
318                    self.series.samples.clear();
319                }
320                3u32 => {
321                    // todo(hl): metadata are skipped.
322                    prost::encoding::skip_field(wire_type, tag, &mut buf, Default::default())?;
323                }
324                _ => prost::encoding::skip_field(wire_type, tag, &mut buf, Default::default())?,
325            }
326        }
327
328        Ok(())
329    }
330}
331
332/// A hook to be injected into the PromWriteRequest decoding process.
333/// It was originally designed with two usage:
334/// 1. consume one series to desired type, in this case, the pipeline map
335/// 2. convert itself to RowInsertRequests
336///
337/// Since the origin conversion is coupled with PromWriteRequest,
338/// let's keep it that way for now.
339pub struct PromSeriesProcessor {
340    pub(crate) use_pipeline: bool,
341    pub(crate) table_values: BTreeMap<String, Vec<PipelineMap>>,
342
343    // optional fields for pipeline
344    pub(crate) pipeline_handler: Option<PipelineHandlerRef>,
345    pub(crate) query_ctx: Option<QueryContextRef>,
346    pub(crate) pipeline_def: Option<PipelineDefinition>,
347}
348
349impl PromSeriesProcessor {
350    pub fn default_processor() -> Self {
351        Self {
352            use_pipeline: false,
353            table_values: BTreeMap::new(),
354            pipeline_handler: None,
355            query_ctx: None,
356            pipeline_def: None,
357        }
358    }
359
360    pub fn set_pipeline(
361        &mut self,
362        handler: PipelineHandlerRef,
363        query_ctx: QueryContextRef,
364        pipeline_def: PipelineDefinition,
365    ) {
366        self.use_pipeline = true;
367        self.pipeline_handler = Some(handler);
368        self.query_ctx = Some(query_ctx);
369        self.pipeline_def = Some(pipeline_def);
370    }
371
372    // convert one series to pipeline map
373    pub(crate) fn consume_series_to_pipeline_map(
374        &mut self,
375        series: &mut PromTimeSeries,
376    ) -> Result<(), DecodeError> {
377        let mut vec_pipeline_map: Vec<PipelineMap> = Vec::new();
378        let mut pipeline_map = PipelineMap::new();
379        for l in series.labels.iter() {
380            let name = String::from_utf8(l.name.to_vec())
381                .map_err(|_| DecodeError::new("invalid utf-8"))?;
382            let value = String::from_utf8(l.value.to_vec())
383                .map_err(|_| DecodeError::new("invalid utf-8"))?;
384            pipeline_map.insert(name, Value::String(value));
385        }
386
387        let one_sample = series.samples.len() == 1;
388
389        for s in series.samples.iter() {
390            let timestamp = s.timestamp;
391            pipeline_map.insert(GREPTIME_TIMESTAMP.to_string(), Value::Int64(timestamp));
392            pipeline_map.insert(GREPTIME_VALUE.to_string(), Value::Float64(s.value));
393            if one_sample {
394                vec_pipeline_map.push(pipeline_map);
395                break;
396            } else {
397                vec_pipeline_map.push(pipeline_map.clone());
398            }
399        }
400
401        let table_name = std::mem::take(&mut series.table_name);
402        match self.table_values.entry(table_name) {
403            Entry::Occupied(mut occupied_entry) => {
404                occupied_entry.get_mut().append(&mut vec_pipeline_map);
405            }
406            Entry::Vacant(vacant_entry) => {
407                vacant_entry.insert(vec_pipeline_map);
408            }
409        }
410
411        Ok(())
412    }
413
414    pub(crate) async fn exec_pipeline(&mut self) -> crate::error::Result<ContextReq> {
415        // prepare params
416        let handler = self.pipeline_handler.as_ref().context(InternalSnafu {
417            err_msg: "pipeline handler is not set",
418        })?;
419        let pipeline_def = self.pipeline_def.as_ref().context(InternalSnafu {
420            err_msg: "pipeline definition is not set",
421        })?;
422        let pipeline_param = GreptimePipelineParams::default();
423        let query_ctx = self.query_ctx.as_ref().context(InternalSnafu {
424            err_msg: "query context is not set",
425        })?;
426
427        let pipeline_ctx = PipelineContext::new(pipeline_def, &pipeline_param, query_ctx.channel());
428
429        // run pipeline
430        let mut req = ContextReq::default();
431        for (table_name, pipeline_maps) in self.table_values.iter_mut() {
432            let pipeline_req = PipelineIngestRequest {
433                table: table_name.clone(),
434                values: pipeline_maps.clone(),
435            };
436            let row_req =
437                run_pipeline(handler, &pipeline_ctx, pipeline_req, query_ctx, true).await?;
438            req.merge(row_req);
439        }
440
441        Ok(req)
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use std::collections::HashMap;
448
449    use api::prom_store::remote::WriteRequest;
450    use api::v1::{Row, RowInsertRequests, Rows};
451    use bytes::Bytes;
452    use prost::Message;
453
454    use crate::http::PromValidationMode;
455    use crate::prom_store::to_grpc_row_insert_requests;
456    use crate::proto::{decode_string, PromSeriesProcessor, PromWriteRequest};
457    use crate::repeated_field::Clear;
458
459    fn sort_rows(rows: Rows) -> Rows {
460        let permutation =
461            permutation::sort_by_key(&rows.schema, |schema| schema.column_name.clone());
462        let schema = permutation.apply_slice(&rows.schema);
463        let mut inner_rows = vec![];
464        for row in rows.rows {
465            let values = permutation.apply_slice(&row.values);
466            inner_rows.push(Row { values });
467        }
468        Rows {
469            schema,
470            rows: inner_rows,
471        }
472    }
473
474    fn check_deserialized(
475        prom_write_request: &mut PromWriteRequest,
476        data: &Bytes,
477        expected_samples: usize,
478        expected_rows: &RowInsertRequests,
479    ) {
480        let mut p = PromSeriesProcessor::default_processor();
481        prom_write_request.clear();
482        prom_write_request
483            .merge(data.clone(), PromValidationMode::Strict, &mut p)
484            .unwrap();
485
486        let req = prom_write_request.as_row_insert_requests();
487        let samples = req
488            .iter()
489            .filter_map(|r| r.rows.as_ref().map(|r| r.rows.len()))
490            .sum::<usize>();
491        let prom_rows = RowInsertRequests { inserts: req };
492
493        assert_eq!(expected_samples, samples);
494        assert_eq!(expected_rows.inserts.len(), prom_rows.inserts.len());
495
496        let expected_rows_map = expected_rows
497            .inserts
498            .iter()
499            .map(|insert| (insert.table_name.clone(), insert.rows.clone().unwrap()))
500            .collect::<HashMap<_, _>>();
501
502        for r in &prom_rows.inserts {
503            // check value
504            let expected_rows = expected_rows_map.get(&r.table_name).unwrap().clone();
505            assert_eq!(sort_rows(expected_rows), sort_rows(r.rows.clone().unwrap()));
506        }
507    }
508
509    // Ensures `WriteRequest` and `PromWriteRequest` produce the same gRPC request.
510    #[test]
511    fn test_decode_write_request() {
512        let mut d = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
513        d.push("benches");
514        d.push("write_request.pb.data");
515        let data = Bytes::from(std::fs::read(d).unwrap());
516
517        let (expected_rows, expected_samples) =
518            to_grpc_row_insert_requests(&WriteRequest::decode(data.clone()).unwrap()).unwrap();
519
520        let mut prom_write_request = PromWriteRequest::default();
521        for _ in 0..3 {
522            check_deserialized(
523                &mut prom_write_request,
524                &data,
525                expected_samples,
526                &expected_rows,
527            );
528        }
529    }
530
531    #[test]
532    fn test_decode_string_strict_mode_valid_utf8() {
533        let valid_utf8 = Bytes::from("hello world");
534        let result = decode_string(&valid_utf8, PromValidationMode::Strict);
535        assert!(result.is_ok());
536        assert_eq!(result.unwrap(), "hello world");
537    }
538
539    #[test]
540    fn test_decode_string_strict_mode_empty() {
541        let empty = Bytes::new();
542        let result = decode_string(&empty, PromValidationMode::Strict);
543        assert!(result.is_ok());
544        assert_eq!(result.unwrap(), "");
545    }
546
547    #[test]
548    fn test_decode_string_strict_mode_unicode() {
549        let unicode = Bytes::from("Hello ไธ–็•Œ ๐ŸŒ");
550        let result = decode_string(&unicode, PromValidationMode::Strict);
551        assert!(result.is_ok());
552        assert_eq!(result.unwrap(), "Hello ไธ–็•Œ ๐ŸŒ");
553    }
554
555    #[test]
556    fn test_decode_string_strict_mode_invalid_utf8() {
557        // Invalid UTF-8 sequence
558        let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
559        let result = decode_string(&invalid_utf8, PromValidationMode::Strict);
560        assert!(result.is_err());
561        assert_eq!(
562            result.unwrap_err().to_string(),
563            "failed to decode Protobuf message: invalid utf-8"
564        );
565    }
566
567    #[test]
568    fn test_decode_string_strict_mode_incomplete_utf8() {
569        // Incomplete UTF-8 sequence (missing continuation bytes)
570        let incomplete_utf8 = Bytes::from(vec![0xC2]); // Start of 2-byte sequence but missing second byte
571        let result = decode_string(&incomplete_utf8, PromValidationMode::Strict);
572        assert!(result.is_err());
573        assert_eq!(
574            result.unwrap_err().to_string(),
575            "failed to decode Protobuf message: invalid utf-8"
576        );
577    }
578
579    #[test]
580    fn test_decode_string_lossy_mode_valid_utf8() {
581        let valid_utf8 = Bytes::from("hello world");
582        let result = decode_string(&valid_utf8, PromValidationMode::Lossy);
583        assert!(result.is_ok());
584        assert_eq!(result.unwrap(), "hello world");
585    }
586
587    #[test]
588    fn test_decode_string_lossy_mode_empty() {
589        let empty = Bytes::new();
590        let result = decode_string(&empty, PromValidationMode::Lossy);
591        assert!(result.is_ok());
592        assert_eq!(result.unwrap(), "");
593    }
594
595    #[test]
596    fn test_decode_string_lossy_mode_unicode() {
597        let unicode = Bytes::from("Hello ไธ–็•Œ ๐ŸŒ");
598        let result = decode_string(&unicode, PromValidationMode::Lossy);
599        assert!(result.is_ok());
600        assert_eq!(result.unwrap(), "Hello ไธ–็•Œ ๐ŸŒ");
601    }
602
603    #[test]
604    fn test_decode_string_lossy_mode_invalid_utf8() {
605        // Invalid UTF-8 sequence - should be replaced with replacement character
606        let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
607        let result = decode_string(&invalid_utf8, PromValidationMode::Lossy);
608        assert!(result.is_ok());
609        // Each invalid byte should be replaced with the Unicode replacement character
610        assert_eq!(result.unwrap(), "๏ฟฝ๏ฟฝ๏ฟฝ");
611    }
612
613    #[test]
614    fn test_decode_string_lossy_mode_mixed_valid_invalid() {
615        // Mix of valid and invalid UTF-8
616        let mut mixed = Vec::new();
617        mixed.extend_from_slice(b"hello");
618        mixed.push(0xFF); // Invalid byte
619        mixed.extend_from_slice(b"world");
620        let mixed_utf8 = Bytes::from(mixed);
621
622        let result = decode_string(&mixed_utf8, PromValidationMode::Lossy);
623        assert!(result.is_ok());
624        assert_eq!(result.unwrap(), "hello๏ฟฝworld");
625    }
626
627    #[test]
628    fn test_decode_string_unchecked_mode_valid_utf8() {
629        let valid_utf8 = Bytes::from("hello world");
630        let result = decode_string(&valid_utf8, PromValidationMode::Unchecked);
631        assert!(result.is_ok());
632        assert_eq!(result.unwrap(), "hello world");
633    }
634
635    #[test]
636    fn test_decode_string_unchecked_mode_empty() {
637        let empty = Bytes::new();
638        let result = decode_string(&empty, PromValidationMode::Unchecked);
639        assert!(result.is_ok());
640        assert_eq!(result.unwrap(), "");
641    }
642
643    #[test]
644    fn test_decode_string_unchecked_mode_unicode() {
645        let unicode = Bytes::from("Hello ไธ–็•Œ ๐ŸŒ");
646        let result = decode_string(&unicode, PromValidationMode::Unchecked);
647        assert!(result.is_ok());
648        assert_eq!(result.unwrap(), "Hello ไธ–็•Œ ๐ŸŒ");
649    }
650
651    #[test]
652    fn test_decode_string_unchecked_mode_invalid_utf8() {
653        // Invalid UTF-8 sequence - unchecked mode doesn't validate
654        let invalid_utf8 = Bytes::from(vec![0xFF, 0xFE, 0xFD]);
655        let result = decode_string(&invalid_utf8, PromValidationMode::Unchecked);
656        // This should succeed but the resulting string may contain invalid UTF-8
657        assert!(result.is_ok());
658        // We can't easily test the exact content since it's invalid UTF-8,
659        // but we can verify it doesn't panic and returns something
660        let _string = result.unwrap();
661    }
662
663    #[test]
664    fn test_decode_string_all_modes_ascii() {
665        let ascii = Bytes::from("simple_ascii_123");
666
667        // All modes should handle ASCII identically
668        let strict_result = decode_string(&ascii, PromValidationMode::Strict).unwrap();
669        let lossy_result = decode_string(&ascii, PromValidationMode::Lossy).unwrap();
670        let unchecked_result = decode_string(&ascii, PromValidationMode::Unchecked).unwrap();
671
672        assert_eq!(strict_result, "simple_ascii_123");
673        assert_eq!(lossy_result, "simple_ascii_123");
674        assert_eq!(unchecked_result, "simple_ascii_123");
675        assert_eq!(strict_result, lossy_result);
676        assert_eq!(lossy_result, unchecked_result);
677    }
678}