servers/grpc/
memory_limit.rs1use std::task::{Context, Poll};
16
17use futures::future::BoxFuture;
18use tonic::server::NamedService;
19use tower::{Layer, Service};
20
21use crate::request_limiter::RequestMemoryLimiter;
22
23#[derive(Clone)]
24pub struct MemoryLimiterExtensionLayer {
25 limiter: RequestMemoryLimiter,
26}
27
28impl MemoryLimiterExtensionLayer {
29 pub fn new(limiter: RequestMemoryLimiter) -> Self {
30 Self { limiter }
31 }
32}
33
34impl<S> Layer<S> for MemoryLimiterExtensionLayer {
35 type Service = MemoryLimiterExtensionService<S>;
36
37 fn layer(&self, service: S) -> Self::Service {
38 MemoryLimiterExtensionService {
39 inner: service,
40 limiter: self.limiter.clone(),
41 }
42 }
43}
44
45#[derive(Clone)]
46pub struct MemoryLimiterExtensionService<S> {
47 inner: S,
48 limiter: RequestMemoryLimiter,
49}
50
51impl<S: NamedService> NamedService for MemoryLimiterExtensionService<S> {
52 const NAME: &'static str = S::NAME;
53}
54
55impl<S, ReqBody> Service<http::Request<ReqBody>> for MemoryLimiterExtensionService<S>
56where
57 S: Service<http::Request<ReqBody>>,
58 S::Future: Send + 'static,
59{
60 type Response = S::Response;
61 type Error = S::Error;
62 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
63
64 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
65 self.inner.poll_ready(cx)
66 }
67
68 fn call(&mut self, mut req: http::Request<ReqBody>) -> Self::Future {
69 req.extensions_mut().insert(self.limiter.clone());
70 Box::pin(self.inner.call(req))
71 }
72}