common_telemetry/
tracing_sampler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use opentelemetry::trace::{
18    Link, SamplingDecision, SamplingResult, SpanKind, TraceContextExt, TraceId, TraceState,
19};
20use opentelemetry::KeyValue;
21use opentelemetry_sdk::trace::{Sampler, ShouldSample};
22use serde::{Deserialize, Serialize};
23
24#[derive(Clone, Debug, Serialize, Deserialize)]
25#[serde(default)]
26pub struct TracingSampleOptions {
27    pub default_ratio: f64,
28    pub rules: Vec<TracingSampleRule>,
29}
30
31impl Default for TracingSampleOptions {
32    fn default() -> Self {
33        Self {
34            default_ratio: 1.0,
35            rules: vec![],
36        }
37    }
38}
39
40/// Determine the sampling rate of a span according to the `rules` provided in `RuleSampler`.
41/// For spans that do not hit any `rules`, the `default_ratio` is used.
42#[derive(Clone, Default, Debug, Serialize, Deserialize)]
43#[serde(default)]
44pub struct TracingSampleRule {
45    pub protocol: String,
46    pub request_types: HashSet<String>,
47    pub ratio: f64,
48}
49
50impl TracingSampleRule {
51    pub fn match_rule(&self, protocol: &str, request_type: Option<&str>) -> Option<f64> {
52        if protocol == self.protocol {
53            if self.request_types.is_empty() {
54                Some(self.ratio)
55            } else if let Some(t) = request_type
56                && self.request_types.contains(t)
57            {
58                Some(self.ratio)
59            } else {
60                None
61            }
62        } else {
63            None
64        }
65    }
66}
67
68impl PartialEq for TracingSampleOptions {
69    fn eq(&self, other: &Self) -> bool {
70        self.default_ratio == other.default_ratio && self.rules == other.rules
71    }
72}
73impl PartialEq for TracingSampleRule {
74    fn eq(&self, other: &Self) -> bool {
75        self.protocol == other.protocol
76            && self.request_types == other.request_types
77            && self.ratio == other.ratio
78    }
79}
80
81impl Eq for TracingSampleOptions {}
82impl Eq for TracingSampleRule {}
83
84pub fn create_sampler(opt: &TracingSampleOptions) -> Box<dyn ShouldSample> {
85    if opt.rules.is_empty() {
86        Box::new(Sampler::TraceIdRatioBased(opt.default_ratio))
87    } else {
88        Box::new(opt.clone())
89    }
90}
91
92impl ShouldSample for TracingSampleOptions {
93    fn should_sample(
94        &self,
95        parent_context: Option<&opentelemetry::Context>,
96        trace_id: TraceId,
97        _name: &str,
98        _span_kind: &SpanKind,
99        attributes: &[KeyValue],
100        _links: &[Link],
101    ) -> SamplingResult {
102        let (mut protocol, mut request_type) = (None, None);
103        for kv in attributes {
104            match kv.key.as_str() {
105                "protocol" => protocol = Some(kv.value.as_str()),
106                "request_type" => request_type = Some(kv.value.as_str()),
107                _ => (),
108            }
109        }
110        let ratio = protocol
111            .and_then(|p| {
112                self.rules
113                    .iter()
114                    .find_map(|rule| rule.match_rule(p.as_ref(), request_type.as_deref()))
115            })
116            .unwrap_or(self.default_ratio);
117        SamplingResult {
118            decision: sample_based_on_probability(ratio, trace_id),
119            // No extra attributes ever set by the SDK samplers.
120            attributes: Vec::new(),
121            // all sampler in SDK will not modify trace state.
122            trace_state: match parent_context {
123                Some(ctx) => ctx.span().span_context().trace_state().clone(),
124                None => TraceState::default(),
125            },
126        }
127    }
128}
129
130/// The code here mainly refers to the relevant implementation of
131/// [opentelemetry](https://github.com/open-telemetry/opentelemetry-rust/blob/ef4701055cc39d3448d5e5392812ded00cdd4476/opentelemetry-sdk/src/trace/sampler.rs#L229),
132/// and determines whether the span needs to be collected based on the `TraceId` and sampling rate (i.e. `prob`).
133fn sample_based_on_probability(prob: f64, trace_id: TraceId) -> SamplingDecision {
134    if prob >= 1.0 {
135        SamplingDecision::RecordAndSample
136    } else {
137        let prob_upper_bound = (prob.max(0.0) * (1u64 << 63) as f64) as u64;
138        let bytes = trace_id.to_bytes();
139        let (_, low) = bytes.split_at(8);
140        let trace_id_low = u64::from_be_bytes(low.try_into().unwrap());
141        let rnd_from_trace_id = trace_id_low >> 1;
142
143        if rnd_from_trace_id < prob_upper_bound {
144            SamplingDecision::RecordAndSample
145        } else {
146            SamplingDecision::Drop
147        }
148    }
149}
150
151#[cfg(test)]
152mod test {
153    use std::collections::HashSet;
154
155    use crate::tracing_sampler::TracingSampleRule;
156
157    #[test]
158    fn test_rule() {
159        let rule = TracingSampleRule {
160            protocol: "http".to_string(),
161            request_types: HashSet::new(),
162            ratio: 1.0,
163        };
164        assert_eq!(rule.match_rule("not_http", None), None);
165        assert_eq!(rule.match_rule("http", None), Some(1.0));
166        assert_eq!(rule.match_rule("http", Some("abc")), Some(1.0));
167        let rule1 = TracingSampleRule {
168            protocol: "http".to_string(),
169            request_types: HashSet::from(["mysql".to_string()]),
170            ratio: 1.0,
171        };
172        assert_eq!(rule1.match_rule("http", None), None);
173        assert_eq!(rule1.match_rule("http", Some("abc")), None);
174        assert_eq!(rule1.match_rule("http", Some("mysql")), Some(1.0));
175    }
176}