common_telemetry/
tracing_sampler.rs1use std::collections::HashSet;
16
17use opentelemetry::trace::{
18 Link, SamplingDecision, SamplingResult, SpanKind, TraceContextExt, TraceId, TraceState,
19};
20use opentelemetry::KeyValue;
21use opentelemetry_sdk::trace::{Sampler, ShouldSample};
22use serde::{Deserialize, Serialize};
23
24#[derive(Clone, Debug, Serialize, Deserialize)]
25#[serde(default)]
26pub struct TracingSampleOptions {
27 pub default_ratio: f64,
28 pub rules: Vec<TracingSampleRule>,
29}
30
31impl Default for TracingSampleOptions {
32 fn default() -> Self {
33 Self {
34 default_ratio: 1.0,
35 rules: vec![],
36 }
37 }
38}
39
40#[derive(Clone, Default, Debug, Serialize, Deserialize)]
43#[serde(default)]
44pub struct TracingSampleRule {
45 pub protocol: String,
46 pub request_types: HashSet<String>,
47 pub ratio: f64,
48}
49
50impl TracingSampleRule {
51 pub fn match_rule(&self, protocol: &str, request_type: Option<&str>) -> Option<f64> {
52 if protocol == self.protocol {
53 if self.request_types.is_empty() {
54 Some(self.ratio)
55 } else if let Some(t) = request_type
56 && self.request_types.contains(t)
57 {
58 Some(self.ratio)
59 } else {
60 None
61 }
62 } else {
63 None
64 }
65 }
66}
67
68impl PartialEq for TracingSampleOptions {
69 fn eq(&self, other: &Self) -> bool {
70 self.default_ratio == other.default_ratio && self.rules == other.rules
71 }
72}
73impl PartialEq for TracingSampleRule {
74 fn eq(&self, other: &Self) -> bool {
75 self.protocol == other.protocol
76 && self.request_types == other.request_types
77 && self.ratio == other.ratio
78 }
79}
80
81impl Eq for TracingSampleOptions {}
82impl Eq for TracingSampleRule {}
83
84pub fn create_sampler(opt: &TracingSampleOptions) -> Box<dyn ShouldSample> {
85 if opt.rules.is_empty() {
86 Box::new(Sampler::TraceIdRatioBased(opt.default_ratio))
87 } else {
88 Box::new(opt.clone())
89 }
90}
91
92impl ShouldSample for TracingSampleOptions {
93 fn should_sample(
94 &self,
95 parent_context: Option<&opentelemetry::Context>,
96 trace_id: TraceId,
97 _name: &str,
98 _span_kind: &SpanKind,
99 attributes: &[KeyValue],
100 _links: &[Link],
101 ) -> SamplingResult {
102 let (mut protocol, mut request_type) = (None, None);
103 for kv in attributes {
104 match kv.key.as_str() {
105 "protocol" => protocol = Some(kv.value.as_str()),
106 "request_type" => request_type = Some(kv.value.as_str()),
107 _ => (),
108 }
109 }
110 let ratio = protocol
111 .and_then(|p| {
112 self.rules
113 .iter()
114 .find_map(|rule| rule.match_rule(p.as_ref(), request_type.as_deref()))
115 })
116 .unwrap_or(self.default_ratio);
117 SamplingResult {
118 decision: sample_based_on_probability(ratio, trace_id),
119 attributes: Vec::new(),
121 trace_state: match parent_context {
123 Some(ctx) => ctx.span().span_context().trace_state().clone(),
124 None => TraceState::default(),
125 },
126 }
127 }
128}
129
130fn sample_based_on_probability(prob: f64, trace_id: TraceId) -> SamplingDecision {
134 if prob >= 1.0 {
135 SamplingDecision::RecordAndSample
136 } else {
137 let prob_upper_bound = (prob.max(0.0) * (1u64 << 63) as f64) as u64;
138 let bytes = trace_id.to_bytes();
139 let (_, low) = bytes.split_at(8);
140 let trace_id_low = u64::from_be_bytes(low.try_into().unwrap());
141 let rnd_from_trace_id = trace_id_low >> 1;
142
143 if rnd_from_trace_id < prob_upper_bound {
144 SamplingDecision::RecordAndSample
145 } else {
146 SamplingDecision::Drop
147 }
148 }
149}
150
151#[cfg(test)]
152mod test {
153 use std::collections::HashSet;
154
155 use crate::tracing_sampler::TracingSampleRule;
156
157 #[test]
158 fn test_rule() {
159 let rule = TracingSampleRule {
160 protocol: "http".to_string(),
161 request_types: HashSet::new(),
162 ratio: 1.0,
163 };
164 assert_eq!(rule.match_rule("not_http", None), None);
165 assert_eq!(rule.match_rule("http", None), Some(1.0));
166 assert_eq!(rule.match_rule("http", Some("abc")), Some(1.0));
167 let rule1 = TracingSampleRule {
168 protocol: "http".to_string(),
169 request_types: HashSet::from(["mysql".to_string()]),
170 ratio: 1.0,
171 };
172 assert_eq!(rule1.match_rule("http", None), None);
173 assert_eq!(rule1.match_rule("http", Some("abc")), None);
174 assert_eq!(rule1.match_rule("http", Some("mysql")), Some(1.0));
175 }
176}