Skip to main content

servers/http/
client_ip.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::net::SocketAddr;
16
17use axum::body::Body;
18use axum::extract::{ConnectInfo, MatchedPath};
19use axum::http::Request;
20use axum::middleware::Next;
21use axum::response::Response;
22use common_telemetry::warn;
23
24/// Middleware that logs HTTP error responses (4xx/5xx) with client IP address.
25///
26/// Extracts client address from [`ConnectInfo`] if available.
27pub async fn log_error_with_client_ip(req: Request<Body>, next: Next) -> Response {
28    let request_info = req
29        .extensions()
30        .get::<ConnectInfo<SocketAddr>>()
31        .map(|c| c.0)
32        .map(|addr| {
33            let method = req.method().clone();
34            let uri = req.uri().clone();
35            let matched_path = req.extensions().get::<MatchedPath>().cloned();
36            (addr, method, uri, matched_path)
37        });
38
39    let response = next.run(req).await;
40
41    if (response.status().is_client_error() || response.status().is_server_error())
42        && let Some((addr, method, uri, matched_path)) = request_info
43    {
44        warn!(
45            "HTTP error response {} for {} {} (matched: {}) from client {}",
46            response.status(),
47            method,
48            uri,
49            matched_path
50                .as_ref()
51                .map(|p| p.as_str())
52                .unwrap_or("<unknown>"),
53            addr
54        );
55    }
56
57    response
58}
59
60#[cfg(test)]
61mod tests {
62    use axum::Router;
63    use axum::routing::get;
64    use http::StatusCode;
65    use tower::ServiceExt;
66
67    use super::*;
68
69    #[tokio::test]
70    async fn test_middleware_passes_error_response() {
71        async fn not_found_handler() -> StatusCode {
72            StatusCode::NOT_FOUND
73        }
74
75        let app = Router::new()
76            .route("/not-found", get(not_found_handler))
77            .layer(axum::middleware::from_fn(log_error_with_client_ip));
78
79        let response = app
80            .oneshot(
81                Request::builder()
82                    .uri("/not-found")
83                    .body(Body::empty())
84                    .unwrap(),
85            )
86            .await
87            .unwrap();
88
89        assert_eq!(response.status(), StatusCode::NOT_FOUND);
90    }
91
92    #[tokio::test]
93    async fn test_middleware_passes_success_response() {
94        async fn ok_handler() -> StatusCode {
95            StatusCode::OK
96        }
97
98        let app = Router::new()
99            .route("/ok", get(ok_handler))
100            .layer(axum::middleware::from_fn(log_error_with_client_ip));
101
102        let response = app
103            .oneshot(Request::builder().uri("/ok").body(Body::empty()).unwrap())
104            .await
105            .unwrap();
106
107        assert_eq!(response.status(), StatusCode::OK);
108    }
109}