common_grpc/
channel_manager.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::path::Path;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
18use std::time::Duration;
19
20use common_base::readable_size::ReadableSize;
21use common_telemetry::info;
22use dashmap::DashMap;
23use dashmap::mapref::entry::Entry;
24use lazy_static::lazy_static;
25use serde::{Deserialize, Serialize};
26use snafu::ResultExt;
27use tokio_util::sync::CancellationToken;
28use tonic::transport::{
29    Certificate, Channel as InnerChannel, ClientTlsConfig, Endpoint, Identity, Uri,
30};
31use tower::Service;
32
33use crate::error::{CreateChannelSnafu, InvalidConfigFilePathSnafu, Result};
34use crate::reloadable_tls::{ReloadableTlsConfig, TlsConfigLoader, maybe_watch_tls_config};
35
36const RECYCLE_CHANNEL_INTERVAL_SECS: u64 = 60;
37pub const DEFAULT_GRPC_REQUEST_TIMEOUT_SECS: u64 = 10;
38pub const DEFAULT_GRPC_CONNECT_TIMEOUT_SECS: u64 = 1;
39pub const DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
40pub const DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
41
42lazy_static! {
43    static ref ID: AtomicU64 = AtomicU64::new(0);
44}
45
46#[derive(Clone, Debug, Default)]
47pub struct ChannelManager {
48    inner: Arc<Inner>,
49}
50
51#[derive(Debug)]
52struct Inner {
53    id: u64,
54    config: ChannelConfig,
55    reloadable_client_tls_config: Option<Arc<ReloadableClientTlsConfig>>,
56    pool: Arc<Pool>,
57    channel_recycle_started: AtomicBool,
58    cancel: CancellationToken,
59}
60
61impl Default for Inner {
62    fn default() -> Self {
63        Self::with_config(ChannelConfig::default())
64    }
65}
66
67impl Drop for Inner {
68    fn drop(&mut self) {
69        // Cancel the channel recycle task.
70        self.cancel.cancel();
71    }
72}
73
74impl Inner {
75    fn with_config(config: ChannelConfig) -> Self {
76        let id = ID.fetch_add(1, Ordering::Relaxed);
77        let pool = Arc::new(Pool::default());
78        let cancel = CancellationToken::new();
79
80        Self {
81            id,
82            config,
83            reloadable_client_tls_config: None,
84            pool,
85            channel_recycle_started: AtomicBool::new(false),
86            cancel,
87        }
88    }
89}
90
91impl ChannelManager {
92    pub fn new() -> Self {
93        Default::default()
94    }
95
96    /// Create a ChannelManager with configuration and optional TLS config
97    ///
98    /// Use [`load_client_tls_config`] to create TLS configuration from `ClientTlsOption`.
99    /// The TLS config supports both static (watch disabled) and dynamic reloading (watch enabled).
100    /// If you want to use dynamic reloading, please **manually** invoke [`maybe_watch_client_tls_config`] after this method.
101    pub fn with_config(
102        config: ChannelConfig,
103        reloadable_tls_config: Option<Arc<ReloadableClientTlsConfig>>,
104    ) -> Self {
105        let mut inner = Inner::with_config(config.clone());
106        inner.reloadable_client_tls_config = reloadable_tls_config;
107        Self {
108            inner: Arc::new(inner),
109        }
110    }
111
112    pub fn config(&self) -> &ChannelConfig {
113        &self.inner.config
114    }
115
116    fn pool(&self) -> &Arc<Pool> {
117        &self.inner.pool
118    }
119
120    pub fn get(&self, addr: impl AsRef<str>) -> Result<InnerChannel> {
121        self.trigger_channel_recycling();
122
123        let addr = addr.as_ref();
124        // It will acquire the read lock.
125        if let Some(inner_ch) = self.pool().get(addr) {
126            return Ok(inner_ch);
127        }
128
129        // It will acquire the write lock.
130        let entry = match self.pool().entry(addr.to_string()) {
131            Entry::Occupied(entry) => {
132                entry.get().increase_access();
133                entry.into_ref()
134            }
135            Entry::Vacant(entry) => {
136                let endpoint = self.build_endpoint(addr)?;
137                let inner_channel = endpoint.connect_lazy();
138
139                let channel = Channel {
140                    channel: inner_channel,
141                    access: AtomicUsize::new(1),
142                    use_default_connector: true,
143                };
144                entry.insert(channel)
145            }
146        };
147        Ok(entry.channel.clone())
148    }
149
150    pub fn reset_with_connector<C>(
151        &self,
152        addr: impl AsRef<str>,
153        connector: C,
154    ) -> Result<InnerChannel>
155    where
156        C: Service<Uri> + Send + 'static,
157        C::Response: hyper::rt::Read + hyper::rt::Write + Send + Unpin,
158        C::Future: Send + 'static,
159        Box<dyn std::error::Error + Send + Sync>: From<C::Error> + Send + 'static,
160    {
161        let addr = addr.as_ref();
162        let endpoint = self.build_endpoint(addr)?;
163        let inner_channel = endpoint.connect_with_connector_lazy(connector);
164        let channel = Channel {
165            channel: inner_channel.clone(),
166            access: AtomicUsize::new(1),
167            use_default_connector: false,
168        };
169        self.pool().put(addr, channel);
170
171        Ok(inner_channel)
172    }
173
174    pub fn retain_channel<F>(&self, f: F)
175    where
176        F: FnMut(&String, &mut Channel) -> bool,
177    {
178        self.pool().retain_channel(f);
179    }
180
181    /// Clear all channels to force reconnection.
182    /// This should be called when TLS configuration changes to ensure new connections use updated certificates.
183    pub fn clear_all_channels(&self) {
184        self.pool().retain_channel(|_, _| false);
185    }
186
187    fn build_endpoint(&self, addr: &str) -> Result<Endpoint> {
188        // Get the latest TLS config from reloadable config (which handles both static and dynamic cases)
189        let tls_config = self
190            .inner
191            .reloadable_client_tls_config
192            .as_ref()
193            .and_then(|c| c.get_config());
194
195        let http_prefix = if tls_config.is_some() {
196            "https"
197        } else {
198            "http"
199        };
200
201        let mut endpoint = Endpoint::new(format!("{http_prefix}://{addr}"))
202            .context(CreateChannelSnafu { addr })?;
203
204        if let Some(dur) = self.config().timeout {
205            endpoint = endpoint.timeout(dur);
206        }
207        if let Some(dur) = self.config().connect_timeout {
208            endpoint = endpoint.connect_timeout(dur);
209        }
210        if let Some(limit) = self.config().concurrency_limit {
211            endpoint = endpoint.concurrency_limit(limit);
212        }
213        if let Some((limit, dur)) = self.config().rate_limit {
214            endpoint = endpoint.rate_limit(limit, dur);
215        }
216        if let Some(size) = self.config().initial_stream_window_size {
217            endpoint = endpoint.initial_stream_window_size(size);
218        }
219        if let Some(size) = self.config().initial_connection_window_size {
220            endpoint = endpoint.initial_connection_window_size(size);
221        }
222        if let Some(dur) = self.config().http2_keep_alive_interval {
223            endpoint = endpoint.http2_keep_alive_interval(dur);
224        }
225        if let Some(dur) = self.config().http2_keep_alive_timeout {
226            endpoint = endpoint.keep_alive_timeout(dur);
227        }
228        if let Some(enabled) = self.config().http2_keep_alive_while_idle {
229            endpoint = endpoint.keep_alive_while_idle(enabled);
230        }
231        if let Some(enabled) = self.config().http2_adaptive_window {
232            endpoint = endpoint.http2_adaptive_window(enabled);
233        }
234        if let Some(tls_config) = tls_config {
235            endpoint = endpoint
236                .tls_config(tls_config)
237                .context(CreateChannelSnafu { addr })?;
238        }
239
240        endpoint = endpoint
241            .tcp_keepalive(self.config().tcp_keepalive)
242            .tcp_nodelay(self.config().tcp_nodelay);
243
244        Ok(endpoint)
245    }
246
247    fn trigger_channel_recycling(&self) {
248        if self
249            .inner
250            .channel_recycle_started
251            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
252            .is_err()
253        {
254            return;
255        }
256
257        let pool = self.pool().clone();
258        let cancel = self.inner.cancel.clone();
259        let id = self.inner.id;
260        let _handle = common_runtime::spawn_global(async move {
261            recycle_channel_in_loop(pool, id, cancel, RECYCLE_CHANNEL_INTERVAL_SECS).await;
262        });
263        info!(
264            "ChannelManager: {}, channel recycle is started, running in the background!",
265            self.inner.id
266        );
267    }
268}
269
270fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result<Option<ClientTlsConfig>> {
271    let path_config = match tls_option {
272        Some(path_config) if path_config.enabled => path_config,
273        _ => return Ok(None),
274    };
275
276    let mut tls_config = ClientTlsConfig::new();
277
278    if let Some(server_ca) = &path_config.server_ca_cert_path {
279        let server_root_ca_cert =
280            std::fs::read_to_string(server_ca).context(InvalidConfigFilePathSnafu)?;
281        let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert);
282        tls_config = tls_config.ca_certificate(server_root_ca_cert);
283    }
284
285    if let (Some(client_cert_path), Some(client_key_path)) =
286        (&path_config.client_cert_path, &path_config.client_key_path)
287    {
288        let client_cert =
289            std::fs::read_to_string(client_cert_path).context(InvalidConfigFilePathSnafu)?;
290        let client_key =
291            std::fs::read_to_string(client_key_path).context(InvalidConfigFilePathSnafu)?;
292        let client_identity = Identity::from_pem(client_cert, client_key);
293        tls_config = tls_config.identity(client_identity);
294    }
295    Ok(Some(tls_config))
296}
297
298impl TlsConfigLoader<ClientTlsConfig> for ClientTlsOption {
299    type Error = crate::error::Error;
300
301    fn load(&self) -> Result<Option<ClientTlsConfig>> {
302        load_tls_config(Some(self))
303    }
304
305    fn watch_paths(&self) -> Vec<&Path> {
306        let mut paths = Vec::new();
307        if let Some(cert_path) = &self.client_cert_path {
308            paths.push(Path::new(cert_path.as_str()));
309        }
310        if let Some(key_path) = &self.client_key_path {
311            paths.push(Path::new(key_path.as_str()));
312        }
313        if let Some(ca_path) = &self.server_ca_cert_path {
314            paths.push(Path::new(ca_path.as_str()));
315        }
316        paths
317    }
318
319    fn watch_enabled(&self) -> bool {
320        self.enabled && self.watch
321    }
322}
323
324/// Type alias for client-side reloadable TLS config
325pub type ReloadableClientTlsConfig = ReloadableTlsConfig<ClientTlsConfig, ClientTlsOption>;
326
327/// Load client TLS configuration from `ClientTlsOption` and return a `ReloadableClientTlsConfig`.
328/// This is the primary way to create TLS configuration for the ChannelManager.
329pub fn load_client_tls_config(
330    tls_option: Option<ClientTlsOption>,
331) -> Result<Option<Arc<ReloadableClientTlsConfig>>> {
332    match tls_option {
333        Some(option) if option.enabled => {
334            let reloadable = ReloadableClientTlsConfig::try_new(option)?;
335            Ok(Some(Arc::new(reloadable)))
336        }
337        _ => Ok(None),
338    }
339}
340
341pub fn maybe_watch_client_tls_config(
342    client_tls_config: Arc<ReloadableClientTlsConfig>,
343    channel_manager: ChannelManager,
344) -> Result<()> {
345    maybe_watch_tls_config(client_tls_config, move || {
346        // Clear all existing channels to force reconnection with new certificates
347        channel_manager.clear_all_channels();
348        info!("Cleared all existing channels to use new TLS certificates.");
349    })
350}
351
352#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)]
353pub struct ClientTlsOption {
354    /// Whether to enable TLS for client.
355    pub enabled: bool,
356    pub server_ca_cert_path: Option<String>,
357    pub client_cert_path: Option<String>,
358    pub client_key_path: Option<String>,
359    #[serde(default)]
360    pub watch: bool,
361}
362
363#[derive(Clone, Debug, PartialEq, Eq)]
364pub struct ChannelConfig {
365    pub timeout: Option<Duration>,
366    pub connect_timeout: Option<Duration>,
367    pub concurrency_limit: Option<usize>,
368    pub rate_limit: Option<(u64, Duration)>,
369    pub initial_stream_window_size: Option<u32>,
370    pub initial_connection_window_size: Option<u32>,
371    pub http2_keep_alive_interval: Option<Duration>,
372    pub http2_keep_alive_timeout: Option<Duration>,
373    pub http2_keep_alive_while_idle: Option<bool>,
374    pub http2_adaptive_window: Option<bool>,
375    pub tcp_keepalive: Option<Duration>,
376    pub tcp_nodelay: bool,
377    pub client_tls: Option<ClientTlsOption>,
378    // Max gRPC receiving(decoding) message size
379    pub max_recv_message_size: ReadableSize,
380    // Max gRPC sending(encoding) message size
381    pub max_send_message_size: ReadableSize,
382    pub send_compression: bool,
383    pub accept_compression: bool,
384}
385
386impl Default for ChannelConfig {
387    fn default() -> Self {
388        Self {
389            timeout: Some(Duration::from_secs(DEFAULT_GRPC_REQUEST_TIMEOUT_SECS)),
390            connect_timeout: Some(Duration::from_secs(DEFAULT_GRPC_CONNECT_TIMEOUT_SECS)),
391            concurrency_limit: None,
392            rate_limit: None,
393            initial_stream_window_size: None,
394            initial_connection_window_size: None,
395            http2_keep_alive_interval: Some(Duration::from_secs(30)),
396            http2_keep_alive_timeout: None,
397            http2_keep_alive_while_idle: Some(true),
398            http2_adaptive_window: None,
399            tcp_keepalive: None,
400            tcp_nodelay: true,
401            client_tls: None,
402            max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
403            max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
404            send_compression: false,
405            accept_compression: false,
406        }
407    }
408}
409
410impl ChannelConfig {
411    pub fn new() -> Self {
412        Default::default()
413    }
414
415    /// A timeout to each request.
416    pub fn timeout(mut self, timeout: Duration) -> Self {
417        self.timeout = Some(timeout);
418        self
419    }
420
421    /// A timeout to connecting to the uri.
422    ///
423    /// Defaults to no timeout.
424    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
425        self.connect_timeout = Some(timeout);
426        self
427    }
428
429    /// A concurrency limit to each request.
430    pub fn concurrency_limit(mut self, limit: usize) -> Self {
431        self.concurrency_limit = Some(limit);
432        self
433    }
434
435    /// A rate limit to each request.
436    pub fn rate_limit(mut self, limit: u64, duration: Duration) -> Self {
437        self.rate_limit = Some((limit, duration));
438        self
439    }
440
441    /// Sets the SETTINGS_INITIAL_WINDOW_SIZE option for HTTP2 stream-level flow control.
442    /// Default is 65,535
443    pub fn initial_stream_window_size(mut self, size: u32) -> Self {
444        self.initial_stream_window_size = Some(size);
445        self
446    }
447
448    /// Sets the max connection-level flow control for HTTP2
449    ///
450    /// Default is 65,535
451    pub fn initial_connection_window_size(mut self, size: u32) -> Self {
452        self.initial_connection_window_size = Some(size);
453        self
454    }
455
456    /// Set http2 KEEP_ALIVE_INTERVAL. Uses hyper’s default otherwise.
457    pub fn http2_keep_alive_interval(mut self, duration: Duration) -> Self {
458        self.http2_keep_alive_interval = Some(duration);
459        self
460    }
461
462    /// Set http2 KEEP_ALIVE_TIMEOUT. Uses hyper’s default otherwise.
463    pub fn http2_keep_alive_timeout(mut self, duration: Duration) -> Self {
464        self.http2_keep_alive_timeout = Some(duration);
465        self
466    }
467
468    /// Set http2 KEEP_ALIVE_WHILE_IDLE. Uses hyper’s default otherwise.
469    pub fn http2_keep_alive_while_idle(mut self, enabled: bool) -> Self {
470        self.http2_keep_alive_while_idle = Some(enabled);
471        self
472    }
473
474    /// Sets whether to use an adaptive flow control. Uses hyper’s default otherwise.
475    pub fn http2_adaptive_window(mut self, enabled: bool) -> Self {
476        self.http2_adaptive_window = Some(enabled);
477        self
478    }
479
480    /// Set whether TCP keepalive messages are enabled on accepted connections.
481    ///
482    /// If None is specified, keepalive is disabled, otherwise the duration specified
483    /// will be the time to remain idle before sending TCP keepalive probes.
484    ///
485    /// Default is no keepalive (None)
486    pub fn tcp_keepalive(mut self, duration: Duration) -> Self {
487        self.tcp_keepalive = Some(duration);
488        self
489    }
490
491    /// Set the value of TCP_NODELAY option for accepted connections.
492    ///
493    /// Enabled by default.
494    pub fn tcp_nodelay(mut self, enabled: bool) -> Self {
495        self.tcp_nodelay = enabled;
496        self
497    }
498
499    /// Set the value of tls client auth.
500    ///
501    /// Disabled by default.
502    pub fn client_tls_config(mut self, client_tls_option: ClientTlsOption) -> Self {
503        self.client_tls = Some(client_tls_option);
504        self
505    }
506}
507
508#[derive(Debug)]
509pub struct Channel {
510    channel: InnerChannel,
511    access: AtomicUsize,
512    use_default_connector: bool,
513}
514
515impl Channel {
516    #[inline]
517    pub fn access(&self) -> usize {
518        self.access.load(Ordering::Relaxed)
519    }
520
521    #[inline]
522    pub fn use_default_connector(&self) -> bool {
523        self.use_default_connector
524    }
525
526    #[inline]
527    pub fn increase_access(&self) {
528        let _ = self.access.fetch_add(1, Ordering::Relaxed);
529    }
530}
531
532#[derive(Debug, Default)]
533struct Pool {
534    channels: DashMap<String, Channel>,
535}
536
537impl Pool {
538    fn get(&self, addr: &str) -> Option<InnerChannel> {
539        let channel = self.channels.get(addr);
540        channel.map(|ch| {
541            ch.increase_access();
542            ch.channel.clone()
543        })
544    }
545
546    fn entry(&self, addr: String) -> Entry<'_, String, Channel> {
547        self.channels.entry(addr)
548    }
549
550    #[cfg(test)]
551    fn get_access(&self, addr: &str) -> Option<usize> {
552        let channel = self.channels.get(addr);
553        channel.map(|ch| ch.access())
554    }
555
556    fn put(&self, addr: &str, channel: Channel) {
557        let _ = self.channels.insert(addr.to_string(), channel);
558    }
559
560    fn retain_channel<F>(&self, f: F)
561    where
562        F: FnMut(&String, &mut Channel) -> bool,
563    {
564        self.channels.retain(f);
565    }
566}
567
568async fn recycle_channel_in_loop(
569    pool: Arc<Pool>,
570    id: u64,
571    cancel: CancellationToken,
572    interval_secs: u64,
573) {
574    let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
575
576    loop {
577        tokio::select! {
578            _ = cancel.cancelled() => {
579                info!("Stop channel recycle, ChannelManager id: {}", id);
580                break;
581            },
582            _ = interval.tick() => {}
583        }
584
585        pool.retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0)
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use tower::service_fn;
592
593    use super::*;
594
595    #[should_panic]
596    #[test]
597    fn test_invalid_addr() {
598        let mgr = ChannelManager::default();
599        let addr = "http://test";
600
601        let _ = mgr.get(addr).unwrap();
602    }
603
604    #[tokio::test]
605    async fn test_access_count() {
606        let mgr = ChannelManager::new();
607        // Do not start recycle
608        mgr.inner
609            .channel_recycle_started
610            .store(true, Ordering::Relaxed);
611        let mgr = Arc::new(mgr);
612        let addr = "test_uri";
613
614        let mut joins = Vec::with_capacity(10);
615        for _ in 0..10 {
616            let mgr_clone = mgr.clone();
617            let join = tokio::spawn(async move {
618                for _ in 0..100 {
619                    let _ = mgr_clone.get(addr);
620                }
621            });
622            joins.push(join);
623        }
624        for join in joins {
625            join.await.unwrap();
626        }
627
628        assert_eq!(1000, mgr.pool().get_access(addr).unwrap());
629
630        mgr.pool()
631            .retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0);
632
633        assert_eq!(0, mgr.pool().get_access(addr).unwrap());
634    }
635
636    #[test]
637    fn test_config() {
638        let default_cfg = ChannelConfig::new();
639        assert_eq!(
640            ChannelConfig {
641                timeout: Some(Duration::from_secs(DEFAULT_GRPC_REQUEST_TIMEOUT_SECS)),
642                connect_timeout: Some(Duration::from_secs(DEFAULT_GRPC_CONNECT_TIMEOUT_SECS)),
643                concurrency_limit: None,
644                rate_limit: None,
645                initial_stream_window_size: None,
646                initial_connection_window_size: None,
647                http2_keep_alive_interval: Some(Duration::from_secs(30)),
648                http2_keep_alive_timeout: None,
649                http2_keep_alive_while_idle: Some(true),
650                http2_adaptive_window: None,
651                tcp_keepalive: None,
652                tcp_nodelay: true,
653                client_tls: None,
654                max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
655                max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
656                send_compression: false,
657                accept_compression: false,
658            },
659            default_cfg
660        );
661
662        let cfg = default_cfg
663            .timeout(Duration::from_secs(3))
664            .connect_timeout(Duration::from_secs(5))
665            .concurrency_limit(6)
666            .rate_limit(5, Duration::from_secs(1))
667            .initial_stream_window_size(10)
668            .initial_connection_window_size(20)
669            .http2_keep_alive_interval(Duration::from_secs(1))
670            .http2_keep_alive_timeout(Duration::from_secs(3))
671            .http2_keep_alive_while_idle(true)
672            .http2_adaptive_window(true)
673            .tcp_keepalive(Duration::from_secs(2))
674            .tcp_nodelay(false)
675            .client_tls_config(ClientTlsOption {
676                enabled: true,
677                server_ca_cert_path: Some("some_server_path".to_string()),
678                client_cert_path: Some("some_cert_path".to_string()),
679                client_key_path: Some("some_key_path".to_string()),
680                watch: false,
681            });
682
683        assert_eq!(
684            ChannelConfig {
685                timeout: Some(Duration::from_secs(3)),
686                connect_timeout: Some(Duration::from_secs(5)),
687                concurrency_limit: Some(6),
688                rate_limit: Some((5, Duration::from_secs(1))),
689                initial_stream_window_size: Some(10),
690                initial_connection_window_size: Some(20),
691                http2_keep_alive_interval: Some(Duration::from_secs(1)),
692                http2_keep_alive_timeout: Some(Duration::from_secs(3)),
693                http2_keep_alive_while_idle: Some(true),
694                http2_adaptive_window: Some(true),
695                tcp_keepalive: Some(Duration::from_secs(2)),
696                tcp_nodelay: false,
697                client_tls: Some(ClientTlsOption {
698                    enabled: true,
699                    server_ca_cert_path: Some("some_server_path".to_string()),
700                    client_cert_path: Some("some_cert_path".to_string()),
701                    client_key_path: Some("some_key_path".to_string()),
702                    watch: false,
703                }),
704                max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
705                max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
706                send_compression: false,
707                accept_compression: false,
708            },
709            cfg
710        );
711    }
712
713    #[test]
714    fn test_build_endpoint() {
715        let config = ChannelConfig::new()
716            .timeout(Duration::from_secs(3))
717            .connect_timeout(Duration::from_secs(5))
718            .concurrency_limit(6)
719            .rate_limit(5, Duration::from_secs(1))
720            .initial_stream_window_size(10)
721            .initial_connection_window_size(20)
722            .http2_keep_alive_interval(Duration::from_secs(1))
723            .http2_keep_alive_timeout(Duration::from_secs(3))
724            .http2_keep_alive_while_idle(true)
725            .http2_adaptive_window(true)
726            .tcp_keepalive(Duration::from_secs(2))
727            .tcp_nodelay(true);
728        let mgr = ChannelManager::with_config(config, None);
729
730        let res = mgr.build_endpoint("test_addr");
731
732        let _ = res.unwrap();
733    }
734
735    #[tokio::test]
736    async fn test_channel_with_connector() {
737        let mgr = ChannelManager::new();
738
739        let addr = "test_addr";
740        let res = mgr.get(addr);
741        let _ = res.unwrap();
742
743        mgr.retain_channel(|addr, channel| {
744            assert_eq!("test_addr", addr);
745            assert!(channel.use_default_connector());
746            true
747        });
748
749        let (client, _) = tokio::io::duplex(1024);
750        let mut client = Some(hyper_util::rt::TokioIo::new(client));
751        let res = mgr.reset_with_connector(
752            addr,
753            service_fn(move |_| {
754                let client = client.take().unwrap();
755                async move { Ok::<_, std::io::Error>(client) }
756            }),
757        );
758
759        let _ = res.unwrap();
760
761        mgr.retain_channel(|addr, channel| {
762            assert_eq!("test_addr", addr);
763            assert!(!channel.use_default_connector());
764            true
765        });
766    }
767
768    #[tokio::test]
769    async fn test_pool_release_with_channel_recycle() {
770        let mgr = ChannelManager::new();
771
772        let pool_holder = mgr.pool().clone();
773
774        // start channel recycle task
775        let addr = "test_addr";
776        let _ = mgr.get(addr);
777
778        let mgr_clone_1 = mgr.clone();
779        let mgr_clone_2 = mgr.clone();
780        assert_eq!(3, Arc::strong_count(mgr.pool()));
781
782        drop(mgr_clone_1);
783        drop(mgr_clone_2);
784        assert_eq!(3, Arc::strong_count(mgr.pool()));
785
786        drop(mgr);
787
788        // wait for the channel recycle task to finish
789        tokio::time::sleep(Duration::from_millis(10)).await;
790
791        assert_eq!(1, Arc::strong_count(&pool_holder));
792    }
793
794    #[tokio::test]
795    async fn test_pool_release_without_channel_recycle() {
796        let mgr = ChannelManager::new();
797
798        let pool_holder = mgr.pool().clone();
799
800        let mgr_clone_1 = mgr.clone();
801        let mgr_clone_2 = mgr.clone();
802        assert_eq!(2, Arc::strong_count(mgr.pool()));
803
804        drop(mgr_clone_1);
805        drop(mgr_clone_2);
806        assert_eq!(2, Arc::strong_count(mgr.pool()));
807
808        drop(mgr);
809
810        assert_eq!(1, Arc::strong_count(&pool_holder));
811    }
812}