servers/grpc/
memory_limit.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::task::{Context, Poll};
16
17use futures::future::BoxFuture;
18use tonic::server::NamedService;
19use tower::{Layer, Service};
20
21use crate::request_limiter::RequestMemoryLimiter;
22
23#[derive(Clone)]
24pub struct MemoryLimiterExtensionLayer {
25    limiter: RequestMemoryLimiter,
26}
27
28impl MemoryLimiterExtensionLayer {
29    pub fn new(limiter: RequestMemoryLimiter) -> Self {
30        Self { limiter }
31    }
32}
33
34impl<S> Layer<S> for MemoryLimiterExtensionLayer {
35    type Service = MemoryLimiterExtensionService<S>;
36
37    fn layer(&self, service: S) -> Self::Service {
38        MemoryLimiterExtensionService {
39            inner: service,
40            limiter: self.limiter.clone(),
41        }
42    }
43}
44
45#[derive(Clone)]
46pub struct MemoryLimiterExtensionService<S> {
47    inner: S,
48    limiter: RequestMemoryLimiter,
49}
50
51impl<S: NamedService> NamedService for MemoryLimiterExtensionService<S> {
52    const NAME: &'static str = S::NAME;
53}
54
55impl<S, ReqBody> Service<http::Request<ReqBody>> for MemoryLimiterExtensionService<S>
56where
57    S: Service<http::Request<ReqBody>>,
58    S::Future: Send + 'static,
59{
60    type Response = S::Response;
61    type Error = S::Error;
62    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
63
64    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
65        self.inner.poll_ready(cx)
66    }
67
68    fn call(&mut self, mut req: http::Request<ReqBody>) -> Self::Future {
69        req.extensions_mut().insert(self.limiter.clone());
70        Box::pin(self.inner.call(req))
71    }
72}