servers/prom_remote_write/
decode.rs1use std::collections::BTreeMap;
18use std::collections::btree_map::Entry;
19
20use api::prom_store::remote::Sample;
21use bytes::Buf;
22use common_query::prelude::{greptime_timestamp, greptime_value};
23use pipeline::{ContextReq, GreptimePipelineParams, PipelineContext, PipelineDefinition};
24use prost::DecodeError;
25use prost::encoding::message::merge;
26use prost::encoding::{WireType, decode_key, decode_varint};
27use session::context::QueryContextRef;
28use snafu::OptionExt;
29use vrl::prelude::NotNan;
30use vrl::value::{KeyString, Value as VrlValue};
31
32use crate::error::InternalSnafu;
33use crate::http::event::PipelineIngestRequest;
34use crate::pipeline::run_pipeline;
35use crate::prom_remote_write::row_builder::{PromCtx, TablesBuilder};
36use crate::prom_remote_write::types::PromLabel;
37use crate::prom_remote_write::validation::PromValidationMode;
38#[allow(deprecated)]
39use crate::prom_store::{
40 DATABASE_LABEL_ALT_BYTES, DATABASE_LABEL_BYTES, METRIC_NAME_LABEL_BYTES,
41 PHYSICAL_TABLE_LABEL_ALT_BYTES, PHYSICAL_TABLE_LABEL_BYTES, SCHEMA_LABEL_BYTES,
42};
43use crate::query_handler::PipelineHandlerRef;
44use crate::repeated_field::{Clear, RepeatedField};
45
46#[derive(Default, Debug)]
47pub(crate) struct PromTimeSeries {
48 pub(crate) table_name: String,
49 pub(crate) schema: Option<String>,
50 pub(crate) physical_table: Option<String>,
51
52 pub(crate) labels: RepeatedField<PromLabel>,
53 pub(crate) samples: RepeatedField<Sample>,
54}
55
56impl Clear for PromTimeSeries {
57 fn clear(&mut self) {
58 self.table_name.clear();
59 self.labels.clear();
60 self.samples.clear();
61 }
62}
63
64impl PromTimeSeries {
65 pub fn merge_field(
66 &mut self,
67 tag: u32,
68 wire_type: WireType,
69 buf: &mut &[u8],
70 prom_validation_mode: PromValidationMode,
71 ) -> Result<(), DecodeError> {
72 const STRUCT_NAME: &str = "PromTimeSeries";
73 match tag {
74 1u32 => {
75 let label = self.labels.push_default();
76
77 let len = decode_varint(buf).map_err(|mut error| {
78 error.push(STRUCT_NAME, "labels");
79 error
80 })?;
81 let remaining = buf.remaining();
82 if len > remaining as u64 {
83 return Err(DecodeError::new("buffer underflow"));
84 }
85
86 let limit = remaining - len as usize;
87 while buf.remaining() > limit {
88 let (tag, wire_type) = decode_key(buf)?;
89 label.merge_field(tag, wire_type, buf)?;
90 }
91 if buf.remaining() != limit {
92 return Err(DecodeError::new("delimited length exceeded"));
93 }
94
95 #[allow(deprecated)]
96 match label.name {
97 METRIC_NAME_LABEL_BYTES => {
98 self.table_name = prom_validation_mode.decode_string(label.value)?;
99 self.labels.truncate(self.labels.len() - 1);
100 }
101 SCHEMA_LABEL_BYTES => {
102 self.schema = Some(prom_validation_mode.decode_string(label.value)?);
103 self.labels.truncate(self.labels.len() - 1);
104 }
105 DATABASE_LABEL_BYTES | DATABASE_LABEL_ALT_BYTES => {
106 if self.schema.is_none() {
107 self.schema = Some(prom_validation_mode.decode_string(label.value)?);
108 }
109 self.labels.truncate(self.labels.len() - 1);
110 }
111 PHYSICAL_TABLE_LABEL_BYTES | PHYSICAL_TABLE_LABEL_ALT_BYTES => {
112 self.physical_table =
113 Some(prom_validation_mode.decode_string(label.value)?);
114 self.labels.truncate(self.labels.len() - 1);
115 }
116 _ => {}
117 }
118
119 Ok(())
120 }
121 2u32 => {
122 let sample = self.samples.push_default();
123 merge(WireType::LengthDelimited, sample, buf, Default::default()).map_err(
124 |mut error| {
125 error.push(STRUCT_NAME, "samples");
126 error
127 },
128 )?;
129 Ok(())
130 }
131 3u32 => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
132 _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default()),
133 }
134 }
135
136 fn add_to_table_data<'a>(
137 &mut self,
138 table_builders: &mut TablesBuilder<'a>,
139 prom_validation_mode: PromValidationMode,
140 ) -> Result<(), DecodeError> {
141 let label_num = self.labels.len();
142 let row_num = self.samples.len();
143
144 let prom_ctx = PromCtx {
145 schema: self.schema.take(),
146 physical_table: self.physical_table.take(),
147 };
148
149 let table_data = table_builders.get_or_create_table_builder(
150 prom_ctx,
151 std::mem::take(&mut self.table_name),
152 label_num,
153 row_num,
154 );
155 table_data.add_labels_and_samples(
156 self.labels.as_slice(),
157 self.samples.as_slice(),
158 prom_validation_mode,
159 )?;
160
161 Ok(())
162 }
163}
164
165#[derive(Default, Debug)]
166pub struct PromWriteRequest<'a> {
167 pub(crate) table_data: TablesBuilder<'a>,
168 series: PromTimeSeries,
169}
170
171impl<'a> Clear for PromWriteRequest<'a> {
172 fn clear(&mut self) {
173 self.table_data.clear();
174 }
175}
176
177impl<'a> PromWriteRequest<'a> {
178 pub fn as_row_insert_requests(&mut self) -> ContextReq {
179 self.table_data.as_insert_requests()
180 }
181
182 pub fn decode(
183 &mut self,
184 buf: Vec<u8>,
185 prom_validation_mode: PromValidationMode,
186 processor: &mut PromSeriesProcessor,
187 ) -> Result<(), DecodeError> {
188 const STRUCT_NAME: &str = "PromWriteRequest";
189 self.table_data.set_raw_data(buf);
190 let mut offset = 0;
191 while offset < self.table_data.raw_data.len() {
192 let mut should_add_to_table_data = false;
193 let mut decoded_timeseries = false;
194 {
195 let raw_data = &self.table_data.raw_data;
196 let buf = &mut &raw_data[offset..];
197 let (tag, wire_type) = decode_key(buf)?;
198 if wire_type != WireType::LengthDelimited {
199 return Err(DecodeError::new(format!(
200 "invalid wire type: {:?}",
201 wire_type
202 )));
203 }
204 match tag {
205 1u32 => {
206 let len = decode_varint(buf).map_err(|mut e| {
207 e.push(STRUCT_NAME, "timeseries");
208 e
209 })?;
210 let remaining = buf.remaining();
211 if len > remaining as u64 {
212 return Err(DecodeError::new("buffer underflow"));
213 }
214
215 let limit = remaining - len as usize;
216 while buf.remaining() > limit {
217 let (tag, wire_type) = decode_key(buf)?;
218 self.series
219 .merge_field(tag, wire_type, buf, prom_validation_mode)?;
220 }
221 if buf.remaining() != limit {
222 return Err(DecodeError::new("delimited length exceeded"));
223 }
224
225 if processor.use_pipeline {
226 processor.consume_series_to_pipeline_map(
227 &mut self.series,
228 prom_validation_mode,
229 )?;
230 } else {
231 should_add_to_table_data = true;
232 }
233
234 decoded_timeseries = true;
235 }
236 3u32 => {
237 prost::encoding::skip_field(wire_type, tag, buf, Default::default())?;
238 }
239 _ => prost::encoding::skip_field(wire_type, tag, buf, Default::default())?,
240 }
241 offset = raw_data.len() - buf.remaining();
242 }
243
244 if should_add_to_table_data {
245 self.series
246 .add_to_table_data(&mut self.table_data, prom_validation_mode)?;
247 }
248
249 if decoded_timeseries {
250 self.series.labels.clear();
251 self.series.samples.clear();
252 }
253 }
254
255 Ok(())
256 }
257}
258
259pub struct PromSeriesProcessor {
261 pub(crate) use_pipeline: bool,
262 pub(crate) table_values: BTreeMap<String, Vec<VrlValue>>,
263
264 pub(crate) pipeline_handler: Option<PipelineHandlerRef>,
265 pub(crate) query_ctx: Option<QueryContextRef>,
266 pub(crate) pipeline_def: Option<PipelineDefinition>,
267}
268
269impl PromSeriesProcessor {
270 pub fn default_processor() -> Self {
271 Self {
272 use_pipeline: false,
273 table_values: BTreeMap::new(),
274 pipeline_handler: None,
275 query_ctx: None,
276 pipeline_def: None,
277 }
278 }
279
280 pub fn set_pipeline(
281 &mut self,
282 handler: PipelineHandlerRef,
283 query_ctx: QueryContextRef,
284 pipeline_def: PipelineDefinition,
285 ) {
286 self.use_pipeline = true;
287 self.pipeline_handler = Some(handler);
288 self.query_ctx = Some(query_ctx);
289 self.pipeline_def = Some(pipeline_def);
290 }
291
292 pub(crate) fn consume_series_to_pipeline_map(
293 &mut self,
294 series: &mut PromTimeSeries,
295 prom_validation_mode: PromValidationMode,
296 ) -> Result<(), DecodeError> {
297 let mut vec_pipeline_map = Vec::new();
298 let mut pipeline_map = BTreeMap::new();
299 for l in series.labels.iter() {
300 let name = prom_validation_mode.decode_label_name(l.name)?;
301 let value = prom_validation_mode.decode_string(l.value)?;
302 pipeline_map.insert(KeyString::from(name), VrlValue::Bytes(value.into()));
303 }
304
305 let one_sample = series.samples.len() == 1;
306
307 for s in series.samples.iter() {
308 let Ok(value) = NotNan::new(s.value) else {
309 common_telemetry::warn!("Invalid float value: {}", s.value);
310 continue;
311 };
312
313 let timestamp = s.timestamp;
314 pipeline_map.insert(
315 KeyString::from(greptime_timestamp()),
316 VrlValue::Integer(timestamp),
317 );
318 pipeline_map.insert(KeyString::from(greptime_value()), VrlValue::Float(value));
319 if one_sample {
320 vec_pipeline_map.push(VrlValue::Object(pipeline_map));
321 break;
322 } else {
323 vec_pipeline_map.push(VrlValue::Object(pipeline_map.clone()));
324 }
325 }
326
327 let table_name = std::mem::take(&mut series.table_name);
328 match self.table_values.entry(table_name) {
329 Entry::Occupied(mut occupied_entry) => {
330 occupied_entry.get_mut().append(&mut vec_pipeline_map);
331 }
332 Entry::Vacant(vacant_entry) => {
333 vacant_entry.insert(vec_pipeline_map);
334 }
335 }
336
337 Ok(())
338 }
339
340 pub(crate) async fn exec_pipeline(&mut self) -> crate::error::Result<ContextReq> {
341 let handler = self.pipeline_handler.as_ref().context(InternalSnafu {
342 err_msg: "pipeline handler is not set",
343 })?;
344 let pipeline_def = self.pipeline_def.as_ref().context(InternalSnafu {
345 err_msg: "pipeline definition is not set",
346 })?;
347 let pipeline_param = GreptimePipelineParams::default();
348 let query_ctx = self.query_ctx.as_ref().context(InternalSnafu {
349 err_msg: "query context is not set",
350 })?;
351
352 let pipeline_ctx = PipelineContext::new(pipeline_def, &pipeline_param, query_ctx.channel());
353
354 let mut req = ContextReq::default();
355 let table_values = std::mem::take(&mut self.table_values);
356 for (table_name, pipeline_maps) in table_values.into_iter() {
357 let pipeline_req = PipelineIngestRequest {
358 table: table_name,
359 values: pipeline_maps,
360 };
361 let row_req =
362 run_pipeline(handler, &pipeline_ctx, pipeline_req, query_ctx, true).await?;
363 req.merge(row_req);
364 }
365
366 Ok(req)
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use std::collections::HashMap;
373
374 use api::prom_store::remote::WriteRequest;
375 use api::v1::{Row, RowInsertRequests, Rows};
376 use bytes::Bytes;
377 use prost::Message;
378
379 use super::*;
380 use crate::prom_store::to_grpc_row_insert_requests;
381 use crate::repeated_field::Clear;
382
383 fn sort_rows(rows: Rows) -> Rows {
384 let permutation =
385 permutation::sort_by_key(&rows.schema, |schema| schema.column_name.clone());
386 let schema = permutation.apply_slice(&rows.schema);
387 let mut inner_rows = vec![];
388 for row in rows.rows {
389 let values = permutation.apply_slice(&row.values);
390 inner_rows.push(Row { values });
391 }
392 Rows {
393 schema,
394 rows: inner_rows,
395 }
396 }
397
398 fn check_deserialized(
399 prom_write_request: &mut PromWriteRequest,
400 data: &[u8],
401 expected_samples: usize,
402 expected_rows: &RowInsertRequests,
403 ) {
404 let mut p = PromSeriesProcessor::default_processor();
405 prom_write_request.clear();
406 prom_write_request
407 .decode(data.to_owned(), PromValidationMode::Strict, &mut p)
408 .unwrap();
409
410 let req = prom_write_request.as_row_insert_requests();
411
412 let samples = req
413 .ref_all_req()
414 .filter_map(|r| r.rows.as_ref().map(|r| r.rows.len()))
415 .sum::<usize>();
416 let prom_rows = RowInsertRequests {
417 inserts: req.all_req().collect::<Vec<_>>(),
418 };
419
420 assert_eq!(expected_samples, samples);
421 assert_eq!(expected_rows.inserts.len(), prom_rows.inserts.len());
422
423 let expected_rows_map = expected_rows
424 .inserts
425 .iter()
426 .map(|insert| (insert.table_name.clone(), insert.rows.clone().unwrap()))
427 .collect::<HashMap<_, _>>();
428
429 for r in &prom_rows.inserts {
430 let expected_rows = expected_rows_map.get(&r.table_name).unwrap().clone();
431 assert_eq!(sort_rows(expected_rows), sort_rows(r.rows.clone().unwrap()));
432 }
433 }
434
435 #[test]
436 fn test_decode_write_request() {
437 let mut d = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
438 d.push("benches");
439 d.push("write_request.pb.data");
440 let data = std::fs::read(d).unwrap();
441
442 let (expected_rows, expected_samples) =
443 to_grpc_row_insert_requests(&WriteRequest::decode(&data[..]).unwrap()).unwrap();
444
445 let mut prom_write_request = PromWriteRequest::default();
446 for _ in 0..3 {
447 check_deserialized(
448 &mut prom_write_request,
449 &data,
450 expected_samples,
451 &expected_rows,
452 );
453 }
454 }
455
456 #[test]
457 fn test_decode_string_strict_mode_valid_utf8() {
458 let valid_utf8 = Bytes::from("hello world");
459 let result = PromValidationMode::Strict.decode_string(&valid_utf8);
460 assert!(result.is_ok());
461 assert_eq!(result.unwrap(), "hello world");
462 }
463
464 #[test]
465 fn test_decode_string_all_modes_ascii() {
466 let ascii = Bytes::from("simple_ascii_123");
467 let strict_result = PromValidationMode::Strict.decode_string(&ascii).unwrap();
468 let lossy_result = PromValidationMode::Lossy.decode_string(&ascii).unwrap();
469 let unchecked_result = PromValidationMode::Unchecked.decode_string(&ascii).unwrap();
470 assert_eq!(strict_result, "simple_ascii_123");
471 assert_eq!(lossy_result, "simple_ascii_123");
472 assert_eq!(unchecked_result, "simple_ascii_123");
473 }
474}