use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use axum::body::Body;
use axum::http::Request;
use axum::response::Response;
use pin_project::pin_project;
use tokio::time::{Instant, Sleep};
use tower::timeout::error::Elapsed;
use tower::{BoxError, Layer, Service};
use crate::http::header::constants::GREPTIME_DB_HEADER_TIMEOUT;
#[derive(Debug)]
#[pin_project]
pub struct ResponseFuture<T> {
#[pin]
response: T,
#[pin]
sleep: Sleep,
}
impl<T> ResponseFuture<T> {
pub(crate) fn new(response: T, sleep: Sleep) -> Self {
ResponseFuture { response, sleep }
}
}
impl<F, T, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<T, E>>,
E: Into<BoxError>,
{
type Output = Result<T, BoxError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.response.poll(cx) {
Poll::Ready(v) => return Poll::Ready(v.map_err(Into::into)),
Poll::Pending => {}
}
match this.sleep.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(_) => Poll::Ready(Err(Elapsed::new().into())),
}
}
}
#[derive(Debug, Clone)]
pub struct DynamicTimeoutLayer {
default_timeout: Duration,
}
impl DynamicTimeoutLayer {
pub fn new(default_timeout: Duration) -> Self {
DynamicTimeoutLayer { default_timeout }
}
}
impl<S> Layer<S> for DynamicTimeoutLayer {
type Service = DynamicTimeout<S>;
fn layer(&self, service: S) -> Self::Service {
DynamicTimeout::new(service, self.default_timeout)
}
}
#[derive(Clone)]
pub struct DynamicTimeout<S> {
inner: S,
default_timeout: Duration,
}
impl<S> DynamicTimeout<S> {
pub fn new(inner: S, default_timeout: Duration) -> Self {
DynamicTimeout {
inner,
default_timeout,
}
}
}
impl<S> Service<Request<Body>> for DynamicTimeout<S>
where
S: Service<Request<Body>, Response = Response> + Send + 'static,
S::Error: Into<BoxError>,
{
type Response = S::Response;
type Error = BoxError;
type Future = ResponseFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.inner.poll_ready(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)),
}
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let timeout = request
.headers()
.get(GREPTIME_DB_HEADER_TIMEOUT)
.and_then(|value| {
value
.to_str()
.ok()
.and_then(|value| humantime::parse_duration(value).ok())
})
.unwrap_or(self.default_timeout);
let response = self.inner.call(request);
if timeout.is_zero() {
let far_future = Instant::now() + Duration::from_secs(86400 * 365 * 30);
ResponseFuture::new(response, tokio::time::sleep_until(far_future))
} else {
let sleep = tokio::time::sleep(timeout);
ResponseFuture::new(response, sleep)
}
}
}