1use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
16use std::sync::Arc;
17use std::time::Duration;
18
19use common_base::readable_size::ReadableSize;
20use common_telemetry::info;
21use dashmap::mapref::entry::Entry;
22use dashmap::DashMap;
23use lazy_static::lazy_static;
24use serde::{Deserialize, Serialize};
25use snafu::{OptionExt, ResultExt};
26use tokio_util::sync::CancellationToken;
27use tonic::transport::{
28 Certificate, Channel as InnerChannel, ClientTlsConfig, Endpoint, Identity, Uri,
29};
30use tower::Service;
31
32use crate::error::{CreateChannelSnafu, InvalidConfigFilePathSnafu, InvalidTlsConfigSnafu, Result};
33
34const RECYCLE_CHANNEL_INTERVAL_SECS: u64 = 60;
35pub const DEFAULT_GRPC_REQUEST_TIMEOUT_SECS: u64 = 10;
36pub const DEFAULT_GRPC_CONNECT_TIMEOUT_SECS: u64 = 1;
37pub const DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
38pub const DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
39
40lazy_static! {
41 static ref ID: AtomicU64 = AtomicU64::new(0);
42}
43
44#[derive(Clone, Debug, Default)]
45pub struct ChannelManager {
46 inner: Arc<Inner>,
47}
48
49#[derive(Debug)]
50struct Inner {
51 id: u64,
52 config: ChannelConfig,
53 client_tls_config: Option<ClientTlsConfig>,
54 pool: Arc<Pool>,
55 channel_recycle_started: AtomicBool,
56 cancel: CancellationToken,
57}
58
59impl Default for Inner {
60 fn default() -> Self {
61 Self::with_config(ChannelConfig::default())
62 }
63}
64
65impl Drop for Inner {
66 fn drop(&mut self) {
67 self.cancel.cancel();
69 }
70}
71
72impl Inner {
73 fn with_config(config: ChannelConfig) -> Self {
74 let id = ID.fetch_add(1, Ordering::Relaxed);
75 let pool = Arc::new(Pool::default());
76 let cancel = CancellationToken::new();
77
78 Self {
79 id,
80 config,
81 client_tls_config: None,
82 pool,
83 channel_recycle_started: AtomicBool::new(false),
84 cancel,
85 }
86 }
87}
88
89impl ChannelManager {
90 pub fn new() -> Self {
91 Default::default()
92 }
93
94 pub fn with_config(config: ChannelConfig) -> Self {
95 let inner = Inner::with_config(config);
96 Self {
97 inner: Arc::new(inner),
98 }
99 }
100
101 pub fn with_tls_config(config: ChannelConfig) -> Result<Self> {
103 let mut inner = Inner::with_config(config.clone());
104
105 let path_config = config.client_tls.context(InvalidTlsConfigSnafu {
107 msg: "no config input",
108 })?;
109
110 if !path_config.enabled {
111 return Ok(Self {
114 inner: Arc::new(inner),
115 });
116 }
117
118 let mut tls_config = ClientTlsConfig::new();
119
120 if let Some(server_ca) = path_config.server_ca_cert_path {
121 let server_root_ca_cert =
122 std::fs::read_to_string(server_ca).context(InvalidConfigFilePathSnafu)?;
123 let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert);
124 tls_config = tls_config.ca_certificate(server_root_ca_cert);
125 }
126
127 if let (Some(client_cert_path), Some(client_key_path)) =
128 (&path_config.client_cert_path, &path_config.client_key_path)
129 {
130 let client_cert =
131 std::fs::read_to_string(client_cert_path).context(InvalidConfigFilePathSnafu)?;
132 let client_key =
133 std::fs::read_to_string(client_key_path).context(InvalidConfigFilePathSnafu)?;
134 let client_identity = Identity::from_pem(client_cert, client_key);
135 tls_config = tls_config.identity(client_identity);
136 }
137
138 inner.client_tls_config = Some(tls_config);
139
140 Ok(Self {
141 inner: Arc::new(inner),
142 })
143 }
144
145 pub fn config(&self) -> &ChannelConfig {
146 &self.inner.config
147 }
148
149 fn pool(&self) -> &Arc<Pool> {
150 &self.inner.pool
151 }
152
153 pub fn get(&self, addr: impl AsRef<str>) -> Result<InnerChannel> {
154 self.trigger_channel_recycling();
155
156 let addr = addr.as_ref();
157 if let Some(inner_ch) = self.pool().get(addr) {
159 return Ok(inner_ch);
160 }
161
162 let entry = match self.pool().entry(addr.to_string()) {
164 Entry::Occupied(entry) => {
165 entry.get().increase_access();
166 entry.into_ref()
167 }
168 Entry::Vacant(entry) => {
169 let endpoint = self.build_endpoint(addr)?;
170 let inner_channel = endpoint.connect_lazy();
171
172 let channel = Channel {
173 channel: inner_channel,
174 access: AtomicUsize::new(1),
175 use_default_connector: true,
176 };
177 entry.insert(channel)
178 }
179 };
180 Ok(entry.channel.clone())
181 }
182
183 pub fn reset_with_connector<C>(
184 &self,
185 addr: impl AsRef<str>,
186 connector: C,
187 ) -> Result<InnerChannel>
188 where
189 C: Service<Uri> + Send + 'static,
190 C::Response: hyper::rt::Read + hyper::rt::Write + Send + Unpin,
191 C::Future: Send + 'static,
192 Box<dyn std::error::Error + Send + Sync>: From<C::Error> + Send + 'static,
193 {
194 let addr = addr.as_ref();
195 let endpoint = self.build_endpoint(addr)?;
196 let inner_channel = endpoint.connect_with_connector_lazy(connector);
197 let channel = Channel {
198 channel: inner_channel.clone(),
199 access: AtomicUsize::new(1),
200 use_default_connector: false,
201 };
202 self.pool().put(addr, channel);
203
204 Ok(inner_channel)
205 }
206
207 pub fn retain_channel<F>(&self, f: F)
208 where
209 F: FnMut(&String, &mut Channel) -> bool,
210 {
211 self.pool().retain_channel(f);
212 }
213
214 fn build_endpoint(&self, addr: &str) -> Result<Endpoint> {
215 let http_prefix = if self.inner.client_tls_config.is_some() {
216 "https"
217 } else {
218 "http"
219 };
220
221 let mut endpoint = Endpoint::new(format!("{http_prefix}://{addr}"))
222 .context(CreateChannelSnafu { addr })?;
223
224 if let Some(dur) = self.config().timeout {
225 endpoint = endpoint.timeout(dur);
226 }
227 if let Some(dur) = self.config().connect_timeout {
228 endpoint = endpoint.connect_timeout(dur);
229 }
230 if let Some(limit) = self.config().concurrency_limit {
231 endpoint = endpoint.concurrency_limit(limit);
232 }
233 if let Some((limit, dur)) = self.config().rate_limit {
234 endpoint = endpoint.rate_limit(limit, dur);
235 }
236 if let Some(size) = self.config().initial_stream_window_size {
237 endpoint = endpoint.initial_stream_window_size(size);
238 }
239 if let Some(size) = self.config().initial_connection_window_size {
240 endpoint = endpoint.initial_connection_window_size(size);
241 }
242 if let Some(dur) = self.config().http2_keep_alive_interval {
243 endpoint = endpoint.http2_keep_alive_interval(dur);
244 }
245 if let Some(dur) = self.config().http2_keep_alive_timeout {
246 endpoint = endpoint.keep_alive_timeout(dur);
247 }
248 if let Some(enabled) = self.config().http2_keep_alive_while_idle {
249 endpoint = endpoint.keep_alive_while_idle(enabled);
250 }
251 if let Some(enabled) = self.config().http2_adaptive_window {
252 endpoint = endpoint.http2_adaptive_window(enabled);
253 }
254 if let Some(tls_config) = &self.inner.client_tls_config {
255 endpoint = endpoint
256 .tls_config(tls_config.clone())
257 .context(CreateChannelSnafu { addr })?;
258 }
259
260 endpoint = endpoint
261 .tcp_keepalive(self.config().tcp_keepalive)
262 .tcp_nodelay(self.config().tcp_nodelay);
263
264 Ok(endpoint)
265 }
266
267 fn trigger_channel_recycling(&self) {
268 if self
269 .inner
270 .channel_recycle_started
271 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
272 .is_err()
273 {
274 return;
275 }
276
277 let pool = self.pool().clone();
278 let cancel = self.inner.cancel.clone();
279 let id = self.inner.id;
280 let _handle = common_runtime::spawn_global(async move {
281 recycle_channel_in_loop(pool, id, cancel, RECYCLE_CHANNEL_INTERVAL_SECS).await;
282 });
283 info!(
284 "ChannelManager: {}, channel recycle is started, running in the background!",
285 self.inner.id
286 );
287 }
288}
289
290#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
291pub struct ClientTlsOption {
292 pub enabled: bool,
294 pub server_ca_cert_path: Option<String>,
295 pub client_cert_path: Option<String>,
296 pub client_key_path: Option<String>,
297}
298
299#[derive(Clone, Debug, PartialEq, Eq)]
300pub struct ChannelConfig {
301 pub timeout: Option<Duration>,
302 pub connect_timeout: Option<Duration>,
303 pub concurrency_limit: Option<usize>,
304 pub rate_limit: Option<(u64, Duration)>,
305 pub initial_stream_window_size: Option<u32>,
306 pub initial_connection_window_size: Option<u32>,
307 pub http2_keep_alive_interval: Option<Duration>,
308 pub http2_keep_alive_timeout: Option<Duration>,
309 pub http2_keep_alive_while_idle: Option<bool>,
310 pub http2_adaptive_window: Option<bool>,
311 pub tcp_keepalive: Option<Duration>,
312 pub tcp_nodelay: bool,
313 pub client_tls: Option<ClientTlsOption>,
314 pub max_recv_message_size: ReadableSize,
316 pub max_send_message_size: ReadableSize,
318 pub send_compression: bool,
319 pub accept_compression: bool,
320}
321
322impl Default for ChannelConfig {
323 fn default() -> Self {
324 Self {
325 timeout: Some(Duration::from_secs(DEFAULT_GRPC_REQUEST_TIMEOUT_SECS)),
326 connect_timeout: Some(Duration::from_secs(DEFAULT_GRPC_CONNECT_TIMEOUT_SECS)),
327 concurrency_limit: None,
328 rate_limit: None,
329 initial_stream_window_size: None,
330 initial_connection_window_size: None,
331 http2_keep_alive_interval: Some(Duration::from_secs(30)),
332 http2_keep_alive_timeout: None,
333 http2_keep_alive_while_idle: Some(true),
334 http2_adaptive_window: None,
335 tcp_keepalive: None,
336 tcp_nodelay: true,
337 client_tls: None,
338 max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
339 max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
340 send_compression: false,
341 accept_compression: false,
342 }
343 }
344}
345
346impl ChannelConfig {
347 pub fn new() -> Self {
348 Default::default()
349 }
350
351 pub fn timeout(mut self, timeout: Duration) -> Self {
353 self.timeout = Some(timeout);
354 self
355 }
356
357 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
361 self.connect_timeout = Some(timeout);
362 self
363 }
364
365 pub fn concurrency_limit(mut self, limit: usize) -> Self {
367 self.concurrency_limit = Some(limit);
368 self
369 }
370
371 pub fn rate_limit(mut self, limit: u64, duration: Duration) -> Self {
373 self.rate_limit = Some((limit, duration));
374 self
375 }
376
377 pub fn initial_stream_window_size(mut self, size: u32) -> Self {
380 self.initial_stream_window_size = Some(size);
381 self
382 }
383
384 pub fn initial_connection_window_size(mut self, size: u32) -> Self {
388 self.initial_connection_window_size = Some(size);
389 self
390 }
391
392 pub fn http2_keep_alive_interval(mut self, duration: Duration) -> Self {
394 self.http2_keep_alive_interval = Some(duration);
395 self
396 }
397
398 pub fn http2_keep_alive_timeout(mut self, duration: Duration) -> Self {
400 self.http2_keep_alive_timeout = Some(duration);
401 self
402 }
403
404 pub fn http2_keep_alive_while_idle(mut self, enabled: bool) -> Self {
406 self.http2_keep_alive_while_idle = Some(enabled);
407 self
408 }
409
410 pub fn http2_adaptive_window(mut self, enabled: bool) -> Self {
412 self.http2_adaptive_window = Some(enabled);
413 self
414 }
415
416 pub fn tcp_keepalive(mut self, duration: Duration) -> Self {
423 self.tcp_keepalive = Some(duration);
424 self
425 }
426
427 pub fn tcp_nodelay(mut self, enabled: bool) -> Self {
431 self.tcp_nodelay = enabled;
432 self
433 }
434
435 pub fn client_tls_config(mut self, client_tls_option: ClientTlsOption) -> Self {
439 self.client_tls = Some(client_tls_option);
440 self
441 }
442}
443
444#[derive(Debug)]
445pub struct Channel {
446 channel: InnerChannel,
447 access: AtomicUsize,
448 use_default_connector: bool,
449}
450
451impl Channel {
452 #[inline]
453 pub fn access(&self) -> usize {
454 self.access.load(Ordering::Relaxed)
455 }
456
457 #[inline]
458 pub fn use_default_connector(&self) -> bool {
459 self.use_default_connector
460 }
461
462 #[inline]
463 pub fn increase_access(&self) {
464 let _ = self.access.fetch_add(1, Ordering::Relaxed);
465 }
466}
467
468#[derive(Debug, Default)]
469struct Pool {
470 channels: DashMap<String, Channel>,
471}
472
473impl Pool {
474 fn get(&self, addr: &str) -> Option<InnerChannel> {
475 let channel = self.channels.get(addr);
476 channel.map(|ch| {
477 ch.increase_access();
478 ch.channel.clone()
479 })
480 }
481
482 fn entry(&self, addr: String) -> Entry<String, Channel> {
483 self.channels.entry(addr)
484 }
485
486 #[cfg(test)]
487 fn get_access(&self, addr: &str) -> Option<usize> {
488 let channel = self.channels.get(addr);
489 channel.map(|ch| ch.access())
490 }
491
492 fn put(&self, addr: &str, channel: Channel) {
493 let _ = self.channels.insert(addr.to_string(), channel);
494 }
495
496 fn retain_channel<F>(&self, f: F)
497 where
498 F: FnMut(&String, &mut Channel) -> bool,
499 {
500 self.channels.retain(f);
501 }
502}
503
504async fn recycle_channel_in_loop(
505 pool: Arc<Pool>,
506 id: u64,
507 cancel: CancellationToken,
508 interval_secs: u64,
509) {
510 let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
511
512 loop {
513 tokio::select! {
514 _ = cancel.cancelled() => {
515 info!("Stop channel recycle, ChannelManager id: {}", id);
516 break;
517 },
518 _ = interval.tick() => {}
519 }
520
521 pool.retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0)
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use tower::service_fn;
528
529 use super::*;
530
531 #[should_panic]
532 #[test]
533 fn test_invalid_addr() {
534 let mgr = ChannelManager::default();
535 let addr = "http://test";
536
537 let _ = mgr.get(addr).unwrap();
538 }
539
540 #[tokio::test]
541 async fn test_access_count() {
542 let mgr = ChannelManager::new();
543 mgr.inner
545 .channel_recycle_started
546 .store(true, Ordering::Relaxed);
547 let mgr = Arc::new(mgr);
548 let addr = "test_uri";
549
550 let mut joins = Vec::with_capacity(10);
551 for _ in 0..10 {
552 let mgr_clone = mgr.clone();
553 let join = tokio::spawn(async move {
554 for _ in 0..100 {
555 let _ = mgr_clone.get(addr);
556 }
557 });
558 joins.push(join);
559 }
560 for join in joins {
561 join.await.unwrap();
562 }
563
564 assert_eq!(1000, mgr.pool().get_access(addr).unwrap());
565
566 mgr.pool()
567 .retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0);
568
569 assert_eq!(0, mgr.pool().get_access(addr).unwrap());
570 }
571
572 #[test]
573 fn test_config() {
574 let default_cfg = ChannelConfig::new();
575 assert_eq!(
576 ChannelConfig {
577 timeout: Some(Duration::from_secs(DEFAULT_GRPC_REQUEST_TIMEOUT_SECS)),
578 connect_timeout: Some(Duration::from_secs(DEFAULT_GRPC_CONNECT_TIMEOUT_SECS)),
579 concurrency_limit: None,
580 rate_limit: None,
581 initial_stream_window_size: None,
582 initial_connection_window_size: None,
583 http2_keep_alive_interval: Some(Duration::from_secs(30)),
584 http2_keep_alive_timeout: None,
585 http2_keep_alive_while_idle: Some(true),
586 http2_adaptive_window: None,
587 tcp_keepalive: None,
588 tcp_nodelay: true,
589 client_tls: None,
590 max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
591 max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
592 send_compression: false,
593 accept_compression: false,
594 },
595 default_cfg
596 );
597
598 let cfg = default_cfg
599 .timeout(Duration::from_secs(3))
600 .connect_timeout(Duration::from_secs(5))
601 .concurrency_limit(6)
602 .rate_limit(5, Duration::from_secs(1))
603 .initial_stream_window_size(10)
604 .initial_connection_window_size(20)
605 .http2_keep_alive_interval(Duration::from_secs(1))
606 .http2_keep_alive_timeout(Duration::from_secs(3))
607 .http2_keep_alive_while_idle(true)
608 .http2_adaptive_window(true)
609 .tcp_keepalive(Duration::from_secs(2))
610 .tcp_nodelay(false)
611 .client_tls_config(ClientTlsOption {
612 enabled: true,
613 server_ca_cert_path: Some("some_server_path".to_string()),
614 client_cert_path: Some("some_cert_path".to_string()),
615 client_key_path: Some("some_key_path".to_string()),
616 });
617
618 assert_eq!(
619 ChannelConfig {
620 timeout: Some(Duration::from_secs(3)),
621 connect_timeout: Some(Duration::from_secs(5)),
622 concurrency_limit: Some(6),
623 rate_limit: Some((5, Duration::from_secs(1))),
624 initial_stream_window_size: Some(10),
625 initial_connection_window_size: Some(20),
626 http2_keep_alive_interval: Some(Duration::from_secs(1)),
627 http2_keep_alive_timeout: Some(Duration::from_secs(3)),
628 http2_keep_alive_while_idle: Some(true),
629 http2_adaptive_window: Some(true),
630 tcp_keepalive: Some(Duration::from_secs(2)),
631 tcp_nodelay: false,
632 client_tls: Some(ClientTlsOption {
633 enabled: true,
634 server_ca_cert_path: Some("some_server_path".to_string()),
635 client_cert_path: Some("some_cert_path".to_string()),
636 client_key_path: Some("some_key_path".to_string()),
637 }),
638 max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
639 max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
640 send_compression: false,
641 accept_compression: false,
642 },
643 cfg
644 );
645 }
646
647 #[test]
648 fn test_build_endpoint() {
649 let config = ChannelConfig::new()
650 .timeout(Duration::from_secs(3))
651 .connect_timeout(Duration::from_secs(5))
652 .concurrency_limit(6)
653 .rate_limit(5, Duration::from_secs(1))
654 .initial_stream_window_size(10)
655 .initial_connection_window_size(20)
656 .http2_keep_alive_interval(Duration::from_secs(1))
657 .http2_keep_alive_timeout(Duration::from_secs(3))
658 .http2_keep_alive_while_idle(true)
659 .http2_adaptive_window(true)
660 .tcp_keepalive(Duration::from_secs(2))
661 .tcp_nodelay(true);
662 let mgr = ChannelManager::with_config(config);
663
664 let res = mgr.build_endpoint("test_addr");
665
666 let _ = res.unwrap();
667 }
668
669 #[tokio::test]
670 async fn test_channel_with_connector() {
671 let mgr = ChannelManager::new();
672
673 let addr = "test_addr";
674 let res = mgr.get(addr);
675 let _ = res.unwrap();
676
677 mgr.retain_channel(|addr, channel| {
678 assert_eq!("test_addr", addr);
679 assert!(channel.use_default_connector());
680 true
681 });
682
683 let (client, _) = tokio::io::duplex(1024);
684 let mut client = Some(hyper_util::rt::TokioIo::new(client));
685 let res = mgr.reset_with_connector(
686 addr,
687 service_fn(move |_| {
688 let client = client.take().unwrap();
689 async move { Ok::<_, std::io::Error>(client) }
690 }),
691 );
692
693 let _ = res.unwrap();
694
695 mgr.retain_channel(|addr, channel| {
696 assert_eq!("test_addr", addr);
697 assert!(!channel.use_default_connector());
698 true
699 });
700 }
701
702 #[tokio::test]
703 async fn test_pool_release_with_channel_recycle() {
704 let mgr = ChannelManager::new();
705
706 let pool_holder = mgr.pool().clone();
707
708 let addr = "test_addr";
710 let _ = mgr.get(addr);
711
712 let mgr_clone_1 = mgr.clone();
713 let mgr_clone_2 = mgr.clone();
714 assert_eq!(3, Arc::strong_count(mgr.pool()));
715
716 drop(mgr_clone_1);
717 drop(mgr_clone_2);
718 assert_eq!(3, Arc::strong_count(mgr.pool()));
719
720 drop(mgr);
721
722 tokio::time::sleep(Duration::from_millis(10)).await;
724
725 assert_eq!(1, Arc::strong_count(&pool_holder));
726 }
727
728 #[tokio::test]
729 async fn test_pool_release_without_channel_recycle() {
730 let mgr = ChannelManager::new();
731
732 let pool_holder = mgr.pool().clone();
733
734 let mgr_clone_1 = mgr.clone();
735 let mgr_clone_2 = mgr.clone();
736 assert_eq!(2, Arc::strong_count(mgr.pool()));
737
738 drop(mgr_clone_1);
739 drop(mgr_clone_2);
740 assert_eq!(2, Arc::strong_count(mgr.pool()));
741
742 drop(mgr);
743
744 assert_eq!(1, Arc::strong_count(&pool_holder));
745 }
746}