common_grpc/
channel_manager.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
16use std::sync::Arc;
17use std::time::Duration;
18
19use common_base::readable_size::ReadableSize;
20use common_telemetry::info;
21use dashmap::mapref::entry::Entry;
22use dashmap::DashMap;
23use lazy_static::lazy_static;
24use snafu::{OptionExt, ResultExt};
25use tokio_util::sync::CancellationToken;
26use tonic::transport::{
27    Certificate, Channel as InnerChannel, ClientTlsConfig, Endpoint, Identity, Uri,
28};
29use tower::Service;
30
31use crate::error::{CreateChannelSnafu, InvalidConfigFilePathSnafu, InvalidTlsConfigSnafu, Result};
32
33const RECYCLE_CHANNEL_INTERVAL_SECS: u64 = 60;
34pub const DEFAULT_GRPC_REQUEST_TIMEOUT_SECS: u64 = 10;
35pub const DEFAULT_GRPC_CONNECT_TIMEOUT_SECS: u64 = 1;
36pub const DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
37pub const DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
38
39lazy_static! {
40    static ref ID: AtomicU64 = AtomicU64::new(0);
41}
42
43#[derive(Clone, Debug, Default)]
44pub struct ChannelManager {
45    inner: Arc<Inner>,
46}
47
48#[derive(Debug)]
49struct Inner {
50    id: u64,
51    config: ChannelConfig,
52    client_tls_config: Option<ClientTlsConfig>,
53    pool: Arc<Pool>,
54    channel_recycle_started: AtomicBool,
55    cancel: CancellationToken,
56}
57
58impl Default for Inner {
59    fn default() -> Self {
60        Self::with_config(ChannelConfig::default())
61    }
62}
63
64impl Drop for Inner {
65    fn drop(&mut self) {
66        // Cancel the channel recycle task.
67        self.cancel.cancel();
68    }
69}
70
71impl Inner {
72    fn with_config(config: ChannelConfig) -> Self {
73        let id = ID.fetch_add(1, Ordering::Relaxed);
74        let pool = Arc::new(Pool::default());
75        let cancel = CancellationToken::new();
76
77        Self {
78            id,
79            config,
80            client_tls_config: None,
81            pool,
82            channel_recycle_started: AtomicBool::new(false),
83            cancel,
84        }
85    }
86}
87
88impl ChannelManager {
89    pub fn new() -> Self {
90        Default::default()
91    }
92
93    pub fn with_config(config: ChannelConfig) -> Self {
94        let inner = Inner::with_config(config);
95        Self {
96            inner: Arc::new(inner),
97        }
98    }
99
100    pub fn with_tls_config(config: ChannelConfig) -> Result<Self> {
101        let mut inner = Inner::with_config(config.clone());
102
103        // setup tls
104        let path_config = config.client_tls.context(InvalidTlsConfigSnafu {
105            msg: "no config input",
106        })?;
107
108        let server_root_ca_cert = std::fs::read_to_string(path_config.server_ca_cert_path)
109            .context(InvalidConfigFilePathSnafu)?;
110        let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert);
111        let client_cert = std::fs::read_to_string(path_config.client_cert_path)
112            .context(InvalidConfigFilePathSnafu)?;
113        let client_key = std::fs::read_to_string(path_config.client_key_path)
114            .context(InvalidConfigFilePathSnafu)?;
115        let client_identity = Identity::from_pem(client_cert, client_key);
116
117        inner.client_tls_config = Some(
118            ClientTlsConfig::new()
119                .ca_certificate(server_root_ca_cert)
120                .identity(client_identity),
121        );
122
123        Ok(Self {
124            inner: Arc::new(inner),
125        })
126    }
127
128    pub fn config(&self) -> &ChannelConfig {
129        &self.inner.config
130    }
131
132    fn pool(&self) -> &Arc<Pool> {
133        &self.inner.pool
134    }
135
136    pub fn get(&self, addr: impl AsRef<str>) -> Result<InnerChannel> {
137        self.trigger_channel_recycling();
138
139        let addr = addr.as_ref();
140        // It will acquire the read lock.
141        if let Some(inner_ch) = self.pool().get(addr) {
142            return Ok(inner_ch);
143        }
144
145        // It will acquire the write lock.
146        let entry = match self.pool().entry(addr.to_string()) {
147            Entry::Occupied(entry) => {
148                entry.get().increase_access();
149                entry.into_ref()
150            }
151            Entry::Vacant(entry) => {
152                let endpoint = self.build_endpoint(addr)?;
153                let inner_channel = endpoint.connect_lazy();
154
155                let channel = Channel {
156                    channel: inner_channel,
157                    access: AtomicUsize::new(1),
158                    use_default_connector: true,
159                };
160                entry.insert(channel)
161            }
162        };
163        Ok(entry.channel.clone())
164    }
165
166    pub fn reset_with_connector<C>(
167        &self,
168        addr: impl AsRef<str>,
169        connector: C,
170    ) -> Result<InnerChannel>
171    where
172        C: Service<Uri> + Send + 'static,
173        C::Response: hyper::rt::Read + hyper::rt::Write + Send + Unpin,
174        C::Future: Send + 'static,
175        Box<dyn std::error::Error + Send + Sync>: From<C::Error> + Send + 'static,
176    {
177        let addr = addr.as_ref();
178        let endpoint = self.build_endpoint(addr)?;
179        let inner_channel = endpoint.connect_with_connector_lazy(connector);
180        let channel = Channel {
181            channel: inner_channel.clone(),
182            access: AtomicUsize::new(1),
183            use_default_connector: false,
184        };
185        self.pool().put(addr, channel);
186
187        Ok(inner_channel)
188    }
189
190    pub fn retain_channel<F>(&self, f: F)
191    where
192        F: FnMut(&String, &mut Channel) -> bool,
193    {
194        self.pool().retain_channel(f);
195    }
196
197    fn build_endpoint(&self, addr: &str) -> Result<Endpoint> {
198        let http_prefix = if self.inner.client_tls_config.is_some() {
199            "https"
200        } else {
201            "http"
202        };
203
204        let mut endpoint =
205            Endpoint::new(format!("{http_prefix}://{addr}")).context(CreateChannelSnafu)?;
206
207        if let Some(dur) = self.config().timeout {
208            endpoint = endpoint.timeout(dur);
209        }
210        if let Some(dur) = self.config().connect_timeout {
211            endpoint = endpoint.connect_timeout(dur);
212        }
213        if let Some(limit) = self.config().concurrency_limit {
214            endpoint = endpoint.concurrency_limit(limit);
215        }
216        if let Some((limit, dur)) = self.config().rate_limit {
217            endpoint = endpoint.rate_limit(limit, dur);
218        }
219        if let Some(size) = self.config().initial_stream_window_size {
220            endpoint = endpoint.initial_stream_window_size(size);
221        }
222        if let Some(size) = self.config().initial_connection_window_size {
223            endpoint = endpoint.initial_connection_window_size(size);
224        }
225        if let Some(dur) = self.config().http2_keep_alive_interval {
226            endpoint = endpoint.http2_keep_alive_interval(dur);
227        }
228        if let Some(dur) = self.config().http2_keep_alive_timeout {
229            endpoint = endpoint.keep_alive_timeout(dur);
230        }
231        if let Some(enabled) = self.config().http2_keep_alive_while_idle {
232            endpoint = endpoint.keep_alive_while_idle(enabled);
233        }
234        if let Some(enabled) = self.config().http2_adaptive_window {
235            endpoint = endpoint.http2_adaptive_window(enabled);
236        }
237        if let Some(tls_config) = &self.inner.client_tls_config {
238            endpoint = endpoint
239                .tls_config(tls_config.clone())
240                .context(CreateChannelSnafu)?;
241        }
242
243        endpoint = endpoint
244            .tcp_keepalive(self.config().tcp_keepalive)
245            .tcp_nodelay(self.config().tcp_nodelay);
246
247        Ok(endpoint)
248    }
249
250    fn trigger_channel_recycling(&self) {
251        if self
252            .inner
253            .channel_recycle_started
254            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
255            .is_err()
256        {
257            return;
258        }
259
260        let pool = self.pool().clone();
261        let cancel = self.inner.cancel.clone();
262        let id = self.inner.id;
263        let _handle = common_runtime::spawn_global(async move {
264            recycle_channel_in_loop(pool, id, cancel, RECYCLE_CHANNEL_INTERVAL_SECS).await;
265        });
266        info!(
267            "ChannelManager: {}, channel recycle is started, running in the background!",
268            self.inner.id
269        );
270    }
271}
272
273#[derive(Clone, Debug, PartialEq, Eq)]
274pub struct ClientTlsOption {
275    pub server_ca_cert_path: String,
276    pub client_cert_path: String,
277    pub client_key_path: String,
278}
279
280#[derive(Clone, Debug, PartialEq, Eq)]
281pub struct ChannelConfig {
282    pub timeout: Option<Duration>,
283    pub connect_timeout: Option<Duration>,
284    pub concurrency_limit: Option<usize>,
285    pub rate_limit: Option<(u64, Duration)>,
286    pub initial_stream_window_size: Option<u32>,
287    pub initial_connection_window_size: Option<u32>,
288    pub http2_keep_alive_interval: Option<Duration>,
289    pub http2_keep_alive_timeout: Option<Duration>,
290    pub http2_keep_alive_while_idle: Option<bool>,
291    pub http2_adaptive_window: Option<bool>,
292    pub tcp_keepalive: Option<Duration>,
293    pub tcp_nodelay: bool,
294    pub client_tls: Option<ClientTlsOption>,
295    // Max gRPC receiving(decoding) message size
296    pub max_recv_message_size: ReadableSize,
297    // Max gRPC sending(encoding) message size
298    pub max_send_message_size: ReadableSize,
299}
300
301impl Default for ChannelConfig {
302    fn default() -> Self {
303        Self {
304            timeout: Some(Duration::from_secs(DEFAULT_GRPC_REQUEST_TIMEOUT_SECS)),
305            connect_timeout: Some(Duration::from_secs(DEFAULT_GRPC_CONNECT_TIMEOUT_SECS)),
306            concurrency_limit: None,
307            rate_limit: None,
308            initial_stream_window_size: None,
309            initial_connection_window_size: None,
310            http2_keep_alive_interval: Some(Duration::from_secs(30)),
311            http2_keep_alive_timeout: None,
312            http2_keep_alive_while_idle: Some(true),
313            http2_adaptive_window: None,
314            tcp_keepalive: None,
315            tcp_nodelay: true,
316            client_tls: None,
317            max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
318            max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
319        }
320    }
321}
322
323impl ChannelConfig {
324    pub fn new() -> Self {
325        Default::default()
326    }
327
328    /// A timeout to each request.
329    pub fn timeout(mut self, timeout: Duration) -> Self {
330        self.timeout = Some(timeout);
331        self
332    }
333
334    /// A timeout to connecting to the uri.
335    ///
336    /// Defaults to no timeout.
337    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
338        self.connect_timeout = Some(timeout);
339        self
340    }
341
342    /// A concurrency limit to each request.
343    pub fn concurrency_limit(mut self, limit: usize) -> Self {
344        self.concurrency_limit = Some(limit);
345        self
346    }
347
348    /// A rate limit to each request.
349    pub fn rate_limit(mut self, limit: u64, duration: Duration) -> Self {
350        self.rate_limit = Some((limit, duration));
351        self
352    }
353
354    /// Sets the SETTINGS_INITIAL_WINDOW_SIZE option for HTTP2 stream-level flow control.
355    /// Default is 65,535
356    pub fn initial_stream_window_size(mut self, size: u32) -> Self {
357        self.initial_stream_window_size = Some(size);
358        self
359    }
360
361    /// Sets the max connection-level flow control for HTTP2
362    ///
363    /// Default is 65,535
364    pub fn initial_connection_window_size(mut self, size: u32) -> Self {
365        self.initial_connection_window_size = Some(size);
366        self
367    }
368
369    /// Set http2 KEEP_ALIVE_INTERVAL. Uses hyper’s default otherwise.
370    pub fn http2_keep_alive_interval(mut self, duration: Duration) -> Self {
371        self.http2_keep_alive_interval = Some(duration);
372        self
373    }
374
375    /// Set http2 KEEP_ALIVE_TIMEOUT. Uses hyper’s default otherwise.
376    pub fn http2_keep_alive_timeout(mut self, duration: Duration) -> Self {
377        self.http2_keep_alive_timeout = Some(duration);
378        self
379    }
380
381    /// Set http2 KEEP_ALIVE_WHILE_IDLE. Uses hyper’s default otherwise.
382    pub fn http2_keep_alive_while_idle(mut self, enabled: bool) -> Self {
383        self.http2_keep_alive_while_idle = Some(enabled);
384        self
385    }
386
387    /// Sets whether to use an adaptive flow control. Uses hyper’s default otherwise.
388    pub fn http2_adaptive_window(mut self, enabled: bool) -> Self {
389        self.http2_adaptive_window = Some(enabled);
390        self
391    }
392
393    /// Set whether TCP keepalive messages are enabled on accepted connections.
394    ///
395    /// If None is specified, keepalive is disabled, otherwise the duration specified
396    /// will be the time to remain idle before sending TCP keepalive probes.
397    ///
398    /// Default is no keepalive (None)
399    pub fn tcp_keepalive(mut self, duration: Duration) -> Self {
400        self.tcp_keepalive = Some(duration);
401        self
402    }
403
404    /// Set the value of TCP_NODELAY option for accepted connections.
405    ///
406    /// Enabled by default.
407    pub fn tcp_nodelay(mut self, enabled: bool) -> Self {
408        self.tcp_nodelay = enabled;
409        self
410    }
411
412    /// Set the value of tls client auth.
413    ///
414    /// Disabled by default.
415    pub fn client_tls_config(mut self, client_tls_option: ClientTlsOption) -> Self {
416        self.client_tls = Some(client_tls_option);
417        self
418    }
419}
420
421#[derive(Debug)]
422pub struct Channel {
423    channel: InnerChannel,
424    access: AtomicUsize,
425    use_default_connector: bool,
426}
427
428impl Channel {
429    #[inline]
430    pub fn access(&self) -> usize {
431        self.access.load(Ordering::Relaxed)
432    }
433
434    #[inline]
435    pub fn use_default_connector(&self) -> bool {
436        self.use_default_connector
437    }
438
439    #[inline]
440    pub fn increase_access(&self) {
441        let _ = self.access.fetch_add(1, Ordering::Relaxed);
442    }
443}
444
445#[derive(Debug, Default)]
446struct Pool {
447    channels: DashMap<String, Channel>,
448}
449
450impl Pool {
451    fn get(&self, addr: &str) -> Option<InnerChannel> {
452        let channel = self.channels.get(addr);
453        channel.map(|ch| {
454            ch.increase_access();
455            ch.channel.clone()
456        })
457    }
458
459    fn entry(&self, addr: String) -> Entry<String, Channel> {
460        self.channels.entry(addr)
461    }
462
463    #[cfg(test)]
464    fn get_access(&self, addr: &str) -> Option<usize> {
465        let channel = self.channels.get(addr);
466        channel.map(|ch| ch.access())
467    }
468
469    fn put(&self, addr: &str, channel: Channel) {
470        let _ = self.channels.insert(addr.to_string(), channel);
471    }
472
473    fn retain_channel<F>(&self, f: F)
474    where
475        F: FnMut(&String, &mut Channel) -> bool,
476    {
477        self.channels.retain(f);
478    }
479}
480
481async fn recycle_channel_in_loop(
482    pool: Arc<Pool>,
483    id: u64,
484    cancel: CancellationToken,
485    interval_secs: u64,
486) {
487    let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
488
489    loop {
490        tokio::select! {
491            _ = cancel.cancelled() => {
492                info!("Stop channel recycle, ChannelManager id: {}", id);
493                break;
494            },
495            _ = interval.tick() => {}
496        }
497
498        pool.retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0)
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use tower::service_fn;
505
506    use super::*;
507
508    #[should_panic]
509    #[test]
510    fn test_invalid_addr() {
511        let mgr = ChannelManager::default();
512        let addr = "http://test";
513
514        let _ = mgr.get(addr).unwrap();
515    }
516
517    #[tokio::test]
518    async fn test_access_count() {
519        let mgr = ChannelManager::new();
520        // Do not start recycle
521        mgr.inner
522            .channel_recycle_started
523            .store(true, Ordering::Relaxed);
524        let mgr = Arc::new(mgr);
525        let addr = "test_uri";
526
527        let mut joins = Vec::with_capacity(10);
528        for _ in 0..10 {
529            let mgr_clone = mgr.clone();
530            let join = tokio::spawn(async move {
531                for _ in 0..100 {
532                    let _ = mgr_clone.get(addr);
533                }
534            });
535            joins.push(join);
536        }
537        for join in joins {
538            join.await.unwrap();
539        }
540
541        assert_eq!(1000, mgr.pool().get_access(addr).unwrap());
542
543        mgr.pool()
544            .retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0);
545
546        assert_eq!(0, mgr.pool().get_access(addr).unwrap());
547    }
548
549    #[test]
550    fn test_config() {
551        let default_cfg = ChannelConfig::new();
552        assert_eq!(
553            ChannelConfig {
554                timeout: Some(Duration::from_secs(DEFAULT_GRPC_REQUEST_TIMEOUT_SECS)),
555                connect_timeout: Some(Duration::from_secs(DEFAULT_GRPC_CONNECT_TIMEOUT_SECS)),
556                concurrency_limit: None,
557                rate_limit: None,
558                initial_stream_window_size: None,
559                initial_connection_window_size: None,
560                http2_keep_alive_interval: Some(Duration::from_secs(30)),
561                http2_keep_alive_timeout: None,
562                http2_keep_alive_while_idle: Some(true),
563                http2_adaptive_window: None,
564                tcp_keepalive: None,
565                tcp_nodelay: true,
566                client_tls: None,
567                max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
568                max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
569            },
570            default_cfg
571        );
572
573        let cfg = default_cfg
574            .timeout(Duration::from_secs(3))
575            .connect_timeout(Duration::from_secs(5))
576            .concurrency_limit(6)
577            .rate_limit(5, Duration::from_secs(1))
578            .initial_stream_window_size(10)
579            .initial_connection_window_size(20)
580            .http2_keep_alive_interval(Duration::from_secs(1))
581            .http2_keep_alive_timeout(Duration::from_secs(3))
582            .http2_keep_alive_while_idle(true)
583            .http2_adaptive_window(true)
584            .tcp_keepalive(Duration::from_secs(2))
585            .tcp_nodelay(false)
586            .client_tls_config(ClientTlsOption {
587                server_ca_cert_path: "some_server_path".to_string(),
588                client_cert_path: "some_cert_path".to_string(),
589                client_key_path: "some_key_path".to_string(),
590            });
591
592        assert_eq!(
593            ChannelConfig {
594                timeout: Some(Duration::from_secs(3)),
595                connect_timeout: Some(Duration::from_secs(5)),
596                concurrency_limit: Some(6),
597                rate_limit: Some((5, Duration::from_secs(1))),
598                initial_stream_window_size: Some(10),
599                initial_connection_window_size: Some(20),
600                http2_keep_alive_interval: Some(Duration::from_secs(1)),
601                http2_keep_alive_timeout: Some(Duration::from_secs(3)),
602                http2_keep_alive_while_idle: Some(true),
603                http2_adaptive_window: Some(true),
604                tcp_keepalive: Some(Duration::from_secs(2)),
605                tcp_nodelay: false,
606                client_tls: Some(ClientTlsOption {
607                    server_ca_cert_path: "some_server_path".to_string(),
608                    client_cert_path: "some_cert_path".to_string(),
609                    client_key_path: "some_key_path".to_string(),
610                }),
611                max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
612                max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
613            },
614            cfg
615        );
616    }
617
618    #[test]
619    fn test_build_endpoint() {
620        let config = ChannelConfig::new()
621            .timeout(Duration::from_secs(3))
622            .connect_timeout(Duration::from_secs(5))
623            .concurrency_limit(6)
624            .rate_limit(5, Duration::from_secs(1))
625            .initial_stream_window_size(10)
626            .initial_connection_window_size(20)
627            .http2_keep_alive_interval(Duration::from_secs(1))
628            .http2_keep_alive_timeout(Duration::from_secs(3))
629            .http2_keep_alive_while_idle(true)
630            .http2_adaptive_window(true)
631            .tcp_keepalive(Duration::from_secs(2))
632            .tcp_nodelay(true);
633        let mgr = ChannelManager::with_config(config);
634
635        let res = mgr.build_endpoint("test_addr");
636
637        let _ = res.unwrap();
638    }
639
640    #[tokio::test]
641    async fn test_channel_with_connector() {
642        let mgr = ChannelManager::new();
643
644        let addr = "test_addr";
645        let res = mgr.get(addr);
646        let _ = res.unwrap();
647
648        mgr.retain_channel(|addr, channel| {
649            assert_eq!("test_addr", addr);
650            assert!(channel.use_default_connector());
651            true
652        });
653
654        let (client, _) = tokio::io::duplex(1024);
655        let mut client = Some(hyper_util::rt::TokioIo::new(client));
656        let res = mgr.reset_with_connector(
657            addr,
658            service_fn(move |_| {
659                let client = client.take().unwrap();
660                async move { Ok::<_, std::io::Error>(client) }
661            }),
662        );
663
664        let _ = res.unwrap();
665
666        mgr.retain_channel(|addr, channel| {
667            assert_eq!("test_addr", addr);
668            assert!(!channel.use_default_connector());
669            true
670        });
671    }
672
673    #[tokio::test]
674    async fn test_pool_release_with_channel_recycle() {
675        let mgr = ChannelManager::new();
676
677        let pool_holder = mgr.pool().clone();
678
679        // start channel recycle task
680        let addr = "test_addr";
681        let _ = mgr.get(addr);
682
683        let mgr_clone_1 = mgr.clone();
684        let mgr_clone_2 = mgr.clone();
685        assert_eq!(3, Arc::strong_count(mgr.pool()));
686
687        drop(mgr_clone_1);
688        drop(mgr_clone_2);
689        assert_eq!(3, Arc::strong_count(mgr.pool()));
690
691        drop(mgr);
692
693        // wait for the channel recycle task to finish
694        tokio::time::sleep(Duration::from_millis(10)).await;
695
696        assert_eq!(1, Arc::strong_count(&pool_holder));
697    }
698
699    #[tokio::test]
700    async fn test_pool_release_without_channel_recycle() {
701        let mgr = ChannelManager::new();
702
703        let pool_holder = mgr.pool().clone();
704
705        let mgr_clone_1 = mgr.clone();
706        let mgr_clone_2 = mgr.clone();
707        assert_eq!(2, Arc::strong_count(mgr.pool()));
708
709        drop(mgr_clone_1);
710        drop(mgr_clone_2);
711        assert_eq!(2, Arc::strong_count(mgr.pool()));
712
713        drop(mgr);
714
715        assert_eq!(1, Arc::strong_count(&pool_holder));
716    }
717}