1use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
17use std::time::Duration;
18
19use common_base::readable_size::ReadableSize;
20use common_telemetry::info;
21use dashmap::DashMap;
22use dashmap::mapref::entry::Entry;
23use lazy_static::lazy_static;
24use serde::{Deserialize, Serialize};
25use snafu::ResultExt;
26use tokio_util::sync::CancellationToken;
27use tonic::transport::{
28 Certificate, Channel as InnerChannel, ClientTlsConfig, Endpoint, Identity, Uri,
29};
30use tower::Service;
31
32use crate::error::{CreateChannelSnafu, InvalidConfigFilePathSnafu, Result};
33
34const RECYCLE_CHANNEL_INTERVAL_SECS: u64 = 60;
35pub const DEFAULT_GRPC_REQUEST_TIMEOUT_SECS: u64 = 10;
36pub const DEFAULT_GRPC_CONNECT_TIMEOUT_SECS: u64 = 1;
37pub const DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
38pub const DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
39
40lazy_static! {
41 static ref ID: AtomicU64 = AtomicU64::new(0);
42}
43
44#[derive(Clone, Debug, Default)]
45pub struct ChannelManager {
46 inner: Arc<Inner>,
47}
48
49#[derive(Debug)]
50struct Inner {
51 id: u64,
52 config: ChannelConfig,
53 client_tls_config: Option<ClientTlsConfig>,
54 pool: Arc<Pool>,
55 channel_recycle_started: AtomicBool,
56 cancel: CancellationToken,
57}
58
59impl Default for Inner {
60 fn default() -> Self {
61 Self::with_config(ChannelConfig::default())
62 }
63}
64
65impl Drop for Inner {
66 fn drop(&mut self) {
67 self.cancel.cancel();
69 }
70}
71
72impl Inner {
73 fn with_config(config: ChannelConfig) -> Self {
74 let id = ID.fetch_add(1, Ordering::Relaxed);
75 let pool = Arc::new(Pool::default());
76 let cancel = CancellationToken::new();
77
78 Self {
79 id,
80 config,
81 client_tls_config: None,
82 pool,
83 channel_recycle_started: AtomicBool::new(false),
84 cancel,
85 }
86 }
87}
88
89impl ChannelManager {
90 pub fn new() -> Self {
91 Default::default()
92 }
93
94 pub fn with_config(config: ChannelConfig, tls_config: Option<ClientTlsConfig>) -> Self {
97 let mut inner = Inner::with_config(config.clone());
98 if let Some(tls_config) = tls_config {
99 inner.client_tls_config = Some(tls_config);
100 }
101 Self {
102 inner: Arc::new(inner),
103 }
104 }
105
106 pub fn config(&self) -> &ChannelConfig {
107 &self.inner.config
108 }
109
110 fn pool(&self) -> &Arc<Pool> {
111 &self.inner.pool
112 }
113
114 pub fn get(&self, addr: impl AsRef<str>) -> Result<InnerChannel> {
115 self.trigger_channel_recycling();
116
117 let addr = addr.as_ref();
118 if let Some(inner_ch) = self.pool().get(addr) {
120 return Ok(inner_ch);
121 }
122
123 let entry = match self.pool().entry(addr.to_string()) {
125 Entry::Occupied(entry) => {
126 entry.get().increase_access();
127 entry.into_ref()
128 }
129 Entry::Vacant(entry) => {
130 let endpoint = self.build_endpoint(addr)?;
131 let inner_channel = endpoint.connect_lazy();
132
133 let channel = Channel {
134 channel: inner_channel,
135 access: AtomicUsize::new(1),
136 use_default_connector: true,
137 };
138 entry.insert(channel)
139 }
140 };
141 Ok(entry.channel.clone())
142 }
143
144 pub fn reset_with_connector<C>(
145 &self,
146 addr: impl AsRef<str>,
147 connector: C,
148 ) -> Result<InnerChannel>
149 where
150 C: Service<Uri> + Send + 'static,
151 C::Response: hyper::rt::Read + hyper::rt::Write + Send + Unpin,
152 C::Future: Send + 'static,
153 Box<dyn std::error::Error + Send + Sync>: From<C::Error> + Send + 'static,
154 {
155 let addr = addr.as_ref();
156 let endpoint = self.build_endpoint(addr)?;
157 let inner_channel = endpoint.connect_with_connector_lazy(connector);
158 let channel = Channel {
159 channel: inner_channel.clone(),
160 access: AtomicUsize::new(1),
161 use_default_connector: false,
162 };
163 self.pool().put(addr, channel);
164
165 Ok(inner_channel)
166 }
167
168 pub fn retain_channel<F>(&self, f: F)
169 where
170 F: FnMut(&String, &mut Channel) -> bool,
171 {
172 self.pool().retain_channel(f);
173 }
174
175 fn build_endpoint(&self, addr: &str) -> Result<Endpoint> {
176 let http_prefix = if self.inner.client_tls_config.is_some() {
177 "https"
178 } else {
179 "http"
180 };
181
182 let mut endpoint = Endpoint::new(format!("{http_prefix}://{addr}"))
183 .context(CreateChannelSnafu { addr })?;
184
185 if let Some(dur) = self.config().timeout {
186 endpoint = endpoint.timeout(dur);
187 }
188 if let Some(dur) = self.config().connect_timeout {
189 endpoint = endpoint.connect_timeout(dur);
190 }
191 if let Some(limit) = self.config().concurrency_limit {
192 endpoint = endpoint.concurrency_limit(limit);
193 }
194 if let Some((limit, dur)) = self.config().rate_limit {
195 endpoint = endpoint.rate_limit(limit, dur);
196 }
197 if let Some(size) = self.config().initial_stream_window_size {
198 endpoint = endpoint.initial_stream_window_size(size);
199 }
200 if let Some(size) = self.config().initial_connection_window_size {
201 endpoint = endpoint.initial_connection_window_size(size);
202 }
203 if let Some(dur) = self.config().http2_keep_alive_interval {
204 endpoint = endpoint.http2_keep_alive_interval(dur);
205 }
206 if let Some(dur) = self.config().http2_keep_alive_timeout {
207 endpoint = endpoint.keep_alive_timeout(dur);
208 }
209 if let Some(enabled) = self.config().http2_keep_alive_while_idle {
210 endpoint = endpoint.keep_alive_while_idle(enabled);
211 }
212 if let Some(enabled) = self.config().http2_adaptive_window {
213 endpoint = endpoint.http2_adaptive_window(enabled);
214 }
215 if let Some(tls_config) = &self.inner.client_tls_config {
216 endpoint = endpoint
217 .tls_config(tls_config.clone())
218 .context(CreateChannelSnafu { addr })?;
219 }
220
221 endpoint = endpoint
222 .tcp_keepalive(self.config().tcp_keepalive)
223 .tcp_nodelay(self.config().tcp_nodelay);
224
225 Ok(endpoint)
226 }
227
228 fn trigger_channel_recycling(&self) {
229 if self
230 .inner
231 .channel_recycle_started
232 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
233 .is_err()
234 {
235 return;
236 }
237
238 let pool = self.pool().clone();
239 let cancel = self.inner.cancel.clone();
240 let id = self.inner.id;
241 let _handle = common_runtime::spawn_global(async move {
242 recycle_channel_in_loop(pool, id, cancel, RECYCLE_CHANNEL_INTERVAL_SECS).await;
243 });
244 info!(
245 "ChannelManager: {}, channel recycle is started, running in the background!",
246 self.inner.id
247 );
248 }
249}
250
251pub fn load_tls_config(tls_option: Option<&ClientTlsOption>) -> Result<Option<ClientTlsConfig>> {
252 let path_config = match tls_option {
253 Some(path_config) if path_config.enabled => path_config,
254 _ => return Ok(None),
255 };
256
257 let mut tls_config = ClientTlsConfig::new();
258
259 if let Some(server_ca) = &path_config.server_ca_cert_path {
260 let server_root_ca_cert =
261 std::fs::read_to_string(server_ca).context(InvalidConfigFilePathSnafu)?;
262 let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert);
263 tls_config = tls_config.ca_certificate(server_root_ca_cert);
264 }
265
266 if let (Some(client_cert_path), Some(client_key_path)) =
267 (&path_config.client_cert_path, &path_config.client_key_path)
268 {
269 let client_cert =
270 std::fs::read_to_string(client_cert_path).context(InvalidConfigFilePathSnafu)?;
271 let client_key =
272 std::fs::read_to_string(client_key_path).context(InvalidConfigFilePathSnafu)?;
273 let client_identity = Identity::from_pem(client_cert, client_key);
274 tls_config = tls_config.identity(client_identity);
275 }
276 Ok(Some(tls_config))
277}
278
279#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
280pub struct ClientTlsOption {
281 pub enabled: bool,
283 pub server_ca_cert_path: Option<String>,
284 pub client_cert_path: Option<String>,
285 pub client_key_path: Option<String>,
286}
287
288#[derive(Clone, Debug, PartialEq, Eq)]
289pub struct ChannelConfig {
290 pub timeout: Option<Duration>,
291 pub connect_timeout: Option<Duration>,
292 pub concurrency_limit: Option<usize>,
293 pub rate_limit: Option<(u64, Duration)>,
294 pub initial_stream_window_size: Option<u32>,
295 pub initial_connection_window_size: Option<u32>,
296 pub http2_keep_alive_interval: Option<Duration>,
297 pub http2_keep_alive_timeout: Option<Duration>,
298 pub http2_keep_alive_while_idle: Option<bool>,
299 pub http2_adaptive_window: Option<bool>,
300 pub tcp_keepalive: Option<Duration>,
301 pub tcp_nodelay: bool,
302 pub client_tls: Option<ClientTlsOption>,
303 pub max_recv_message_size: ReadableSize,
305 pub max_send_message_size: ReadableSize,
307 pub send_compression: bool,
308 pub accept_compression: bool,
309}
310
311impl Default for ChannelConfig {
312 fn default() -> Self {
313 Self {
314 timeout: Some(Duration::from_secs(DEFAULT_GRPC_REQUEST_TIMEOUT_SECS)),
315 connect_timeout: Some(Duration::from_secs(DEFAULT_GRPC_CONNECT_TIMEOUT_SECS)),
316 concurrency_limit: None,
317 rate_limit: None,
318 initial_stream_window_size: None,
319 initial_connection_window_size: None,
320 http2_keep_alive_interval: Some(Duration::from_secs(30)),
321 http2_keep_alive_timeout: None,
322 http2_keep_alive_while_idle: Some(true),
323 http2_adaptive_window: None,
324 tcp_keepalive: None,
325 tcp_nodelay: true,
326 client_tls: None,
327 max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
328 max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
329 send_compression: false,
330 accept_compression: false,
331 }
332 }
333}
334
335impl ChannelConfig {
336 pub fn new() -> Self {
337 Default::default()
338 }
339
340 pub fn timeout(mut self, timeout: Duration) -> Self {
342 self.timeout = Some(timeout);
343 self
344 }
345
346 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
350 self.connect_timeout = Some(timeout);
351 self
352 }
353
354 pub fn concurrency_limit(mut self, limit: usize) -> Self {
356 self.concurrency_limit = Some(limit);
357 self
358 }
359
360 pub fn rate_limit(mut self, limit: u64, duration: Duration) -> Self {
362 self.rate_limit = Some((limit, duration));
363 self
364 }
365
366 pub fn initial_stream_window_size(mut self, size: u32) -> Self {
369 self.initial_stream_window_size = Some(size);
370 self
371 }
372
373 pub fn initial_connection_window_size(mut self, size: u32) -> Self {
377 self.initial_connection_window_size = Some(size);
378 self
379 }
380
381 pub fn http2_keep_alive_interval(mut self, duration: Duration) -> Self {
383 self.http2_keep_alive_interval = Some(duration);
384 self
385 }
386
387 pub fn http2_keep_alive_timeout(mut self, duration: Duration) -> Self {
389 self.http2_keep_alive_timeout = Some(duration);
390 self
391 }
392
393 pub fn http2_keep_alive_while_idle(mut self, enabled: bool) -> Self {
395 self.http2_keep_alive_while_idle = Some(enabled);
396 self
397 }
398
399 pub fn http2_adaptive_window(mut self, enabled: bool) -> Self {
401 self.http2_adaptive_window = Some(enabled);
402 self
403 }
404
405 pub fn tcp_keepalive(mut self, duration: Duration) -> Self {
412 self.tcp_keepalive = Some(duration);
413 self
414 }
415
416 pub fn tcp_nodelay(mut self, enabled: bool) -> Self {
420 self.tcp_nodelay = enabled;
421 self
422 }
423
424 pub fn client_tls_config(mut self, client_tls_option: ClientTlsOption) -> Self {
428 self.client_tls = Some(client_tls_option);
429 self
430 }
431}
432
433#[derive(Debug)]
434pub struct Channel {
435 channel: InnerChannel,
436 access: AtomicUsize,
437 use_default_connector: bool,
438}
439
440impl Channel {
441 #[inline]
442 pub fn access(&self) -> usize {
443 self.access.load(Ordering::Relaxed)
444 }
445
446 #[inline]
447 pub fn use_default_connector(&self) -> bool {
448 self.use_default_connector
449 }
450
451 #[inline]
452 pub fn increase_access(&self) {
453 let _ = self.access.fetch_add(1, Ordering::Relaxed);
454 }
455}
456
457#[derive(Debug, Default)]
458struct Pool {
459 channels: DashMap<String, Channel>,
460}
461
462impl Pool {
463 fn get(&self, addr: &str) -> Option<InnerChannel> {
464 let channel = self.channels.get(addr);
465 channel.map(|ch| {
466 ch.increase_access();
467 ch.channel.clone()
468 })
469 }
470
471 fn entry(&self, addr: String) -> Entry<'_, String, Channel> {
472 self.channels.entry(addr)
473 }
474
475 #[cfg(test)]
476 fn get_access(&self, addr: &str) -> Option<usize> {
477 let channel = self.channels.get(addr);
478 channel.map(|ch| ch.access())
479 }
480
481 fn put(&self, addr: &str, channel: Channel) {
482 let _ = self.channels.insert(addr.to_string(), channel);
483 }
484
485 fn retain_channel<F>(&self, f: F)
486 where
487 F: FnMut(&String, &mut Channel) -> bool,
488 {
489 self.channels.retain(f);
490 }
491}
492
493async fn recycle_channel_in_loop(
494 pool: Arc<Pool>,
495 id: u64,
496 cancel: CancellationToken,
497 interval_secs: u64,
498) {
499 let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
500
501 loop {
502 tokio::select! {
503 _ = cancel.cancelled() => {
504 info!("Stop channel recycle, ChannelManager id: {}", id);
505 break;
506 },
507 _ = interval.tick() => {}
508 }
509
510 pool.retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0)
511 }
512}
513
514#[cfg(test)]
515mod tests {
516 use tower::service_fn;
517
518 use super::*;
519
520 #[should_panic]
521 #[test]
522 fn test_invalid_addr() {
523 let mgr = ChannelManager::default();
524 let addr = "http://test";
525
526 let _ = mgr.get(addr).unwrap();
527 }
528
529 #[tokio::test]
530 async fn test_access_count() {
531 let mgr = ChannelManager::new();
532 mgr.inner
534 .channel_recycle_started
535 .store(true, Ordering::Relaxed);
536 let mgr = Arc::new(mgr);
537 let addr = "test_uri";
538
539 let mut joins = Vec::with_capacity(10);
540 for _ in 0..10 {
541 let mgr_clone = mgr.clone();
542 let join = tokio::spawn(async move {
543 for _ in 0..100 {
544 let _ = mgr_clone.get(addr);
545 }
546 });
547 joins.push(join);
548 }
549 for join in joins {
550 join.await.unwrap();
551 }
552
553 assert_eq!(1000, mgr.pool().get_access(addr).unwrap());
554
555 mgr.pool()
556 .retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0);
557
558 assert_eq!(0, mgr.pool().get_access(addr).unwrap());
559 }
560
561 #[test]
562 fn test_config() {
563 let default_cfg = ChannelConfig::new();
564 assert_eq!(
565 ChannelConfig {
566 timeout: Some(Duration::from_secs(DEFAULT_GRPC_REQUEST_TIMEOUT_SECS)),
567 connect_timeout: Some(Duration::from_secs(DEFAULT_GRPC_CONNECT_TIMEOUT_SECS)),
568 concurrency_limit: None,
569 rate_limit: None,
570 initial_stream_window_size: None,
571 initial_connection_window_size: None,
572 http2_keep_alive_interval: Some(Duration::from_secs(30)),
573 http2_keep_alive_timeout: None,
574 http2_keep_alive_while_idle: Some(true),
575 http2_adaptive_window: None,
576 tcp_keepalive: None,
577 tcp_nodelay: true,
578 client_tls: None,
579 max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
580 max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
581 send_compression: false,
582 accept_compression: false,
583 },
584 default_cfg
585 );
586
587 let cfg = default_cfg
588 .timeout(Duration::from_secs(3))
589 .connect_timeout(Duration::from_secs(5))
590 .concurrency_limit(6)
591 .rate_limit(5, Duration::from_secs(1))
592 .initial_stream_window_size(10)
593 .initial_connection_window_size(20)
594 .http2_keep_alive_interval(Duration::from_secs(1))
595 .http2_keep_alive_timeout(Duration::from_secs(3))
596 .http2_keep_alive_while_idle(true)
597 .http2_adaptive_window(true)
598 .tcp_keepalive(Duration::from_secs(2))
599 .tcp_nodelay(false)
600 .client_tls_config(ClientTlsOption {
601 enabled: true,
602 server_ca_cert_path: Some("some_server_path".to_string()),
603 client_cert_path: Some("some_cert_path".to_string()),
604 client_key_path: Some("some_key_path".to_string()),
605 });
606
607 assert_eq!(
608 ChannelConfig {
609 timeout: Some(Duration::from_secs(3)),
610 connect_timeout: Some(Duration::from_secs(5)),
611 concurrency_limit: Some(6),
612 rate_limit: Some((5, Duration::from_secs(1))),
613 initial_stream_window_size: Some(10),
614 initial_connection_window_size: Some(20),
615 http2_keep_alive_interval: Some(Duration::from_secs(1)),
616 http2_keep_alive_timeout: Some(Duration::from_secs(3)),
617 http2_keep_alive_while_idle: Some(true),
618 http2_adaptive_window: Some(true),
619 tcp_keepalive: Some(Duration::from_secs(2)),
620 tcp_nodelay: false,
621 client_tls: Some(ClientTlsOption {
622 enabled: true,
623 server_ca_cert_path: Some("some_server_path".to_string()),
624 client_cert_path: Some("some_cert_path".to_string()),
625 client_key_path: Some("some_key_path".to_string()),
626 }),
627 max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
628 max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
629 send_compression: false,
630 accept_compression: false,
631 },
632 cfg
633 );
634 }
635
636 #[test]
637 fn test_build_endpoint() {
638 let config = ChannelConfig::new()
639 .timeout(Duration::from_secs(3))
640 .connect_timeout(Duration::from_secs(5))
641 .concurrency_limit(6)
642 .rate_limit(5, Duration::from_secs(1))
643 .initial_stream_window_size(10)
644 .initial_connection_window_size(20)
645 .http2_keep_alive_interval(Duration::from_secs(1))
646 .http2_keep_alive_timeout(Duration::from_secs(3))
647 .http2_keep_alive_while_idle(true)
648 .http2_adaptive_window(true)
649 .tcp_keepalive(Duration::from_secs(2))
650 .tcp_nodelay(true);
651 let mgr = ChannelManager::with_config(config, None);
652
653 let res = mgr.build_endpoint("test_addr");
654
655 let _ = res.unwrap();
656 }
657
658 #[tokio::test]
659 async fn test_channel_with_connector() {
660 let mgr = ChannelManager::new();
661
662 let addr = "test_addr";
663 let res = mgr.get(addr);
664 let _ = res.unwrap();
665
666 mgr.retain_channel(|addr, channel| {
667 assert_eq!("test_addr", addr);
668 assert!(channel.use_default_connector());
669 true
670 });
671
672 let (client, _) = tokio::io::duplex(1024);
673 let mut client = Some(hyper_util::rt::TokioIo::new(client));
674 let res = mgr.reset_with_connector(
675 addr,
676 service_fn(move |_| {
677 let client = client.take().unwrap();
678 async move { Ok::<_, std::io::Error>(client) }
679 }),
680 );
681
682 let _ = res.unwrap();
683
684 mgr.retain_channel(|addr, channel| {
685 assert_eq!("test_addr", addr);
686 assert!(!channel.use_default_connector());
687 true
688 });
689 }
690
691 #[tokio::test]
692 async fn test_pool_release_with_channel_recycle() {
693 let mgr = ChannelManager::new();
694
695 let pool_holder = mgr.pool().clone();
696
697 let addr = "test_addr";
699 let _ = mgr.get(addr);
700
701 let mgr_clone_1 = mgr.clone();
702 let mgr_clone_2 = mgr.clone();
703 assert_eq!(3, Arc::strong_count(mgr.pool()));
704
705 drop(mgr_clone_1);
706 drop(mgr_clone_2);
707 assert_eq!(3, Arc::strong_count(mgr.pool()));
708
709 drop(mgr);
710
711 tokio::time::sleep(Duration::from_millis(10)).await;
713
714 assert_eq!(1, Arc::strong_count(&pool_holder));
715 }
716
717 #[tokio::test]
718 async fn test_pool_release_without_channel_recycle() {
719 let mgr = ChannelManager::new();
720
721 let pool_holder = mgr.pool().clone();
722
723 let mgr_clone_1 = mgr.clone();
724 let mgr_clone_2 = mgr.clone();
725 assert_eq!(2, Arc::strong_count(mgr.pool()));
726
727 drop(mgr_clone_1);
728 drop(mgr_clone_2);
729 assert_eq!(2, Arc::strong_count(mgr.pool()));
730
731 drop(mgr);
732
733 assert_eq!(1, Arc::strong_count(&pool_holder));
734 }
735}