servers/http/test_helpers.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
// This file is copied from https://github.com/tokio-rs/axum/blob/axum-v0.6.20/axum/src/test_helpers/test_client.rs
//! Axum Test Client
//!
//! ```rust
//! use axum::Router;
//! use axum::http::StatusCode;
//! use axum::routing::get;
//! use crate::servers::http::test_helpers::TestClient;
//!
//! let async_block = async {
//! // you can replace this Router with your own app
//! let app = Router::new().route("/", get(|| async {}));
//!
//! // initiate the TestClient with the previous declared Router
//! let client = TestClient::new(app).await;
//!
//! let res = client.get("/").await;
//! assert_eq!(res.status(), StatusCode::OK);
//! };
//!
//! // Create a runtime for executing the async block. This runtime is local
//! // to the main function and does not require any global setup.
//! let runtime = tokio::runtime::Builder::new_current_thread()
//! .enable_all()
//! .build()
//! .unwrap();
//!
//! // Use the local runtime to block on the async block.
//! runtime.block_on(async_block);
//! ```
use std::convert::TryFrom;
use std::net::SocketAddr;
use axum::Router;
use bytes::Bytes;
use common_telemetry::info;
use http::header::{HeaderName, HeaderValue};
use http::{Method, StatusCode};
use tokio::net::TcpListener;
/// Test client to Axum servers.
pub struct TestClient {
client: reqwest::Client,
addr: SocketAddr,
}
impl TestClient {
/// Create a new test client.
pub async fn new(svc: Router) -> Self {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("Could not bind ephemeral socket");
let addr = listener.local_addr().unwrap();
info!("Listening on {}", addr);
tokio::spawn(async move {
axum::serve(listener, svc).await.expect("server error");
});
let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.unwrap();
TestClient { client, addr }
}
/// Returns the base URL (http://ip:port) for this TestClient
///
/// this is useful when trying to check if Location headers in responses
/// are generated correctly as Location contains an absolute URL
pub fn base_url(&self) -> String {
format!("http://{}", self.addr)
}
/// Create a GET request.
pub fn get(&self, url: &str) -> RequestBuilder {
common_telemetry::info!("GET {} {}", self.addr, url);
RequestBuilder {
builder: self.client.get(format!("http://{}{}", self.addr, url)),
}
}
/// Create a HEAD request.
pub fn head(&self, url: &str) -> RequestBuilder {
common_telemetry::info!("HEAD {} {}", self.addr, url);
RequestBuilder {
builder: self.client.head(format!("http://{}{}", self.addr, url)),
}
}
/// Create a POST request.
pub fn post(&self, url: &str) -> RequestBuilder {
common_telemetry::info!("POST {} {}", self.addr, url);
RequestBuilder {
builder: self.client.post(format!("http://{}{}", self.addr, url)),
}
}
/// Create a PUT request.
pub fn put(&self, url: &str) -> RequestBuilder {
common_telemetry::info!("PUT {} {}", self.addr, url);
RequestBuilder {
builder: self.client.put(format!("http://{}{}", self.addr, url)),
}
}
/// Create a PATCH request.
pub fn patch(&self, url: &str) -> RequestBuilder {
common_telemetry::info!("PATCH {} {}", self.addr, url);
RequestBuilder {
builder: self.client.patch(format!("http://{}{}", self.addr, url)),
}
}
/// Create a DELETE request.
pub fn delete(&self, url: &str) -> RequestBuilder {
common_telemetry::info!("DELETE {} {}", self.addr, url);
RequestBuilder {
builder: self.client.delete(format!("http://{}{}", self.addr, url)),
}
}
/// Options preflight request
pub fn options(&self, url: &str) -> RequestBuilder {
common_telemetry::info!("OPTIONS {} {}", self.addr, url);
RequestBuilder {
builder: self
.client
.request(Method::OPTIONS, format!("http://{}{}", self.addr, url)),
}
}
}
/// Builder for test requests.
pub struct RequestBuilder {
builder: reqwest::RequestBuilder,
}
impl RequestBuilder {
pub async fn send(self) -> TestResponse {
TestResponse {
response: self.builder.send().await.unwrap(),
}
}
/// Set the request body.
pub fn body(mut self, body: impl Into<reqwest::Body>) -> Self {
self.builder = self.builder.body(body);
self
}
/// Set the request forms.
pub fn form<T: serde::Serialize + ?Sized>(mut self, form: &T) -> Self {
self.builder = self.builder.form(&form);
self
}
/// Set the request JSON body.
pub fn json<T>(mut self, json: &T) -> Self
where
T: serde::Serialize,
{
self.builder = self.builder.json(json);
self
}
/// Set a request header.
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
self.builder = self.builder.header(key, value);
self
}
/// Set a request multipart form.
pub fn multipart(mut self, form: reqwest::multipart::Form) -> Self {
self.builder = self.builder.multipart(form);
self
}
}
/// A wrapper around [`reqwest::Response`] that provides common methods with internal `unwrap()`s.
///
/// This is convenient for tests where panics are what you want. For access to
/// non-panicking versions or the complete `Response` API use `into_inner()` or
/// `as_ref()`.
#[derive(Debug)]
pub struct TestResponse {
response: reqwest::Response,
}
impl TestResponse {
/// Get the response body as text.
pub async fn text(self) -> String {
self.response.text().await.unwrap()
}
/// Get the response body as bytes.
pub async fn bytes(self) -> Bytes {
self.response.bytes().await.unwrap()
}
/// Get the response body as JSON.
pub async fn json<T>(self) -> T
where
T: serde::de::DeserializeOwned,
{
self.response.json().await.unwrap()
}
/// Get the response status.
pub fn status(&self) -> StatusCode {
StatusCode::from_u16(self.response.status().as_u16()).unwrap()
}
/// Get the response headers.
pub fn headers(&self) -> http::HeaderMap {
self.response.headers().clone()
}
/// Get the response in chunks.
pub async fn chunk(&mut self) -> Option<Bytes> {
self.response.chunk().await.unwrap()
}
/// Get the response in chunks as text.
pub async fn chunk_text(&mut self) -> Option<String> {
let chunk = self.chunk().await?;
Some(String::from_utf8(chunk.to_vec()).unwrap())
}
/// Get the inner [`reqwest::Response`] for less convenient but more complete access.
pub fn into_inner(self) -> reqwest::Response {
self.response
}
}
impl AsRef<reqwest::Response> for TestResponse {
fn as_ref(&self) -> &reqwest::Response {
&self.response
}
}