common_grpc/
channel_manager.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
16use std::sync::Arc;
17use std::time::Duration;
18
19use common_base::readable_size::ReadableSize;
20use common_telemetry::info;
21use dashmap::mapref::entry::Entry;
22use dashmap::DashMap;
23use lazy_static::lazy_static;
24use serde::{Deserialize, Serialize};
25use snafu::{OptionExt, ResultExt};
26use tokio_util::sync::CancellationToken;
27use tonic::transport::{
28    Certificate, Channel as InnerChannel, ClientTlsConfig, Endpoint, Identity, Uri,
29};
30use tower::Service;
31
32use crate::error::{CreateChannelSnafu, InvalidConfigFilePathSnafu, InvalidTlsConfigSnafu, Result};
33
34const RECYCLE_CHANNEL_INTERVAL_SECS: u64 = 60;
35pub const DEFAULT_GRPC_REQUEST_TIMEOUT_SECS: u64 = 10;
36pub const DEFAULT_GRPC_CONNECT_TIMEOUT_SECS: u64 = 1;
37pub const DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
38pub const DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
39
40lazy_static! {
41    static ref ID: AtomicU64 = AtomicU64::new(0);
42}
43
44#[derive(Clone, Debug, Default)]
45pub struct ChannelManager {
46    inner: Arc<Inner>,
47}
48
49#[derive(Debug)]
50struct Inner {
51    id: u64,
52    config: ChannelConfig,
53    client_tls_config: Option<ClientTlsConfig>,
54    pool: Arc<Pool>,
55    channel_recycle_started: AtomicBool,
56    cancel: CancellationToken,
57}
58
59impl Default for Inner {
60    fn default() -> Self {
61        Self::with_config(ChannelConfig::default())
62    }
63}
64
65impl Drop for Inner {
66    fn drop(&mut self) {
67        // Cancel the channel recycle task.
68        self.cancel.cancel();
69    }
70}
71
72impl Inner {
73    fn with_config(config: ChannelConfig) -> Self {
74        let id = ID.fetch_add(1, Ordering::Relaxed);
75        let pool = Arc::new(Pool::default());
76        let cancel = CancellationToken::new();
77
78        Self {
79            id,
80            config,
81            client_tls_config: None,
82            pool,
83            channel_recycle_started: AtomicBool::new(false),
84            cancel,
85        }
86    }
87}
88
89impl ChannelManager {
90    pub fn new() -> Self {
91        Default::default()
92    }
93
94    pub fn with_config(config: ChannelConfig) -> Self {
95        let inner = Inner::with_config(config);
96        Self {
97            inner: Arc::new(inner),
98        }
99    }
100
101    /// Read tls cert and key files and create a ChannelManager with TLS config.
102    pub fn with_tls_config(config: ChannelConfig) -> Result<Self> {
103        let mut inner = Inner::with_config(config.clone());
104
105        // setup tls
106        let path_config = config.client_tls.context(InvalidTlsConfigSnafu {
107            msg: "no config input",
108        })?;
109
110        if !path_config.enabled {
111            // if TLS not enabled, just ignore other tls config
112            // and not set `client_tls_config` hence not use TLS
113            return Ok(Self {
114                inner: Arc::new(inner),
115            });
116        }
117
118        let mut tls_config = ClientTlsConfig::new();
119
120        if let Some(server_ca) = path_config.server_ca_cert_path {
121            let server_root_ca_cert =
122                std::fs::read_to_string(server_ca).context(InvalidConfigFilePathSnafu)?;
123            let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert);
124            tls_config = tls_config.ca_certificate(server_root_ca_cert);
125        }
126
127        if let (Some(client_cert_path), Some(client_key_path)) =
128            (&path_config.client_cert_path, &path_config.client_key_path)
129        {
130            let client_cert =
131                std::fs::read_to_string(client_cert_path).context(InvalidConfigFilePathSnafu)?;
132            let client_key =
133                std::fs::read_to_string(client_key_path).context(InvalidConfigFilePathSnafu)?;
134            let client_identity = Identity::from_pem(client_cert, client_key);
135            tls_config = tls_config.identity(client_identity);
136        }
137
138        inner.client_tls_config = Some(tls_config);
139
140        Ok(Self {
141            inner: Arc::new(inner),
142        })
143    }
144
145    pub fn config(&self) -> &ChannelConfig {
146        &self.inner.config
147    }
148
149    fn pool(&self) -> &Arc<Pool> {
150        &self.inner.pool
151    }
152
153    pub fn get(&self, addr: impl AsRef<str>) -> Result<InnerChannel> {
154        self.trigger_channel_recycling();
155
156        let addr = addr.as_ref();
157        // It will acquire the read lock.
158        if let Some(inner_ch) = self.pool().get(addr) {
159            return Ok(inner_ch);
160        }
161
162        // It will acquire the write lock.
163        let entry = match self.pool().entry(addr.to_string()) {
164            Entry::Occupied(entry) => {
165                entry.get().increase_access();
166                entry.into_ref()
167            }
168            Entry::Vacant(entry) => {
169                let endpoint = self.build_endpoint(addr)?;
170                let inner_channel = endpoint.connect_lazy();
171
172                let channel = Channel {
173                    channel: inner_channel,
174                    access: AtomicUsize::new(1),
175                    use_default_connector: true,
176                };
177                entry.insert(channel)
178            }
179        };
180        Ok(entry.channel.clone())
181    }
182
183    pub fn reset_with_connector<C>(
184        &self,
185        addr: impl AsRef<str>,
186        connector: C,
187    ) -> Result<InnerChannel>
188    where
189        C: Service<Uri> + Send + 'static,
190        C::Response: hyper::rt::Read + hyper::rt::Write + Send + Unpin,
191        C::Future: Send + 'static,
192        Box<dyn std::error::Error + Send + Sync>: From<C::Error> + Send + 'static,
193    {
194        let addr = addr.as_ref();
195        let endpoint = self.build_endpoint(addr)?;
196        let inner_channel = endpoint.connect_with_connector_lazy(connector);
197        let channel = Channel {
198            channel: inner_channel.clone(),
199            access: AtomicUsize::new(1),
200            use_default_connector: false,
201        };
202        self.pool().put(addr, channel);
203
204        Ok(inner_channel)
205    }
206
207    pub fn retain_channel<F>(&self, f: F)
208    where
209        F: FnMut(&String, &mut Channel) -> bool,
210    {
211        self.pool().retain_channel(f);
212    }
213
214    fn build_endpoint(&self, addr: &str) -> Result<Endpoint> {
215        let http_prefix = if self.inner.client_tls_config.is_some() {
216            "https"
217        } else {
218            "http"
219        };
220
221        let mut endpoint = Endpoint::new(format!("{http_prefix}://{addr}"))
222            .context(CreateChannelSnafu { addr })?;
223
224        if let Some(dur) = self.config().timeout {
225            endpoint = endpoint.timeout(dur);
226        }
227        if let Some(dur) = self.config().connect_timeout {
228            endpoint = endpoint.connect_timeout(dur);
229        }
230        if let Some(limit) = self.config().concurrency_limit {
231            endpoint = endpoint.concurrency_limit(limit);
232        }
233        if let Some((limit, dur)) = self.config().rate_limit {
234            endpoint = endpoint.rate_limit(limit, dur);
235        }
236        if let Some(size) = self.config().initial_stream_window_size {
237            endpoint = endpoint.initial_stream_window_size(size);
238        }
239        if let Some(size) = self.config().initial_connection_window_size {
240            endpoint = endpoint.initial_connection_window_size(size);
241        }
242        if let Some(dur) = self.config().http2_keep_alive_interval {
243            endpoint = endpoint.http2_keep_alive_interval(dur);
244        }
245        if let Some(dur) = self.config().http2_keep_alive_timeout {
246            endpoint = endpoint.keep_alive_timeout(dur);
247        }
248        if let Some(enabled) = self.config().http2_keep_alive_while_idle {
249            endpoint = endpoint.keep_alive_while_idle(enabled);
250        }
251        if let Some(enabled) = self.config().http2_adaptive_window {
252            endpoint = endpoint.http2_adaptive_window(enabled);
253        }
254        if let Some(tls_config) = &self.inner.client_tls_config {
255            endpoint = endpoint
256                .tls_config(tls_config.clone())
257                .context(CreateChannelSnafu { addr })?;
258        }
259
260        endpoint = endpoint
261            .tcp_keepalive(self.config().tcp_keepalive)
262            .tcp_nodelay(self.config().tcp_nodelay);
263
264        Ok(endpoint)
265    }
266
267    fn trigger_channel_recycling(&self) {
268        if self
269            .inner
270            .channel_recycle_started
271            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
272            .is_err()
273        {
274            return;
275        }
276
277        let pool = self.pool().clone();
278        let cancel = self.inner.cancel.clone();
279        let id = self.inner.id;
280        let _handle = common_runtime::spawn_global(async move {
281            recycle_channel_in_loop(pool, id, cancel, RECYCLE_CHANNEL_INTERVAL_SECS).await;
282        });
283        info!(
284            "ChannelManager: {}, channel recycle is started, running in the background!",
285            self.inner.id
286        );
287    }
288}
289
290#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
291pub struct ClientTlsOption {
292    /// Whether to enable TLS for client.
293    pub enabled: bool,
294    pub server_ca_cert_path: Option<String>,
295    pub client_cert_path: Option<String>,
296    pub client_key_path: Option<String>,
297}
298
299#[derive(Clone, Debug, PartialEq, Eq)]
300pub struct ChannelConfig {
301    pub timeout: Option<Duration>,
302    pub connect_timeout: Option<Duration>,
303    pub concurrency_limit: Option<usize>,
304    pub rate_limit: Option<(u64, Duration)>,
305    pub initial_stream_window_size: Option<u32>,
306    pub initial_connection_window_size: Option<u32>,
307    pub http2_keep_alive_interval: Option<Duration>,
308    pub http2_keep_alive_timeout: Option<Duration>,
309    pub http2_keep_alive_while_idle: Option<bool>,
310    pub http2_adaptive_window: Option<bool>,
311    pub tcp_keepalive: Option<Duration>,
312    pub tcp_nodelay: bool,
313    pub client_tls: Option<ClientTlsOption>,
314    // Max gRPC receiving(decoding) message size
315    pub max_recv_message_size: ReadableSize,
316    // Max gRPC sending(encoding) message size
317    pub max_send_message_size: ReadableSize,
318    pub send_compression: bool,
319    pub accept_compression: bool,
320}
321
322impl Default for ChannelConfig {
323    fn default() -> Self {
324        Self {
325            timeout: Some(Duration::from_secs(DEFAULT_GRPC_REQUEST_TIMEOUT_SECS)),
326            connect_timeout: Some(Duration::from_secs(DEFAULT_GRPC_CONNECT_TIMEOUT_SECS)),
327            concurrency_limit: None,
328            rate_limit: None,
329            initial_stream_window_size: None,
330            initial_connection_window_size: None,
331            http2_keep_alive_interval: Some(Duration::from_secs(30)),
332            http2_keep_alive_timeout: None,
333            http2_keep_alive_while_idle: Some(true),
334            http2_adaptive_window: None,
335            tcp_keepalive: None,
336            tcp_nodelay: true,
337            client_tls: None,
338            max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
339            max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
340            send_compression: false,
341            accept_compression: false,
342        }
343    }
344}
345
346impl ChannelConfig {
347    pub fn new() -> Self {
348        Default::default()
349    }
350
351    /// A timeout to each request.
352    pub fn timeout(mut self, timeout: Duration) -> Self {
353        self.timeout = Some(timeout);
354        self
355    }
356
357    /// A timeout to connecting to the uri.
358    ///
359    /// Defaults to no timeout.
360    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
361        self.connect_timeout = Some(timeout);
362        self
363    }
364
365    /// A concurrency limit to each request.
366    pub fn concurrency_limit(mut self, limit: usize) -> Self {
367        self.concurrency_limit = Some(limit);
368        self
369    }
370
371    /// A rate limit to each request.
372    pub fn rate_limit(mut self, limit: u64, duration: Duration) -> Self {
373        self.rate_limit = Some((limit, duration));
374        self
375    }
376
377    /// Sets the SETTINGS_INITIAL_WINDOW_SIZE option for HTTP2 stream-level flow control.
378    /// Default is 65,535
379    pub fn initial_stream_window_size(mut self, size: u32) -> Self {
380        self.initial_stream_window_size = Some(size);
381        self
382    }
383
384    /// Sets the max connection-level flow control for HTTP2
385    ///
386    /// Default is 65,535
387    pub fn initial_connection_window_size(mut self, size: u32) -> Self {
388        self.initial_connection_window_size = Some(size);
389        self
390    }
391
392    /// Set http2 KEEP_ALIVE_INTERVAL. Uses hyper’s default otherwise.
393    pub fn http2_keep_alive_interval(mut self, duration: Duration) -> Self {
394        self.http2_keep_alive_interval = Some(duration);
395        self
396    }
397
398    /// Set http2 KEEP_ALIVE_TIMEOUT. Uses hyper’s default otherwise.
399    pub fn http2_keep_alive_timeout(mut self, duration: Duration) -> Self {
400        self.http2_keep_alive_timeout = Some(duration);
401        self
402    }
403
404    /// Set http2 KEEP_ALIVE_WHILE_IDLE. Uses hyper’s default otherwise.
405    pub fn http2_keep_alive_while_idle(mut self, enabled: bool) -> Self {
406        self.http2_keep_alive_while_idle = Some(enabled);
407        self
408    }
409
410    /// Sets whether to use an adaptive flow control. Uses hyper’s default otherwise.
411    pub fn http2_adaptive_window(mut self, enabled: bool) -> Self {
412        self.http2_adaptive_window = Some(enabled);
413        self
414    }
415
416    /// Set whether TCP keepalive messages are enabled on accepted connections.
417    ///
418    /// If None is specified, keepalive is disabled, otherwise the duration specified
419    /// will be the time to remain idle before sending TCP keepalive probes.
420    ///
421    /// Default is no keepalive (None)
422    pub fn tcp_keepalive(mut self, duration: Duration) -> Self {
423        self.tcp_keepalive = Some(duration);
424        self
425    }
426
427    /// Set the value of TCP_NODELAY option for accepted connections.
428    ///
429    /// Enabled by default.
430    pub fn tcp_nodelay(mut self, enabled: bool) -> Self {
431        self.tcp_nodelay = enabled;
432        self
433    }
434
435    /// Set the value of tls client auth.
436    ///
437    /// Disabled by default.
438    pub fn client_tls_config(mut self, client_tls_option: ClientTlsOption) -> Self {
439        self.client_tls = Some(client_tls_option);
440        self
441    }
442}
443
444#[derive(Debug)]
445pub struct Channel {
446    channel: InnerChannel,
447    access: AtomicUsize,
448    use_default_connector: bool,
449}
450
451impl Channel {
452    #[inline]
453    pub fn access(&self) -> usize {
454        self.access.load(Ordering::Relaxed)
455    }
456
457    #[inline]
458    pub fn use_default_connector(&self) -> bool {
459        self.use_default_connector
460    }
461
462    #[inline]
463    pub fn increase_access(&self) {
464        let _ = self.access.fetch_add(1, Ordering::Relaxed);
465    }
466}
467
468#[derive(Debug, Default)]
469struct Pool {
470    channels: DashMap<String, Channel>,
471}
472
473impl Pool {
474    fn get(&self, addr: &str) -> Option<InnerChannel> {
475        let channel = self.channels.get(addr);
476        channel.map(|ch| {
477            ch.increase_access();
478            ch.channel.clone()
479        })
480    }
481
482    fn entry(&self, addr: String) -> Entry<String, Channel> {
483        self.channels.entry(addr)
484    }
485
486    #[cfg(test)]
487    fn get_access(&self, addr: &str) -> Option<usize> {
488        let channel = self.channels.get(addr);
489        channel.map(|ch| ch.access())
490    }
491
492    fn put(&self, addr: &str, channel: Channel) {
493        let _ = self.channels.insert(addr.to_string(), channel);
494    }
495
496    fn retain_channel<F>(&self, f: F)
497    where
498        F: FnMut(&String, &mut Channel) -> bool,
499    {
500        self.channels.retain(f);
501    }
502}
503
504async fn recycle_channel_in_loop(
505    pool: Arc<Pool>,
506    id: u64,
507    cancel: CancellationToken,
508    interval_secs: u64,
509) {
510    let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
511
512    loop {
513        tokio::select! {
514            _ = cancel.cancelled() => {
515                info!("Stop channel recycle, ChannelManager id: {}", id);
516                break;
517            },
518            _ = interval.tick() => {}
519        }
520
521        pool.retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0)
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use tower::service_fn;
528
529    use super::*;
530
531    #[should_panic]
532    #[test]
533    fn test_invalid_addr() {
534        let mgr = ChannelManager::default();
535        let addr = "http://test";
536
537        let _ = mgr.get(addr).unwrap();
538    }
539
540    #[tokio::test]
541    async fn test_access_count() {
542        let mgr = ChannelManager::new();
543        // Do not start recycle
544        mgr.inner
545            .channel_recycle_started
546            .store(true, Ordering::Relaxed);
547        let mgr = Arc::new(mgr);
548        let addr = "test_uri";
549
550        let mut joins = Vec::with_capacity(10);
551        for _ in 0..10 {
552            let mgr_clone = mgr.clone();
553            let join = tokio::spawn(async move {
554                for _ in 0..100 {
555                    let _ = mgr_clone.get(addr);
556                }
557            });
558            joins.push(join);
559        }
560        for join in joins {
561            join.await.unwrap();
562        }
563
564        assert_eq!(1000, mgr.pool().get_access(addr).unwrap());
565
566        mgr.pool()
567            .retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0);
568
569        assert_eq!(0, mgr.pool().get_access(addr).unwrap());
570    }
571
572    #[test]
573    fn test_config() {
574        let default_cfg = ChannelConfig::new();
575        assert_eq!(
576            ChannelConfig {
577                timeout: Some(Duration::from_secs(DEFAULT_GRPC_REQUEST_TIMEOUT_SECS)),
578                connect_timeout: Some(Duration::from_secs(DEFAULT_GRPC_CONNECT_TIMEOUT_SECS)),
579                concurrency_limit: None,
580                rate_limit: None,
581                initial_stream_window_size: None,
582                initial_connection_window_size: None,
583                http2_keep_alive_interval: Some(Duration::from_secs(30)),
584                http2_keep_alive_timeout: None,
585                http2_keep_alive_while_idle: Some(true),
586                http2_adaptive_window: None,
587                tcp_keepalive: None,
588                tcp_nodelay: true,
589                client_tls: None,
590                max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
591                max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
592                send_compression: false,
593                accept_compression: false,
594            },
595            default_cfg
596        );
597
598        let cfg = default_cfg
599            .timeout(Duration::from_secs(3))
600            .connect_timeout(Duration::from_secs(5))
601            .concurrency_limit(6)
602            .rate_limit(5, Duration::from_secs(1))
603            .initial_stream_window_size(10)
604            .initial_connection_window_size(20)
605            .http2_keep_alive_interval(Duration::from_secs(1))
606            .http2_keep_alive_timeout(Duration::from_secs(3))
607            .http2_keep_alive_while_idle(true)
608            .http2_adaptive_window(true)
609            .tcp_keepalive(Duration::from_secs(2))
610            .tcp_nodelay(false)
611            .client_tls_config(ClientTlsOption {
612                enabled: true,
613                server_ca_cert_path: Some("some_server_path".to_string()),
614                client_cert_path: Some("some_cert_path".to_string()),
615                client_key_path: Some("some_key_path".to_string()),
616            });
617
618        assert_eq!(
619            ChannelConfig {
620                timeout: Some(Duration::from_secs(3)),
621                connect_timeout: Some(Duration::from_secs(5)),
622                concurrency_limit: Some(6),
623                rate_limit: Some((5, Duration::from_secs(1))),
624                initial_stream_window_size: Some(10),
625                initial_connection_window_size: Some(20),
626                http2_keep_alive_interval: Some(Duration::from_secs(1)),
627                http2_keep_alive_timeout: Some(Duration::from_secs(3)),
628                http2_keep_alive_while_idle: Some(true),
629                http2_adaptive_window: Some(true),
630                tcp_keepalive: Some(Duration::from_secs(2)),
631                tcp_nodelay: false,
632                client_tls: Some(ClientTlsOption {
633                    enabled: true,
634                    server_ca_cert_path: Some("some_server_path".to_string()),
635                    client_cert_path: Some("some_cert_path".to_string()),
636                    client_key_path: Some("some_key_path".to_string()),
637                }),
638                max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
639                max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
640                send_compression: false,
641                accept_compression: false,
642            },
643            cfg
644        );
645    }
646
647    #[test]
648    fn test_build_endpoint() {
649        let config = ChannelConfig::new()
650            .timeout(Duration::from_secs(3))
651            .connect_timeout(Duration::from_secs(5))
652            .concurrency_limit(6)
653            .rate_limit(5, Duration::from_secs(1))
654            .initial_stream_window_size(10)
655            .initial_connection_window_size(20)
656            .http2_keep_alive_interval(Duration::from_secs(1))
657            .http2_keep_alive_timeout(Duration::from_secs(3))
658            .http2_keep_alive_while_idle(true)
659            .http2_adaptive_window(true)
660            .tcp_keepalive(Duration::from_secs(2))
661            .tcp_nodelay(true);
662        let mgr = ChannelManager::with_config(config);
663
664        let res = mgr.build_endpoint("test_addr");
665
666        let _ = res.unwrap();
667    }
668
669    #[tokio::test]
670    async fn test_channel_with_connector() {
671        let mgr = ChannelManager::new();
672
673        let addr = "test_addr";
674        let res = mgr.get(addr);
675        let _ = res.unwrap();
676
677        mgr.retain_channel(|addr, channel| {
678            assert_eq!("test_addr", addr);
679            assert!(channel.use_default_connector());
680            true
681        });
682
683        let (client, _) = tokio::io::duplex(1024);
684        let mut client = Some(hyper_util::rt::TokioIo::new(client));
685        let res = mgr.reset_with_connector(
686            addr,
687            service_fn(move |_| {
688                let client = client.take().unwrap();
689                async move { Ok::<_, std::io::Error>(client) }
690            }),
691        );
692
693        let _ = res.unwrap();
694
695        mgr.retain_channel(|addr, channel| {
696            assert_eq!("test_addr", addr);
697            assert!(!channel.use_default_connector());
698            true
699        });
700    }
701
702    #[tokio::test]
703    async fn test_pool_release_with_channel_recycle() {
704        let mgr = ChannelManager::new();
705
706        let pool_holder = mgr.pool().clone();
707
708        // start channel recycle task
709        let addr = "test_addr";
710        let _ = mgr.get(addr);
711
712        let mgr_clone_1 = mgr.clone();
713        let mgr_clone_2 = mgr.clone();
714        assert_eq!(3, Arc::strong_count(mgr.pool()));
715
716        drop(mgr_clone_1);
717        drop(mgr_clone_2);
718        assert_eq!(3, Arc::strong_count(mgr.pool()));
719
720        drop(mgr);
721
722        // wait for the channel recycle task to finish
723        tokio::time::sleep(Duration::from_millis(10)).await;
724
725        assert_eq!(1, Arc::strong_count(&pool_holder));
726    }
727
728    #[tokio::test]
729    async fn test_pool_release_without_channel_recycle() {
730        let mgr = ChannelManager::new();
731
732        let pool_holder = mgr.pool().clone();
733
734        let mgr_clone_1 = mgr.clone();
735        let mgr_clone_2 = mgr.clone();
736        assert_eq!(2, Arc::strong_count(mgr.pool()));
737
738        drop(mgr_clone_1);
739        drop(mgr_clone_2);
740        assert_eq!(2, Arc::strong_count(mgr.pool()));
741
742        drop(mgr);
743
744        assert_eq!(1, Arc::strong_count(&pool_holder));
745    }
746}