Skip to main content

servers/http/
client_ip.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::net::SocketAddr;
16
17use axum::body::Body;
18use axum::extract::{ConnectInfo, MatchedPath};
19use axum::http::Request;
20use axum::middleware::Next;
21use axum::response::Response;
22use common_telemetry::warn;
23
24/// Middleware that logs HTTP error responses (4xx/5xx) with client IP address.
25///
26/// Extracts client address from [`ConnectInfo`] if available.
27pub async fn log_error_with_client_ip(req: Request<Body>, next: Next) -> Response {
28    let request_info = if is_public_http_api_path(req.uri().path()) {
29        req.extensions()
30            .get::<ConnectInfo<SocketAddr>>()
31            .map(|c| c.0)
32            .map(|addr| {
33                let method = req.method().clone();
34                let uri = req.uri().clone();
35                let matched_path = req.extensions().get::<MatchedPath>().cloned();
36                (addr, method, uri, matched_path)
37            })
38    } else {
39        None
40    };
41
42    let response = next.run(req).await;
43
44    if (response.status().is_client_error() || response.status().is_server_error())
45        && let Some((addr, method, uri, matched_path)) = request_info
46    {
47        warn!(
48            "HTTP error response {} for {} {} (matched: {}) from client {}",
49            response.status(),
50            method,
51            uri,
52            matched_path
53                .as_ref()
54                .map(|p| p.as_str())
55                .unwrap_or("<unknown>"),
56            addr
57        );
58    }
59
60    response
61}
62
63fn is_public_http_api_path(path: &str) -> bool {
64    path == super::HTTP_API_PREFIX_WITHOUT_TRAILING_SLASH
65        || path.starts_with(super::HTTP_API_PREFIX)
66}
67
68#[cfg(test)]
69mod tests {
70    use axum::Router;
71    use axum::routing::get;
72    use http::StatusCode;
73    use tower::ServiceExt;
74
75    use super::*;
76
77    #[test]
78    fn test_public_http_api_path_matches_v1_prefix() {
79        assert!(is_public_http_api_path("/v1"));
80        assert!(is_public_http_api_path("/v1/sql"));
81        assert!(is_public_http_api_path("/v1/prometheus/api/v1/query"));
82
83        assert!(!is_public_http_api_path("/"));
84        assert!(!is_public_http_api_path("/health"));
85        assert!(!is_public_http_api_path("/status"));
86        assert!(!is_public_http_api_path("/metrics"));
87        assert!(!is_public_http_api_path("/v10/sql"));
88    }
89
90    #[tokio::test]
91    async fn test_middleware_passes_error_response() {
92        async fn not_found_handler() -> StatusCode {
93            StatusCode::NOT_FOUND
94        }
95
96        let app = Router::new()
97            .route("/not-found", get(not_found_handler))
98            .layer(axum::middleware::from_fn(log_error_with_client_ip));
99
100        let response = app
101            .oneshot(
102                Request::builder()
103                    .uri("/not-found")
104                    .body(Body::empty())
105                    .unwrap(),
106            )
107            .await
108            .unwrap();
109
110        assert_eq!(response.status(), StatusCode::NOT_FOUND);
111    }
112
113    #[tokio::test]
114    async fn test_middleware_passes_success_response() {
115        async fn ok_handler() -> StatusCode {
116            StatusCode::OK
117        }
118
119        let app = Router::new()
120            .route("/ok", get(ok_handler))
121            .layer(axum::middleware::from_fn(log_error_with_client_ip));
122
123        let response = app
124            .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
125            .await
126            .unwrap();
127
128        assert_eq!(response.status(), StatusCode::OK);
129    }
130}