servers/http/
test_helpers.rs

1// This file is copied from https://github.com/tokio-rs/axum/blob/axum-v0.6.20/axum/src/test_helpers/test_client.rs
2
3//! Axum Test Client
4//!
5//! ```rust
6//! use axum::Router;
7//! use axum::http::StatusCode;
8//! use axum::routing::get;
9//! use crate::servers::http::test_helpers::TestClient;
10//!
11//! let async_block = async {
12//!     // you can replace this Router with your own app
13//!     let app = Router::new().route("/", get(|| async {}));
14//!
15//!     // initiate the TestClient with the previous declared Router
16//!     let client = TestClient::new(app).await;
17//!
18//!     let res = client.get("/").await;
19//!     assert_eq!(res.status(), StatusCode::OK);
20//! };
21//!
22//! // Create a runtime for executing the async block. This runtime is local
23//! // to the main function and does not require any global setup.
24//! let runtime = tokio::runtime::Builder::new_current_thread()
25//!     .enable_all()
26//!     .build()
27//!     .unwrap();
28//!
29//! // Use the local runtime to block on the async block.
30//! runtime.block_on(async_block);
31//! ```
32
33use std::convert::TryFrom;
34use std::net::SocketAddr;
35
36use axum::Router;
37use bytes::Bytes;
38use common_telemetry::info;
39use http::header::{HeaderName, HeaderValue};
40use http::{Method, StatusCode};
41use tokio::net::TcpListener;
42
43/// Test client to Axum servers.
44pub struct TestClient {
45    client: reqwest::Client,
46    addr: SocketAddr,
47}
48
49impl TestClient {
50    /// Create a new test client.
51    pub async fn new(svc: Router) -> Self {
52        let listener = TcpListener::bind("127.0.0.1:0")
53            .await
54            .expect("Could not bind ephemeral socket");
55        let addr = listener.local_addr().unwrap();
56        info!("Listening on {}", addr);
57
58        tokio::spawn(async move {
59            axum::serve(listener, svc).await.expect("server error");
60        });
61
62        let client = reqwest::Client::builder()
63            .redirect(reqwest::redirect::Policy::none())
64            .build()
65            .unwrap();
66
67        TestClient { client, addr }
68    }
69
70    /// Returns the base URL (http://ip:port) for this TestClient
71    ///
72    /// this is useful when trying to check if Location headers in responses
73    /// are generated correctly as Location contains an absolute URL
74    pub fn base_url(&self) -> String {
75        format!("http://{}", self.addr)
76    }
77
78    /// Create a GET request.
79    pub fn get(&self, url: &str) -> RequestBuilder {
80        common_telemetry::info!("GET {} {}", self.addr, url);
81
82        RequestBuilder {
83            builder: self.client.get(format!("http://{}{}", self.addr, url)),
84        }
85    }
86
87    /// Create a HEAD request.
88    pub fn head(&self, url: &str) -> RequestBuilder {
89        common_telemetry::info!("HEAD {} {}", self.addr, url);
90
91        RequestBuilder {
92            builder: self.client.head(format!("http://{}{}", self.addr, url)),
93        }
94    }
95
96    /// Create a POST request.
97    pub fn post(&self, url: &str) -> RequestBuilder {
98        common_telemetry::info!("POST {} {}", self.addr, url);
99
100        RequestBuilder {
101            builder: self.client.post(format!("http://{}{}", self.addr, url)),
102        }
103    }
104
105    /// Create a PUT request.
106    pub fn put(&self, url: &str) -> RequestBuilder {
107        common_telemetry::info!("PUT {} {}", self.addr, url);
108
109        RequestBuilder {
110            builder: self.client.put(format!("http://{}{}", self.addr, url)),
111        }
112    }
113
114    /// Create a PATCH request.
115    pub fn patch(&self, url: &str) -> RequestBuilder {
116        common_telemetry::info!("PATCH {} {}", self.addr, url);
117
118        RequestBuilder {
119            builder: self.client.patch(format!("http://{}{}", self.addr, url)),
120        }
121    }
122
123    /// Create a DELETE request.
124    pub fn delete(&self, url: &str) -> RequestBuilder {
125        common_telemetry::info!("DELETE {} {}", self.addr, url);
126
127        RequestBuilder {
128            builder: self.client.delete(format!("http://{}{}", self.addr, url)),
129        }
130    }
131
132    /// Options preflight request
133    pub fn options(&self, url: &str) -> RequestBuilder {
134        common_telemetry::info!("OPTIONS {} {}", self.addr, url);
135
136        RequestBuilder {
137            builder: self
138                .client
139                .request(Method::OPTIONS, format!("http://{}{}", self.addr, url)),
140        }
141    }
142}
143
144/// Builder for test requests.
145pub struct RequestBuilder {
146    builder: reqwest::RequestBuilder,
147}
148
149impl RequestBuilder {
150    pub async fn send(self) -> TestResponse {
151        TestResponse {
152            response: self.builder.send().await.unwrap(),
153        }
154    }
155
156    /// Set the request body.
157    pub fn body(mut self, body: impl Into<reqwest::Body>) -> Self {
158        self.builder = self.builder.body(body);
159        self
160    }
161
162    /// Set the request forms.
163    pub fn form<T: serde::Serialize + ?Sized>(mut self, form: &T) -> Self {
164        self.builder = self.builder.form(&form);
165        self
166    }
167
168    /// Set the request JSON body.
169    pub fn json<T>(mut self, json: &T) -> Self
170    where
171        T: serde::Serialize,
172    {
173        self.builder = self.builder.json(json);
174        self
175    }
176
177    /// Set a request header.
178    pub fn header<K, V>(mut self, key: K, value: V) -> Self
179    where
180        HeaderName: TryFrom<K>,
181        <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
182        HeaderValue: TryFrom<V>,
183        <HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
184    {
185        self.builder = self.builder.header(key, value);
186
187        self
188    }
189
190    /// Set a request multipart form.
191    pub fn multipart(mut self, form: reqwest::multipart::Form) -> Self {
192        self.builder = self.builder.multipart(form);
193        self
194    }
195}
196
197/// A wrapper around [`reqwest::Response`] that provides common methods with internal `unwrap()`s.
198///
199/// This is convenient for tests where panics are what you want. For access to
200/// non-panicking versions or the complete `Response` API use `into_inner()` or
201/// `as_ref()`.
202#[derive(Debug)]
203pub struct TestResponse {
204    response: reqwest::Response,
205}
206
207impl TestResponse {
208    /// Get the response body as text.
209    pub async fn text(self) -> String {
210        self.response.text().await.unwrap()
211    }
212
213    /// Get the response body as bytes.
214    pub async fn bytes(self) -> Bytes {
215        self.response.bytes().await.unwrap()
216    }
217
218    /// Get the response body as JSON.
219    pub async fn json<T>(self) -> T
220    where
221        T: serde::de::DeserializeOwned,
222    {
223        self.response.json().await.unwrap()
224    }
225
226    /// Get the response status.
227    pub fn status(&self) -> StatusCode {
228        StatusCode::from_u16(self.response.status().as_u16()).unwrap()
229    }
230
231    /// Get the response headers.
232    pub fn headers(&self) -> http::HeaderMap {
233        self.response.headers().clone()
234    }
235
236    /// Get the response in chunks.
237    pub async fn chunk(&mut self) -> Option<Bytes> {
238        self.response.chunk().await.unwrap()
239    }
240
241    /// Get the response in chunks as text.
242    pub async fn chunk_text(&mut self) -> Option<String> {
243        let chunk = self.chunk().await?;
244        Some(String::from_utf8(chunk.to_vec()).unwrap())
245    }
246
247    /// Get the inner [`reqwest::Response`] for less convenient but more complete access.
248    pub fn into_inner(self) -> reqwest::Response {
249        self.response
250    }
251}
252
253impl AsRef<reqwest::Response> for TestResponse {
254    fn as_ref(&self) -> &reqwest::Response {
255        &self.response
256    }
257}