servers/prom_remote_write/
decode.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Decoding of Prometheus remote write protobuf payloads.
16
17use std::collections::BTreeMap;
18use std::collections::btree_map::Entry;
19
20use api::prom_store::remote::Sample;
21use bytes::Buf;
22use common_query::prelude::{greptime_timestamp, greptime_value};
23use pipeline::{ContextReq, GreptimePipelineParams, PipelineContext, PipelineDefinition};
24use prost::DecodeError;
25use prost::encoding::message::merge;
26use prost::encoding::{WireType, decode_key, decode_varint};
27use session::context::QueryContextRef;
28use snafu::OptionExt;
29use vrl::prelude::NotNan;
30use vrl::value::{KeyString, Value as VrlValue};
31
32use crate::error::InternalSnafu;
33use crate::http::event::PipelineIngestRequest;
34use crate::pipeline::run_pipeline;
35use crate::prom_remote_write::row_builder::{PromCtx, TablesBuilder};
36use crate::prom_remote_write::types::PromLabel;
37use crate::prom_remote_write::validation::PromValidationMode;
38#[allow(deprecated)]
39use crate::prom_store::{
40    DATABASE_LABEL_ALT_BYTES, DATABASE_LABEL_BYTES, METRIC_NAME_LABEL_BYTES,
41    PHYSICAL_TABLE_LABEL_ALT_BYTES, PHYSICAL_TABLE_LABEL_BYTES, SCHEMA_LABEL_BYTES,
42};
43use crate::query_handler::PipelineHandlerRef;
44use crate::repeated_field::{Clear, RepeatedField};
45
46#[derive(Default, Debug)]
47pub(crate) struct PromTimeSeries {
48    pub(crate) table_name: String,
49    pub(crate) schema: Option<String>,
50    pub(crate) physical_table: Option<String>,
51
52    pub(crate) labels: RepeatedField<PromLabel>,
53    pub(crate) samples: RepeatedField<Sample>,
54}
55
56impl Clear for PromTimeSeries {
57    fn clear(&mut self) {
58        self.table_name.clear();
59        self.labels.clear();
60        self.samples.clear();
61    }
62}
63
64impl PromTimeSeries {
65    pub fn merge_field(
66        &mut self,
67        tag: u32,
68        wire_type: WireType,
69        buf: &mut &[u8],
70        prom_validation_mode: PromValidationMode,
71    ) -> Result<(), DecodeError> {
72        const STRUCT_NAME: &str = "PromTimeSeries";
73        match tag {
74            1u32 => {
75                let label = self.labels.push_default();
76
77                let len = decode_varint(buf).map_err(|mut error| {
78                    error.push(STRUCT_NAME, "labels");
79                    error
80                })?;
81                let remaining = buf.remaining();
82                if len > remaining as u64 {
83                    return Err(DecodeError::new("buffer underflow"));
84                }
85
86                let limit = remaining - len as usize;
87                while buf.remaining() > limit {
88                    let (tag, wire_type) = decode_key(buf)?;
89                    label.merge_field(tag, wire_type, buf)?;
90                }
91                if buf.remaining() != limit {
92                    return Err(DecodeError::new("delimited length exceeded"));
93                }
94
95                #[allow(deprecated)]
96                match label.name {
97                    METRIC_NAME_LABEL_BYTES => {
98                        self.table_name = prom_validation_mode.decode_string(label.value)?;
99                        self.labels.truncate(self.labels.len() - 1);
100                    }
101                    SCHEMA_LABEL_BYTES => {
102                        self.schema = Some(prom_validation_mode.decode_string(label.value)?);
103                        self.labels.truncate(self.labels.len() - 1);
104                    }
105                    DATABASE_LABEL_BYTES | DATABASE_LABEL_ALT_BYTES => {
106                        if self.schema.is_none() {
107                            self.schema = Some(prom_validation_mode.decode_string(label.value)?);
108                        }
109                        self.labels.truncate(self.labels.len() - 1);
110                    }
111                    PHYSICAL_TABLE_LABEL_BYTES | PHYSICAL_TABLE_LABEL_ALT_BYTES => {
112                        self.physical_table =
113                            Some(prom_validation_mode.decode_string(label.value)?);
114                        self.labels.truncate(self.labels.len() - 1);
115                    }
116                    _ => {}
117                }
118
119                Ok(())
120            }
121            2u32 => {
122                let sample = self.samples.push_default();
123                merge(WireType::LengthDelimited, sample, buf, Default::default()).map_err(
124                    |mut error| {
125                        error.push(STRUCT_NAME, "samples");
126                        error
127                    },
128                )?;
129                Ok(())
130            }
131            3u32 => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
132            _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
133        }
134    }
135
136    fn add_to_table_data<'a>(
137        &mut self,
138        table_builders: &mut TablesBuilder<'a>,
139        prom_validation_mode: PromValidationMode,
140    ) -> Result<(), DecodeError> {
141        let label_num = self.labels.len();
142        let row_num = self.samples.len();
143
144        let prom_ctx = PromCtx {
145            schema: self.schema.take(),
146            physical_table: self.physical_table.take(),
147        };
148
149        let table_data = table_builders.get_or_create_table_builder(
150            prom_ctx,
151            std::mem::take(&mut self.table_name),
152            label_num,
153            row_num,
154        );
155        table_data.add_labels_and_samples(
156            self.labels.as_slice(),
157            self.samples.as_slice(),
158            prom_validation_mode,
159        )?;
160
161        Ok(())
162    }
163}
164
165#[derive(Default, Debug)]
166pub struct PromWriteRequest<'a> {
167    pub(crate) table_data: TablesBuilder<'a>,
168    series: PromTimeSeries,
169}
170
171impl<'a> Clear for PromWriteRequest<'a> {
172    fn clear(&mut self) {
173        self.table_data.clear();
174    }
175}
176
177impl<'a> PromWriteRequest<'a> {
178    pub fn as_row_insert_requests(&mut self) -> ContextReq {
179        self.table_data.as_insert_requests()
180    }
181
182    pub fn decode(
183        &mut self,
184        buf: Vec<u8>,
185        prom_validation_mode: PromValidationMode,
186        processor: &mut PromSeriesProcessor,
187    ) -> Result<(), DecodeError> {
188        const STRUCT_NAME: &str = "PromWriteRequest";
189        self.table_data.set_raw_data(buf);
190        let mut offset = 0;
191        while offset < self.table_data.raw_data.len() {
192            let mut should_add_to_table_data = false;
193            let mut decoded_timeseries = false;
194            {
195                let raw_data = &self.table_data.raw_data;
196                let buf = &mut &raw_data[offset..];
197                let (tag, wire_type) = decode_key(buf)?;
198                if wire_type != WireType::LengthDelimited {
199                    return Err(DecodeError::new(format!(
200                        "invalid wire type: {:?}",
201                        wire_type
202                    )));
203                }
204                match tag {
205                    1u32 => {
206                        let len = decode_varint(buf).map_err(|mut e| {
207                            e.push(STRUCT_NAME, "timeseries");
208                            e
209                        })?;
210                        let remaining = buf.remaining();
211                        if len > remaining as u64 {
212                            return Err(DecodeError::new("buffer underflow"));
213                        }
214
215                        let limit = remaining - len as usize;
216                        while buf.remaining() > limit {
217                            let (tag, wire_type) = decode_key(buf)?;
218                            self.series
219                                .merge_field(tag, wire_type, buf, prom_validation_mode)?;
220                        }
221                        if buf.remaining() != limit {
222                            return Err(DecodeError::new("delimited length exceeded"));
223                        }
224
225                        if processor.use_pipeline {
226                            processor.consume_series_to_pipeline_map(
227                                &mut self.series,
228                                prom_validation_mode,
229                            )?;
230                        } else {
231                            should_add_to_table_data = true;
232                        }
233
234                        decoded_timeseries = true;
235                    }
236                    3u32 => {
237                        prost::encoding::skip_field(wire_type, tag, buf, Default::default())?;
238                    }
239                    _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default())?,
240                }
241                offset = raw_data.len() - buf.remaining();
242            }
243
244            if should_add_to_table_data {
245                self.series
246                    .add_to_table_data(&mut self.table_data, prom_validation_mode)?;
247            }
248
249            if decoded_timeseries {
250                self.series.labels.clear();
251                self.series.samples.clear();
252            }
253        }
254
255        Ok(())
256    }
257}
258
259/// Hook injected into the PromWriteRequest decoding process.
260pub struct PromSeriesProcessor {
261    pub(crate) use_pipeline: bool,
262    pub(crate) table_values: BTreeMap<String, Vec<VrlValue>>,
263
264    pub(crate) pipeline_handler: Option<PipelineHandlerRef>,
265    pub(crate) query_ctx: Option<QueryContextRef>,
266    pub(crate) pipeline_def: Option<PipelineDefinition>,
267}
268
269impl PromSeriesProcessor {
270    pub fn default_processor() -> Self {
271        Self {
272            use_pipeline: false,
273            table_values: BTreeMap::new(),
274            pipeline_handler: None,
275            query_ctx: None,
276            pipeline_def: None,
277        }
278    }
279
280    pub fn set_pipeline(
281        &mut self,
282        handler: PipelineHandlerRef,
283        query_ctx: QueryContextRef,
284        pipeline_def: PipelineDefinition,
285    ) {
286        self.use_pipeline = true;
287        self.pipeline_handler = Some(handler);
288        self.query_ctx = Some(query_ctx);
289        self.pipeline_def = Some(pipeline_def);
290    }
291
292    pub(crate) fn consume_series_to_pipeline_map(
293        &mut self,
294        series: &mut PromTimeSeries,
295        prom_validation_mode: PromValidationMode,
296    ) -> Result<(), DecodeError> {
297        let mut vec_pipeline_map = Vec::new();
298        let mut pipeline_map = BTreeMap::new();
299        for l in series.labels.iter() {
300            let name = prom_validation_mode.decode_label_name(l.name)?;
301            let value = prom_validation_mode.decode_string(l.value)?;
302            pipeline_map.insert(KeyString::from(name), VrlValue::Bytes(value.into()));
303        }
304
305        let one_sample = series.samples.len() == 1;
306
307        for s in series.samples.iter() {
308            let Ok(value) = NotNan::new(s.value) else {
309                common_telemetry::warn!("Invalid float value: {}", s.value);
310                continue;
311            };
312
313            let timestamp = s.timestamp;
314            pipeline_map.insert(
315                KeyString::from(greptime_timestamp()),
316                VrlValue::Integer(timestamp),
317            );
318            pipeline_map.insert(KeyString::from(greptime_value()), VrlValue::Float(value));
319            if one_sample {
320                vec_pipeline_map.push(VrlValue::Object(pipeline_map));
321                break;
322            } else {
323                vec_pipeline_map.push(VrlValue::Object(pipeline_map.clone()));
324            }
325        }
326
327        let table_name = std::mem::take(&mut series.table_name);
328        match self.table_values.entry(table_name) {
329            Entry::Occupied(mut occupied_entry) => {
330                occupied_entry.get_mut().append(&mut vec_pipeline_map);
331            }
332            Entry::Vacant(vacant_entry) => {
333                vacant_entry.insert(vec_pipeline_map);
334            }
335        }
336
337        Ok(())
338    }
339
340    pub(crate) async fn exec_pipeline(&mut self) -> crate::error::Result<ContextReq> {
341        let handler = self.pipeline_handler.as_ref().context(InternalSnafu {
342            err_msg: "pipeline handler is not set",
343        })?;
344        let pipeline_def = self.pipeline_def.as_ref().context(InternalSnafu {
345            err_msg: "pipeline definition is not set",
346        })?;
347        let pipeline_param = GreptimePipelineParams::default();
348        let query_ctx = self.query_ctx.as_ref().context(InternalSnafu {
349            err_msg: "query context is not set",
350        })?;
351
352        let pipeline_ctx = PipelineContext::new(pipeline_def, &pipeline_param, query_ctx.channel());
353
354        let mut req = ContextReq::default();
355        let table_values = std::mem::take(&mut self.table_values);
356        for (table_name, pipeline_maps) in table_values.into_iter() {
357            let pipeline_req = PipelineIngestRequest {
358                table: table_name,
359                values: pipeline_maps,
360            };
361            let row_req =
362                run_pipeline(handler, &pipeline_ctx, pipeline_req, query_ctx, true).await?;
363            req.merge(row_req);
364        }
365
366        Ok(req)
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use std::collections::HashMap;
373
374    use api::prom_store::remote::WriteRequest;
375    use api::v1::{Row, RowInsertRequests, Rows};
376    use bytes::Bytes;
377    use prost::Message;
378
379    use super::*;
380    use crate::prom_store::to_grpc_row_insert_requests;
381    use crate::repeated_field::Clear;
382
383    fn sort_rows(rows: Rows) -> Rows {
384        let permutation =
385            permutation::sort_by_key(&rows.schema, |schema| schema.column_name.clone());
386        let schema = permutation.apply_slice(&rows.schema);
387        let mut inner_rows = vec![];
388        for row in rows.rows {
389            let values = permutation.apply_slice(&row.values);
390            inner_rows.push(Row { values });
391        }
392        Rows {
393            schema,
394            rows: inner_rows,
395        }
396    }
397
398    fn check_deserialized(
399        prom_write_request: &mut PromWriteRequest,
400        data: &[u8],
401        expected_samples: usize,
402        expected_rows: &RowInsertRequests,
403    ) {
404        let mut p = PromSeriesProcessor::default_processor();
405        prom_write_request.clear();
406        prom_write_request
407            .decode(data.to_owned(), PromValidationMode::Strict, &mut p)
408            .unwrap();
409
410        let req = prom_write_request.as_row_insert_requests();
411
412        let samples = req
413            .ref_all_req()
414            .filter_map(|r| r.rows.as_ref().map(|r| r.rows.len()))
415            .sum::<usize>();
416        let prom_rows = RowInsertRequests {
417            inserts: req.all_req().collect::<Vec<_>>(),
418        };
419
420        assert_eq!(expected_samples, samples);
421        assert_eq!(expected_rows.inserts.len(), prom_rows.inserts.len());
422
423        let expected_rows_map = expected_rows
424            .inserts
425            .iter()
426            .map(|insert| (insert.table_name.clone(), insert.rows.clone().unwrap()))
427            .collect::<HashMap<_, _>>();
428
429        for r in &prom_rows.inserts {
430            let expected_rows = expected_rows_map.get(&r.table_name).unwrap().clone();
431            assert_eq!(sort_rows(expected_rows), sort_rows(r.rows.clone().unwrap()));
432        }
433    }
434
435    #[test]
436    fn test_decode_write_request() {
437        let mut d = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
438        d.push("benches");
439        d.push("write_request.pb.data");
440        let data = std::fs::read(d).unwrap();
441
442        let (expected_rows, expected_samples) =
443            to_grpc_row_insert_requests(&WriteRequest::decode(&data[..]).unwrap()).unwrap();
444
445        let mut prom_write_request = PromWriteRequest::default();
446        for _ in 0..3 {
447            check_deserialized(
448                &mut prom_write_request,
449                &data,
450                expected_samples,
451                &expected_rows,
452            );
453        }
454    }
455
456    #[test]
457    fn test_decode_string_strict_mode_valid_utf8() {
458        let valid_utf8 = Bytes::from("hello world");
459        let result = PromValidationMode::Strict.decode_string(&valid_utf8);
460        assert!(result.is_ok());
461        assert_eq!(result.unwrap(), "hello world");
462    }
463
464    #[test]
465    fn test_decode_string_all_modes_ascii() {
466        let ascii = Bytes::from("simple_ascii_123");
467        let strict_result = PromValidationMode::Strict.decode_string(&ascii).unwrap();
468        let lossy_result = PromValidationMode::Lossy.decode_string(&ascii).unwrap();
469        let unchecked_result = PromValidationMode::Unchecked.decode_string(&ascii).unwrap();
470        assert_eq!(strict_result, "simple_ascii_123");
471        assert_eq!(lossy_result, "simple_ascii_123");
472        assert_eq!(unchecked_result, "simple_ascii_123");
473    }
474}