1use axum::body::Body;
16use axum::http::Request;
17use axum::middleware::Next;
18use axum::response::Response;
19use common_telemetry::debug;
20use session::context::QueryContext;
21use session::hints::is_reserved_extension_key;
22
23use crate::hint_headers;
24
25pub async fn extract_hints(mut request: Request<Body>, next: Next) -> Response {
26 let hints = hint_headers::extract_hints(request.headers());
27 if let Some(query_ctx) = request.extensions_mut().get_mut::<QueryContext>() {
28 apply_hints(query_ctx, hints);
29 }
30 next.run(request).await
31}
32
33fn apply_hints(query_ctx: &mut QueryContext, hints: Vec<(String, String)>) {
34 for (key, value) in hints {
35 if is_reserved_extension_key(&key) {
36 debug!(
37 key = key.as_str(),
38 "Ignoring reserved external query context extension key"
39 );
40 continue;
41 }
42 query_ctx.set_extension(key, value);
43 }
44}
45
46#[cfg(test)]
47mod tests {
48 use common_query::request::INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY as COMMON_INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY;
49 use session::context::{QueryContextBuilder, generate_remote_query_id};
50 use session::hints::{
51 INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY, REMOTE_QUERY_ID_EXTENSION_KEY,
52 };
53
54 use super::apply_hints;
55
56 #[test]
57 fn test_apply_hints_ignores_reserved_extension_keys() {
58 let original_query_id = generate_remote_query_id();
59 let mut query_ctx = QueryContextBuilder::default()
60 .set_extension(
61 REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
62 original_query_id.clone(),
63 )
64 .build();
65
66 apply_hints(
67 &mut query_ctx,
68 vec![
69 (
70 REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
71 "spoofed".to_string(),
72 ),
73 (
74 INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY.to_string(),
75 "spoofed-regs".to_string(),
76 ),
77 ("ttl".to_string(), "7d".to_string()),
78 ],
79 );
80
81 assert_eq!(
82 query_ctx.remote_query_id(),
83 Some(original_query_id.as_str())
84 );
85 assert!(
86 query_ctx
87 .extension(INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY)
88 .is_none()
89 );
90 assert_eq!(query_ctx.extension("ttl"), Some("7d"));
91 }
92
93 #[test]
94 fn test_initial_dyn_filter_registration_key_matches_common_query_constant() {
95 assert_eq!(
96 INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY,
97 COMMON_INITIAL_REMOTE_DYN_FILTER_REGISTRATIONS_EXTENSION_KEY
98 );
99 }
100}