1use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
16use std::sync::Arc;
17use std::time::Duration;
18
19use common_base::readable_size::ReadableSize;
20use common_telemetry::info;
21use dashmap::mapref::entry::Entry;
22use dashmap::DashMap;
23use lazy_static::lazy_static;
24use snafu::{OptionExt, ResultExt};
25use tokio_util::sync::CancellationToken;
26use tonic::transport::{
27 Certificate, Channel as InnerChannel, ClientTlsConfig, Endpoint, Identity, Uri,
28};
29use tower::Service;
30
31use crate::error::{CreateChannelSnafu, InvalidConfigFilePathSnafu, InvalidTlsConfigSnafu, Result};
32
33const RECYCLE_CHANNEL_INTERVAL_SECS: u64 = 60;
34pub const DEFAULT_GRPC_REQUEST_TIMEOUT_SECS: u64 = 10;
35pub const DEFAULT_GRPC_CONNECT_TIMEOUT_SECS: u64 = 1;
36pub const DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
37pub const DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE: ReadableSize = ReadableSize::mb(512);
38
39lazy_static! {
40 static ref ID: AtomicU64 = AtomicU64::new(0);
41}
42
43#[derive(Clone, Debug, Default)]
44pub struct ChannelManager {
45 inner: Arc<Inner>,
46}
47
48#[derive(Debug)]
49struct Inner {
50 id: u64,
51 config: ChannelConfig,
52 client_tls_config: Option<ClientTlsConfig>,
53 pool: Arc<Pool>,
54 channel_recycle_started: AtomicBool,
55 cancel: CancellationToken,
56}
57
58impl Default for Inner {
59 fn default() -> Self {
60 Self::with_config(ChannelConfig::default())
61 }
62}
63
64impl Drop for Inner {
65 fn drop(&mut self) {
66 self.cancel.cancel();
68 }
69}
70
71impl Inner {
72 fn with_config(config: ChannelConfig) -> Self {
73 let id = ID.fetch_add(1, Ordering::Relaxed);
74 let pool = Arc::new(Pool::default());
75 let cancel = CancellationToken::new();
76
77 Self {
78 id,
79 config,
80 client_tls_config: None,
81 pool,
82 channel_recycle_started: AtomicBool::new(false),
83 cancel,
84 }
85 }
86}
87
88impl ChannelManager {
89 pub fn new() -> Self {
90 Default::default()
91 }
92
93 pub fn with_config(config: ChannelConfig) -> Self {
94 let inner = Inner::with_config(config);
95 Self {
96 inner: Arc::new(inner),
97 }
98 }
99
100 pub fn with_tls_config(config: ChannelConfig) -> Result<Self> {
101 let mut inner = Inner::with_config(config.clone());
102
103 let path_config = config.client_tls.context(InvalidTlsConfigSnafu {
105 msg: "no config input",
106 })?;
107
108 let server_root_ca_cert = std::fs::read_to_string(path_config.server_ca_cert_path)
109 .context(InvalidConfigFilePathSnafu)?;
110 let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert);
111 let client_cert = std::fs::read_to_string(path_config.client_cert_path)
112 .context(InvalidConfigFilePathSnafu)?;
113 let client_key = std::fs::read_to_string(path_config.client_key_path)
114 .context(InvalidConfigFilePathSnafu)?;
115 let client_identity = Identity::from_pem(client_cert, client_key);
116
117 inner.client_tls_config = Some(
118 ClientTlsConfig::new()
119 .ca_certificate(server_root_ca_cert)
120 .identity(client_identity),
121 );
122
123 Ok(Self {
124 inner: Arc::new(inner),
125 })
126 }
127
128 pub fn config(&self) -> &ChannelConfig {
129 &self.inner.config
130 }
131
132 fn pool(&self) -> &Arc<Pool> {
133 &self.inner.pool
134 }
135
136 pub fn get(&self, addr: impl AsRef<str>) -> Result<InnerChannel> {
137 self.trigger_channel_recycling();
138
139 let addr = addr.as_ref();
140 if let Some(inner_ch) = self.pool().get(addr) {
142 return Ok(inner_ch);
143 }
144
145 let entry = match self.pool().entry(addr.to_string()) {
147 Entry::Occupied(entry) => {
148 entry.get().increase_access();
149 entry.into_ref()
150 }
151 Entry::Vacant(entry) => {
152 let endpoint = self.build_endpoint(addr)?;
153 let inner_channel = endpoint.connect_lazy();
154
155 let channel = Channel {
156 channel: inner_channel,
157 access: AtomicUsize::new(1),
158 use_default_connector: true,
159 };
160 entry.insert(channel)
161 }
162 };
163 Ok(entry.channel.clone())
164 }
165
166 pub fn reset_with_connector<C>(
167 &self,
168 addr: impl AsRef<str>,
169 connector: C,
170 ) -> Result<InnerChannel>
171 where
172 C: Service<Uri> + Send + 'static,
173 C::Response: hyper::rt::Read + hyper::rt::Write + Send + Unpin,
174 C::Future: Send + 'static,
175 Box<dyn std::error::Error + Send + Sync>: From<C::Error> + Send + 'static,
176 {
177 let addr = addr.as_ref();
178 let endpoint = self.build_endpoint(addr)?;
179 let inner_channel = endpoint.connect_with_connector_lazy(connector);
180 let channel = Channel {
181 channel: inner_channel.clone(),
182 access: AtomicUsize::new(1),
183 use_default_connector: false,
184 };
185 self.pool().put(addr, channel);
186
187 Ok(inner_channel)
188 }
189
190 pub fn retain_channel<F>(&self, f: F)
191 where
192 F: FnMut(&String, &mut Channel) -> bool,
193 {
194 self.pool().retain_channel(f);
195 }
196
197 fn build_endpoint(&self, addr: &str) -> Result<Endpoint> {
198 let http_prefix = if self.inner.client_tls_config.is_some() {
199 "https"
200 } else {
201 "http"
202 };
203
204 let mut endpoint = Endpoint::new(format!("{http_prefix}://{addr}"))
205 .context(CreateChannelSnafu { addr })?;
206
207 if let Some(dur) = self.config().timeout {
208 endpoint = endpoint.timeout(dur);
209 }
210 if let Some(dur) = self.config().connect_timeout {
211 endpoint = endpoint.connect_timeout(dur);
212 }
213 if let Some(limit) = self.config().concurrency_limit {
214 endpoint = endpoint.concurrency_limit(limit);
215 }
216 if let Some((limit, dur)) = self.config().rate_limit {
217 endpoint = endpoint.rate_limit(limit, dur);
218 }
219 if let Some(size) = self.config().initial_stream_window_size {
220 endpoint = endpoint.initial_stream_window_size(size);
221 }
222 if let Some(size) = self.config().initial_connection_window_size {
223 endpoint = endpoint.initial_connection_window_size(size);
224 }
225 if let Some(dur) = self.config().http2_keep_alive_interval {
226 endpoint = endpoint.http2_keep_alive_interval(dur);
227 }
228 if let Some(dur) = self.config().http2_keep_alive_timeout {
229 endpoint = endpoint.keep_alive_timeout(dur);
230 }
231 if let Some(enabled) = self.config().http2_keep_alive_while_idle {
232 endpoint = endpoint.keep_alive_while_idle(enabled);
233 }
234 if let Some(enabled) = self.config().http2_adaptive_window {
235 endpoint = endpoint.http2_adaptive_window(enabled);
236 }
237 if let Some(tls_config) = &self.inner.client_tls_config {
238 endpoint = endpoint
239 .tls_config(tls_config.clone())
240 .context(CreateChannelSnafu { addr })?;
241 }
242
243 endpoint = endpoint
244 .tcp_keepalive(self.config().tcp_keepalive)
245 .tcp_nodelay(self.config().tcp_nodelay);
246
247 Ok(endpoint)
248 }
249
250 fn trigger_channel_recycling(&self) {
251 if self
252 .inner
253 .channel_recycle_started
254 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
255 .is_err()
256 {
257 return;
258 }
259
260 let pool = self.pool().clone();
261 let cancel = self.inner.cancel.clone();
262 let id = self.inner.id;
263 let _handle = common_runtime::spawn_global(async move {
264 recycle_channel_in_loop(pool, id, cancel, RECYCLE_CHANNEL_INTERVAL_SECS).await;
265 });
266 info!(
267 "ChannelManager: {}, channel recycle is started, running in the background!",
268 self.inner.id
269 );
270 }
271}
272
273#[derive(Clone, Debug, PartialEq, Eq)]
274pub struct ClientTlsOption {
275 pub server_ca_cert_path: String,
276 pub client_cert_path: String,
277 pub client_key_path: String,
278}
279
280#[derive(Clone, Debug, PartialEq, Eq)]
281pub struct ChannelConfig {
282 pub timeout: Option<Duration>,
283 pub connect_timeout: Option<Duration>,
284 pub concurrency_limit: Option<usize>,
285 pub rate_limit: Option<(u64, Duration)>,
286 pub initial_stream_window_size: Option<u32>,
287 pub initial_connection_window_size: Option<u32>,
288 pub http2_keep_alive_interval: Option<Duration>,
289 pub http2_keep_alive_timeout: Option<Duration>,
290 pub http2_keep_alive_while_idle: Option<bool>,
291 pub http2_adaptive_window: Option<bool>,
292 pub tcp_keepalive: Option<Duration>,
293 pub tcp_nodelay: bool,
294 pub client_tls: Option<ClientTlsOption>,
295 pub max_recv_message_size: ReadableSize,
297 pub max_send_message_size: ReadableSize,
299 pub send_compression: bool,
300 pub accept_compression: bool,
301}
302
303impl Default for ChannelConfig {
304 fn default() -> Self {
305 Self {
306 timeout: Some(Duration::from_secs(DEFAULT_GRPC_REQUEST_TIMEOUT_SECS)),
307 connect_timeout: Some(Duration::from_secs(DEFAULT_GRPC_CONNECT_TIMEOUT_SECS)),
308 concurrency_limit: None,
309 rate_limit: None,
310 initial_stream_window_size: None,
311 initial_connection_window_size: None,
312 http2_keep_alive_interval: Some(Duration::from_secs(30)),
313 http2_keep_alive_timeout: None,
314 http2_keep_alive_while_idle: Some(true),
315 http2_adaptive_window: None,
316 tcp_keepalive: None,
317 tcp_nodelay: true,
318 client_tls: None,
319 max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
320 max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
321 send_compression: false,
322 accept_compression: false,
323 }
324 }
325}
326
327impl ChannelConfig {
328 pub fn new() -> Self {
329 Default::default()
330 }
331
332 pub fn timeout(mut self, timeout: Duration) -> Self {
334 self.timeout = Some(timeout);
335 self
336 }
337
338 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
342 self.connect_timeout = Some(timeout);
343 self
344 }
345
346 pub fn concurrency_limit(mut self, limit: usize) -> Self {
348 self.concurrency_limit = Some(limit);
349 self
350 }
351
352 pub fn rate_limit(mut self, limit: u64, duration: Duration) -> Self {
354 self.rate_limit = Some((limit, duration));
355 self
356 }
357
358 pub fn initial_stream_window_size(mut self, size: u32) -> Self {
361 self.initial_stream_window_size = Some(size);
362 self
363 }
364
365 pub fn initial_connection_window_size(mut self, size: u32) -> Self {
369 self.initial_connection_window_size = Some(size);
370 self
371 }
372
373 pub fn http2_keep_alive_interval(mut self, duration: Duration) -> Self {
375 self.http2_keep_alive_interval = Some(duration);
376 self
377 }
378
379 pub fn http2_keep_alive_timeout(mut self, duration: Duration) -> Self {
381 self.http2_keep_alive_timeout = Some(duration);
382 self
383 }
384
385 pub fn http2_keep_alive_while_idle(mut self, enabled: bool) -> Self {
387 self.http2_keep_alive_while_idle = Some(enabled);
388 self
389 }
390
391 pub fn http2_adaptive_window(mut self, enabled: bool) -> Self {
393 self.http2_adaptive_window = Some(enabled);
394 self
395 }
396
397 pub fn tcp_keepalive(mut self, duration: Duration) -> Self {
404 self.tcp_keepalive = Some(duration);
405 self
406 }
407
408 pub fn tcp_nodelay(mut self, enabled: bool) -> Self {
412 self.tcp_nodelay = enabled;
413 self
414 }
415
416 pub fn client_tls_config(mut self, client_tls_option: ClientTlsOption) -> Self {
420 self.client_tls = Some(client_tls_option);
421 self
422 }
423}
424
425#[derive(Debug)]
426pub struct Channel {
427 channel: InnerChannel,
428 access: AtomicUsize,
429 use_default_connector: bool,
430}
431
432impl Channel {
433 #[inline]
434 pub fn access(&self) -> usize {
435 self.access.load(Ordering::Relaxed)
436 }
437
438 #[inline]
439 pub fn use_default_connector(&self) -> bool {
440 self.use_default_connector
441 }
442
443 #[inline]
444 pub fn increase_access(&self) {
445 let _ = self.access.fetch_add(1, Ordering::Relaxed);
446 }
447}
448
449#[derive(Debug, Default)]
450struct Pool {
451 channels: DashMap<String, Channel>,
452}
453
454impl Pool {
455 fn get(&self, addr: &str) -> Option<InnerChannel> {
456 let channel = self.channels.get(addr);
457 channel.map(|ch| {
458 ch.increase_access();
459 ch.channel.clone()
460 })
461 }
462
463 fn entry(&self, addr: String) -> Entry<String, Channel> {
464 self.channels.entry(addr)
465 }
466
467 #[cfg(test)]
468 fn get_access(&self, addr: &str) -> Option<usize> {
469 let channel = self.channels.get(addr);
470 channel.map(|ch| ch.access())
471 }
472
473 fn put(&self, addr: &str, channel: Channel) {
474 let _ = self.channels.insert(addr.to_string(), channel);
475 }
476
477 fn retain_channel<F>(&self, f: F)
478 where
479 F: FnMut(&String, &mut Channel) -> bool,
480 {
481 self.channels.retain(f);
482 }
483}
484
485async fn recycle_channel_in_loop(
486 pool: Arc<Pool>,
487 id: u64,
488 cancel: CancellationToken,
489 interval_secs: u64,
490) {
491 let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
492
493 loop {
494 tokio::select! {
495 _ = cancel.cancelled() => {
496 info!("Stop channel recycle, ChannelManager id: {}", id);
497 break;
498 },
499 _ = interval.tick() => {}
500 }
501
502 pool.retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0)
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use tower::service_fn;
509
510 use super::*;
511
512 #[should_panic]
513 #[test]
514 fn test_invalid_addr() {
515 let mgr = ChannelManager::default();
516 let addr = "http://test";
517
518 let _ = mgr.get(addr).unwrap();
519 }
520
521 #[tokio::test]
522 async fn test_access_count() {
523 let mgr = ChannelManager::new();
524 mgr.inner
526 .channel_recycle_started
527 .store(true, Ordering::Relaxed);
528 let mgr = Arc::new(mgr);
529 let addr = "test_uri";
530
531 let mut joins = Vec::with_capacity(10);
532 for _ in 0..10 {
533 let mgr_clone = mgr.clone();
534 let join = tokio::spawn(async move {
535 for _ in 0..100 {
536 let _ = mgr_clone.get(addr);
537 }
538 });
539 joins.push(join);
540 }
541 for join in joins {
542 join.await.unwrap();
543 }
544
545 assert_eq!(1000, mgr.pool().get_access(addr).unwrap());
546
547 mgr.pool()
548 .retain_channel(|_, c| c.access.swap(0, Ordering::Relaxed) != 0);
549
550 assert_eq!(0, mgr.pool().get_access(addr).unwrap());
551 }
552
553 #[test]
554 fn test_config() {
555 let default_cfg = ChannelConfig::new();
556 assert_eq!(
557 ChannelConfig {
558 timeout: Some(Duration::from_secs(DEFAULT_GRPC_REQUEST_TIMEOUT_SECS)),
559 connect_timeout: Some(Duration::from_secs(DEFAULT_GRPC_CONNECT_TIMEOUT_SECS)),
560 concurrency_limit: None,
561 rate_limit: None,
562 initial_stream_window_size: None,
563 initial_connection_window_size: None,
564 http2_keep_alive_interval: Some(Duration::from_secs(30)),
565 http2_keep_alive_timeout: None,
566 http2_keep_alive_while_idle: Some(true),
567 http2_adaptive_window: None,
568 tcp_keepalive: None,
569 tcp_nodelay: true,
570 client_tls: None,
571 max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
572 max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
573 send_compression: false,
574 accept_compression: false,
575 },
576 default_cfg
577 );
578
579 let cfg = default_cfg
580 .timeout(Duration::from_secs(3))
581 .connect_timeout(Duration::from_secs(5))
582 .concurrency_limit(6)
583 .rate_limit(5, Duration::from_secs(1))
584 .initial_stream_window_size(10)
585 .initial_connection_window_size(20)
586 .http2_keep_alive_interval(Duration::from_secs(1))
587 .http2_keep_alive_timeout(Duration::from_secs(3))
588 .http2_keep_alive_while_idle(true)
589 .http2_adaptive_window(true)
590 .tcp_keepalive(Duration::from_secs(2))
591 .tcp_nodelay(false)
592 .client_tls_config(ClientTlsOption {
593 server_ca_cert_path: "some_server_path".to_string(),
594 client_cert_path: "some_cert_path".to_string(),
595 client_key_path: "some_key_path".to_string(),
596 });
597
598 assert_eq!(
599 ChannelConfig {
600 timeout: Some(Duration::from_secs(3)),
601 connect_timeout: Some(Duration::from_secs(5)),
602 concurrency_limit: Some(6),
603 rate_limit: Some((5, Duration::from_secs(1))),
604 initial_stream_window_size: Some(10),
605 initial_connection_window_size: Some(20),
606 http2_keep_alive_interval: Some(Duration::from_secs(1)),
607 http2_keep_alive_timeout: Some(Duration::from_secs(3)),
608 http2_keep_alive_while_idle: Some(true),
609 http2_adaptive_window: Some(true),
610 tcp_keepalive: Some(Duration::from_secs(2)),
611 tcp_nodelay: false,
612 client_tls: Some(ClientTlsOption {
613 server_ca_cert_path: "some_server_path".to_string(),
614 client_cert_path: "some_cert_path".to_string(),
615 client_key_path: "some_key_path".to_string(),
616 }),
617 max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
618 max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
619 send_compression: false,
620 accept_compression: false,
621 },
622 cfg
623 );
624 }
625
626 #[test]
627 fn test_build_endpoint() {
628 let config = ChannelConfig::new()
629 .timeout(Duration::from_secs(3))
630 .connect_timeout(Duration::from_secs(5))
631 .concurrency_limit(6)
632 .rate_limit(5, Duration::from_secs(1))
633 .initial_stream_window_size(10)
634 .initial_connection_window_size(20)
635 .http2_keep_alive_interval(Duration::from_secs(1))
636 .http2_keep_alive_timeout(Duration::from_secs(3))
637 .http2_keep_alive_while_idle(true)
638 .http2_adaptive_window(true)
639 .tcp_keepalive(Duration::from_secs(2))
640 .tcp_nodelay(true);
641 let mgr = ChannelManager::with_config(config);
642
643 let res = mgr.build_endpoint("test_addr");
644
645 let _ = res.unwrap();
646 }
647
648 #[tokio::test]
649 async fn test_channel_with_connector() {
650 let mgr = ChannelManager::new();
651
652 let addr = "test_addr";
653 let res = mgr.get(addr);
654 let _ = res.unwrap();
655
656 mgr.retain_channel(|addr, channel| {
657 assert_eq!("test_addr", addr);
658 assert!(channel.use_default_connector());
659 true
660 });
661
662 let (client, _) = tokio::io::duplex(1024);
663 let mut client = Some(hyper_util::rt::TokioIo::new(client));
664 let res = mgr.reset_with_connector(
665 addr,
666 service_fn(move |_| {
667 let client = client.take().unwrap();
668 async move { Ok::<_, std::io::Error>(client) }
669 }),
670 );
671
672 let _ = res.unwrap();
673
674 mgr.retain_channel(|addr, channel| {
675 assert_eq!("test_addr", addr);
676 assert!(!channel.use_default_connector());
677 true
678 });
679 }
680
681 #[tokio::test]
682 async fn test_pool_release_with_channel_recycle() {
683 let mgr = ChannelManager::new();
684
685 let pool_holder = mgr.pool().clone();
686
687 let addr = "test_addr";
689 let _ = mgr.get(addr);
690
691 let mgr_clone_1 = mgr.clone();
692 let mgr_clone_2 = mgr.clone();
693 assert_eq!(3, Arc::strong_count(mgr.pool()));
694
695 drop(mgr_clone_1);
696 drop(mgr_clone_2);
697 assert_eq!(3, Arc::strong_count(mgr.pool()));
698
699 drop(mgr);
700
701 tokio::time::sleep(Duration::from_millis(10)).await;
703
704 assert_eq!(1, Arc::strong_count(&pool_holder));
705 }
706
707 #[tokio::test]
708 async fn test_pool_release_without_channel_recycle() {
709 let mgr = ChannelManager::new();
710
711 let pool_holder = mgr.pool().clone();
712
713 let mgr_clone_1 = mgr.clone();
714 let mgr_clone_2 = mgr.clone();
715 assert_eq!(2, Arc::strong_count(mgr.pool()));
716
717 drop(mgr_clone_1);
718 drop(mgr_clone_2);
719 assert_eq!(2, Arc::strong_count(mgr.pool()));
720
721 drop(mgr);
722
723 assert_eq!(1, Arc::strong_count(&pool_holder));
724 }
725}